/*
 * Copyright (c) 2022-2025, Gregory Bertilson <gregory@ladybird.org>
 *
 * SPDX-License-Identifier: BSD-2-Clause
 */

#pragma once

#include <AK/Function.h>
#include <AK/Optional.h>
#include <LibCore/EventLoop.h>
#include <LibCore/File.h>
#include <LibMedia/Audio/ChannelMap.h>
#include <LibMedia/Containers/Matroska/MatroskaDemuxer.h>
#include <LibMedia/Containers/Matroska/Reader.h>
#include <LibMedia/Demuxer.h>
#include <LibMedia/FFmpeg/FFmpegDemuxer.h>
#include <LibMedia/PipelineStatus.h>
#include <LibMedia/Producers/DecodedAudioProducer.h>
#include <LibMedia/VideoDecoder.h>
#include <LibMedia/VideoFrame.h>
#include <LibTest/TestCase.h>

template<typename T>
static inline void decode_video(StringView path, size_t expected_frame_count, T create_decoder)
{
    auto file = MUST(Core::File::open(path, Core::File::OpenMode::Read));
    auto stream = Media::IncrementallyPopulatedStream::create_from_buffer(MUST(file->read_until_eof()));
    auto matroska_reader = MUST(Media::Matroska::Reader::from_stream(stream->create_cursor()));
    RefPtr<Media::Matroska::TrackEntry const> video_track_entry;
    MUST(matroska_reader.for_each_track_of_type(Media::Matroska::TrackEntry::TrackType::Video, [&](Media::Matroska::TrackEntry const& track_entry) -> Media::DecoderErrorOr<IterationDecision> {
        video_track_entry = track_entry;
        return IterationDecision::Break;
    }));
    EXPECT(video_track_entry);

    auto iterator = MUST(matroska_reader.create_sample_iterator(stream->create_cursor(), video_track_entry->track_number()));
    size_t frame_count = 0;
    NonnullOwnPtr<Media::VideoDecoder> decoder = create_decoder(*video_track_entry);

    auto last_timestamp = AK::Duration::min();

    while (frame_count <= expected_frame_count) {
        auto block_result = iterator.next_block();
        if (block_result.is_error() && block_result.error().category() == Media::DecoderErrorCategory::EndOfStream) {
            EXPECT_EQ(frame_count, expected_frame_count);
            return;
        }

        auto block = block_result.release_value();
        EXPECT(block.timestamp().has_value());
        auto frames = MUST(iterator.get_frames(block));
        for (auto const& frame : frames) {
            MUST(decoder->receive_coded_data(block.timestamp().value(), block.duration().value_or(AK::Duration::zero()), frame));
            while (true) {
                auto frame_result = decoder->get_decoded_frame({});
                if (frame_result.is_error()) {
                    if (frame_result.error().category() == Media::DecoderErrorCategory::NeedsMoreInput)
                        break;
                    VERIFY_NOT_REACHED();
                }
                EXPECT(last_timestamp <= frame_result.value()->timestamp());
                last_timestamp = frame_result.value()->timestamp();
            }
            frame_count++;
        }
    }

    VERIFY_NOT_REACHED();
}

static inline void decode_audio(StringView path, u32 sample_rate, u8 channel_count, size_t expected_frame_count, Optional<Audio::ChannelMap> expected_channel_map = {})
{
    Core::EventLoop loop;

    auto file = MUST(Core::File::open(path, Core::File::OpenMode::Read));
    auto stream = Media::IncrementallyPopulatedStream::create_from_buffer(MUST(file->read_until_eof()));
    auto demuxer = MUST([&] -> Media::DecoderErrorOr<NonnullRefPtr<Media::Demuxer>> {
        auto matroska_result = Media::Matroska::MatroskaDemuxer::from_stream(stream);
        if (!matroska_result.is_error())
            return matroska_result.release_value();
        return Media::FFmpeg::FFmpegDemuxer::from_stream(stream);
    }());
    auto tracks = TRY_OR_FAIL(demuxer->get_tracks_for_type(Media::TrackType::Audio));
    VERIFY(!tracks.is_empty());
    auto producer = TRY_OR_FAIL(Media::DecodedAudioProducer::try_create(Core::EventLoop::current_weak(), demuxer, tracks[0]));

    producer->set_error_handler([&](Media::DecoderError&&) {
        FAIL("An error occurred while decoding.");
    });
    producer->start();

    auto time_limit = AK::Duration::from_seconds(1);
    auto start_time = MonotonicTime::now_coarse();

    i64 last_frame = 0;
    size_t frame_count = 0;
    auto reached_end = false;

    while (true) {
        Media::AudioBlock block;
        auto status = producer->pull(block);
        if (status == Media::PipelineStatus::HaveData) {
            EXPECT(!block.is_empty());
            EXPECT_EQ(block.sample_rate(), sample_rate);
            EXPECT_EQ(block.channel_count(), channel_count);
            if (expected_channel_map.has_value())
                EXPECT_EQ(block.sample_specification().channel_map(), expected_channel_map.value());

            VERIFY(frame_count == 0 || last_frame <= block.timestamp_in_frames());
            last_frame = block.timestamp_in_frames() + static_cast<i64>(block.frame_count());

            frame_count += block.frame_count();
        } else if (status == Media::PipelineStatus::EndOfStream) {
            reached_end = true;
            break;
        }

        if (MonotonicTime::now_coarse() - start_time >= time_limit) {
            FAIL("Decoding timed out.");
            return;
        }

        loop.pump(Core::EventLoop::WaitMode::PollForEvents);
    }

    VERIFY(reached_end);
    EXPECT_EQ(frame_count, expected_frame_count);
}
