// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

#include <sstream>
#include <gtest/gtest.h>

#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm_group_quant.hpp"

template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
                         const ck_tile::index_t kbatch,
                         const float max_accumulated_value)
{
    using ComputeType =
        std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
    // Calculate thresholds
    const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
        ck_tile::integer_divide_ceil(K, kbatch));
    const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
        max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
    // Calculate error due to split_k accumulation
    const auto rtol_split_k =
        ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
    const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
        max_accumulated_value, kbatch);
    // Use higher threshold
    return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}

enum struct GemmPipelineType
{
    Aquant
};

template <GemmPipelineType PT, typename Problem>
struct GemmPipelineTypeSelector;

template <typename Problem>
struct GemmPipelineTypeSelector<GemmPipelineType::Aquant, Problem>
{
    using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<Problem>;
    using pipeline      = ck_tile::AQuantGemmPipelineAgBgCrCompV3<Problem>;
};

template <typename Tuple>
class TestCkTileGemmPipeline : public ::testing::Test
{
    protected:
    using ALayout                            = std::tuple_element_t<0, Tuple>;
    using BLayout                            = std::tuple_element_t<1, Tuple>;
    using CLayout                            = std::tuple_element_t<2, Tuple>;
    using ADataType                          = std::tuple_element_t<3, Tuple>;
    using BDataType                          = std::tuple_element_t<4, Tuple>;
    using AccDataType                        = std::tuple_element_t<5, Tuple>;
    using CDataType                          = std::tuple_element_t<6, Tuple>;
    static constexpr auto Scheduler          = std::tuple_element_t<7, Tuple>::value;
    static constexpr auto PipelineType       = std::tuple_element_t<8, Tuple>::value;
    using AQLayout                           = std::tuple_element_t<9, Tuple>;
    using AQDataType                         = std::tuple_element_t<10, Tuple>;
    static constexpr auto QuantGroupSize     = std::tuple_element_t<11, Tuple>::value;
    static constexpr bool TransposedWarpGemm = std::tuple_element_t<12, Tuple>::value;
    static constexpr bool TransposeC         = std::tuple_element_t<13, Tuple>::value;

    static_assert(Scheduler == ck_tile::GemmPipelineScheduler::Intrawave,
                  "Aquant Gemm only supports Intrawave scheduler.");

    template <bool PadM, bool PadN, bool PadK>
    void invoke_gemm(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
    {
        // TODO: This should be parameterized in tests
        constexpr ck_tile::index_t M_Tile = 128;
        constexpr ck_tile::index_t N_Tile = 128;
        constexpr ck_tile::index_t K_Tile = 128;

        constexpr ck_tile::index_t M_Warp = 2;
        constexpr ck_tile::index_t N_Warp = 2;
        constexpr ck_tile::index_t K_Warp = 1;

        constexpr ck_tile::index_t M_Warp_Tile = 32;
        constexpr ck_tile::index_t N_Warp_Tile = 32;
        constexpr ck_tile::index_t K_Warp_Tile = 16;

        constexpr bool kPadM = PadM;
        constexpr bool kPadN = PadN;
        constexpr bool kPadK = PadK;

        constexpr int kBlockPerCu                         = 1;
        constexpr ck_tile::index_t TileParitionerGroupNum = 8;
        constexpr ck_tile::index_t TileParitionerM01      = 4;

        using GemmShape =
            ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
                                   ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
                                   ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
        using TilePartitioner = ck_tile::
            GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;

        using CodegenGemmTraits =
            ck_tile::TileGemmAQuantTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout, AQLayout>;

        using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
                                                                     BDataType,
                                                                     AccDataType,
                                                                     GemmShape,
                                                                     CodegenGemmTraits,
                                                                     BDataType>;

        using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;

        const ck_tile::index_t k_grain     = args.k_batch * K_Tile;
        const ck_tile::index_t K_split     = (args.K + k_grain - 1) / k_grain * K_Tile;
        const ck_tile::index_t num_loop    = TilePartitioner::GetLoopNum(K_split);
        const bool has_hot_loop            = BaseGemmPipeline::BlockHasHotloop(num_loop);
        const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);

        const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
            constexpr bool has_hot_loop_v = has_hot_loop_.value;
            constexpr auto tail_number_v  = tail_number_.value;

            using CodegenPipelineProblem =
                ck_tile::GemmAQuantPipelineProblem<ADataType,
                                                   AQDataType,
                                                   BDataType,
                                                   AccDataType,
                                                   GemmShape,
                                                   CodegenGemmTraits,
                                                   QuantGroupSize,
                                                   TransposedWarpGemm,
                                                   BDataType,
                                                   ck_tile::GemmPipelineScheduler::Intrawave,
                                                   has_hot_loop_v,
                                                   tail_number_v>;

            using GemmPipeline =
                typename GemmPipelineTypeSelector<PipelineType, CodegenPipelineProblem>::pipeline;

            constexpr bool transposeC_epilogue =
                (!TransposedWarpGemm && TransposeC) ||
                (TransposedWarpGemm && !TransposeC);

            using GemmEpilogue = ck_tile::CShuffleEpilogue<
                ck_tile::CShuffleEpilogueProblem<ADataType,
                                                 BDataType,
                                                 AccDataType,
                                                 CDataType,
                                                 CLayout,
                                                 GemmPipeline::BlockSize,
                                                 TilePartitioner::MPerBlock,
                                                 TilePartitioner::NPerBlock,
                                                 M_Warp,
                                                 N_Warp,
                                                 M_Warp_Tile,
                                                 N_Warp_Tile,
                                                 K_Warp_Tile,
                                                 transposeC_epilogue>>;

            using Kernel = ck_tile::AQuantGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
            auto kargs   = Kernel::MakeKernelArgs(args);

            const dim3 grids      = Kernel::GridSize(args.M, args.N, args.k_batch);
            constexpr dim3 blocks = Kernel::BlockSize();

            if(!Kernel::IsSupportedArgument(kargs))
            {
                throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
            }

            if(s.log_level_ > 0)
            {
                std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", "
                          << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", "
                          << blocks.y << ", " << blocks.z << "}" << std::endl;
            }

            ck_tile::launch_kernel(
                s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
        };

        if(has_hot_loop)
        {
            if constexpr(PipelineType == GemmPipelineType::Aquant)
            {
                if(tail_num == ck_tile::TailNumber::Full)
                {
                    Run(ck_tile::bool_constant<true>{},
                        ck_tile::integral_constant<ck_tile::TailNumber,
                                                   ck_tile::TailNumber::Full>{});
                }
                else if(tail_num == ck_tile::TailNumber::Odd)
                {
                    Run(ck_tile::bool_constant<true>{},
                        ck_tile::integral_constant<ck_tile::TailNumber,
                                                   ck_tile::TailNumber::Odd>{});
                }
                else if(tail_num == ck_tile::TailNumber::Even)
                {
                    Run(ck_tile::bool_constant<true>{},
                        ck_tile::integral_constant<ck_tile::TailNumber,
                                                   ck_tile::TailNumber::Even>{});
                }
                else
                {
                    std::ostringstream err;
                    err << "For Aquant compute pipeline tail number should always be Full, Odd, or "
                           "Even, "
                           "but have \""
                        << tail_num << "\" which is not supported! PrefetchStages: "
                        << BaseGemmPipeline::PrefetchStages << "\n File: " << __FILE__ << ":"
                        << __LINE__ << ", in function: " << __func__;
                    throw std::runtime_error(err.str());
                }
            }
        }
        else
        {
            // Tail number always Full - #PrefetchStages
            if(tail_num == ck_tile::TailNumber::Full)
            {
                Run(ck_tile::bool_constant<false>{},
                    ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
            }
            else
            {
                std::ostringstream err;
                err << "When there's no hot loop, this tail number \"" << tail_num
                    << "\" is not supported! " << __FILE__ << ":" << __LINE__
                    << ", in function: " << __func__;
                throw std::runtime_error(err.str());
            }
        }
    }

    public:
    std::vector<int> k_batches_;

    void SetUp() override
    {
        // Only do k_batch = 1 when pipeline is Aquant
        k_batches_ = {1};
    }

    template <bool PadM = false, bool PadN = false, bool PadK = false>
    void Run(const int M,
             const int N,
             const int K,
             const int StrideA  = 0,
             const int StrideB  = 0,
             const int StrideC  = 0,
             const int strideAQ = 0)
    {
        for(auto kb : k_batches_)
        {
            RunSingle<PadM, PadN, PadK>(M, N, K, StrideA, StrideB, StrideC, strideAQ, kb);
        }
    }

    template <bool PadM, bool PadN, bool PadK>
    void RunSingle(const int M,
                   const int N,
                   const int K,
                   const int StrideA,
                   const int StrideB,
                   const int StrideC,
                   const int strideAQ,
                   int kbatch = 1)
    {
        using namespace ck_tile::literals;

        auto f_host_tensor_descriptor = [](std::size_t row,
                                           std::size_t col,
                                           std::size_t stride,
                                           auto layout) {
            if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
            {
                return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
            }
            else
            {
                return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
            }
        };

        auto f_get_default_stride =
            [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
                if(stride == 0)
                {
                    // give a chance if stride is zero, return a default packed stride
                    if constexpr(std::is_same_v<decltype(layout),
                                                ck_tile::tensor_layout::gemm::RowMajor>)
                    {
                        return col;
                    }
                    else
                    {
                        return row;
                    }
                }
                else
                    return stride;
            };

        std::size_t stride_A  = f_get_default_stride(M, K, StrideA, ALayout{});
        std::size_t stride_B  = f_get_default_stride(K, N, StrideB, BLayout{});
        std::size_t stride_C  = f_get_default_stride(M, N, StrideC, CLayout{});
        std::size_t stride_AQ = f_get_default_stride(M, K / QuantGroupSize, strideAQ, AQLayout{});

        ck_tile::HostTensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{}));
        ck_tile::HostTensor<AQDataType> aq_m_aqk(
            f_host_tensor_descriptor(M, K / QuantGroupSize, stride_AQ, AQLayout{}));
        ck_tile::HostTensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{}));
        ck_tile::HostTensor<CDataType> c_m_n_dev_result(
            f_host_tensor_descriptor(M, N, stride_C, CLayout{}));

        std::random_device rd;
        std::mt19937 gen(rd());
        std::uniform_int_distribution<std::uint32_t> fill_seed(0, 500);

        if(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
        {
            constexpr auto int4_array = std::array<uint8_t, 16>{0x77,
                                                                0x66,
                                                                0x55,
                                                                0x44,
                                                                0x33,
                                                                0x22,
                                                                0x11,
                                                                0x00,
                                                                0xff,
                                                                0xee,
                                                                0xdd,
                                                                0xcc,
                                                                0xbb,
                                                                0xaa,
                                                                0x99,
                                                                0x88};
            std::uniform_int_distribution<std::uint32_t> dis(0, 15);
            for(size_t i = 0; i < a_m_k.size(); i++)
            {
                int randomInt   = dis(gen);
                a_m_k.data()[i] = int4_array[randomInt];
            }
        }
        else
        {
            ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
        }
        ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(aq_m_aqk);
        ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);

        ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
        ck_tile::DeviceMem aq_m_aqk_dev_buf(aq_m_aqk.get_element_space_size_in_bytes());
        ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
        ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());

        a_m_k_dev_buf.ToDevice(a_m_k.data());
        b_k_n_dev_buf.ToDevice(b_k_n.data());
        aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
        c_m_n_dev_buf.SetZero();
        c_m_n_dev_result.SetZero();

        ck_tile::AQuantGemmHostArgs args;
        args.a_ptr     = a_m_k_dev_buf.GetDeviceBuffer();
        args.aq_ptr    = aq_m_aqk_dev_buf.GetDeviceBuffer();
        args.b_ptr     = b_k_n_dev_buf.GetDeviceBuffer();
        args.c_ptr     = c_m_n_dev_buf.GetDeviceBuffer();
        args.k_batch   = kbatch;
        args.M         = M;
        args.N         = N;
        args.K         = K;
        args.QK        = K / QuantGroupSize;
        args.stride_A  = stride_A;
        args.stride_B  = stride_B;
        args.stride_C  = stride_C;
        args.stride_AQ = stride_AQ;

        invoke_gemm<PadM, PadN, PadK>(args, ck_tile::stream_config{nullptr, false});

        c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
        bool pass = true;

        ck_tile::HostTensor<CDataType> c_m_n_host_ref(
            f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
        c_m_n_host_ref.SetZero();

        ck_tile::reference_gemm_quant<ADataType,
                                      AQDataType,
                                      BDataType,
                                      AccDataType,
                                      CDataType,
                                      QuantGroupSize, true>(a_m_k, aq_m_aqk, b_k_n, c_m_n_host_ref);

        const float max_accumulated_value =
            *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
        const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
            K, kbatch, max_accumulated_value);
        pass = ck_tile::check_err(c_m_n_dev_result,
                                  c_m_n_host_ref,
                                  "Error: Incorrect results!",
                                  rtol_atol.at(ck_tile::number<0>{}),
                                  rtol_atol.at(ck_tile::number<1>{}));
        EXPECT_TRUE(pass);
    }
};
