// Copyright (c) the JPEG XL Project Authors. All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

#if defined(LIB_JPEGLI_DCT_INL_H_) == defined(HWY_TARGET_TOGGLE)
#ifdef LIB_JPEGLI_DCT_INL_H_
#undef LIB_JPEGLI_DCT_INL_H_
#else
#define LIB_JPEGLI_DCT_INL_H_
#endif

#include "lib/jpegli/transpose-inl.h"
#include "lib/jxl/base/compiler_specific.h"

HWY_BEFORE_NAMESPACE();
namespace jpegli {
namespace HWY_NAMESPACE {
namespace {

// These templates are not found via ADL.
using hwy::HWY_NAMESPACE::Abs;
using hwy::HWY_NAMESPACE::Add;
using hwy::HWY_NAMESPACE::DemoteTo;
using hwy::HWY_NAMESPACE::Ge;
using hwy::HWY_NAMESPACE::IfThenElseZero;
using hwy::HWY_NAMESPACE::Mul;
using hwy::HWY_NAMESPACE::MulAdd;
using hwy::HWY_NAMESPACE::Rebind;
using hwy::HWY_NAMESPACE::Round;
using hwy::HWY_NAMESPACE::Sub;
using hwy::HWY_NAMESPACE::Vec;

using D = HWY_FULL(float);
using DI = HWY_FULL(int32_t);

template <size_t N>
void AddReverse(const float* JXL_RESTRICT a_in1,
                const float* JXL_RESTRICT a_in2, float* JXL_RESTRICT a_out) {
  HWY_CAPPED(float, 8) d8;
  for (size_t i = 0; i < N; i++) {
    auto in1 = Load(d8, a_in1 + i * 8);
    auto in2 = Load(d8, a_in2 + (N - i - 1) * 8);
    Store(Add(in1, in2), d8, a_out + i * 8);
  }
}

template <size_t N>
void SubReverse(const float* JXL_RESTRICT a_in1,
                const float* JXL_RESTRICT a_in2, float* JXL_RESTRICT a_out) {
  HWY_CAPPED(float, 8) d8;
  for (size_t i = 0; i < N; i++) {
    auto in1 = Load(d8, a_in1 + i * 8);
    auto in2 = Load(d8, a_in2 + (N - i - 1) * 8);
    Store(Sub(in1, in2), d8, a_out + i * 8);
  }
}

template <size_t N>
void B(float* JXL_RESTRICT coeff) {
  HWY_CAPPED(float, 8) d8;
  constexpr float kSqrt2 = 1.41421356237f;
  auto sqrt2 = Set(d8, kSqrt2);
  auto in1 = Load(d8, coeff);
  auto in2 = Load(d8, coeff + 8);
  Store(MulAdd(in1, sqrt2, in2), d8, coeff);
  for (size_t i = 1; i + 1 < N; i++) {
    auto in1 = Load(d8, coeff + i * 8);
    auto in2 = Load(d8, coeff + (i + 1) * 8);
    Store(Add(in1, in2), d8, coeff + i * 8);
  }
}

// Ideally optimized away by compiler (except the multiply).
template <size_t N>
void InverseEvenOdd(const float* JXL_RESTRICT a_in, float* JXL_RESTRICT a_out) {
  HWY_CAPPED(float, 8) d8;
  for (size_t i = 0; i < N / 2; i++) {
    auto in1 = Load(d8, a_in + i * 8);
    Store(in1, d8, a_out + 2 * i * 8);
  }
  for (size_t i = N / 2; i < N; i++) {
    auto in1 = Load(d8, a_in + i * 8);
    Store(in1, d8, a_out + (2 * (i - N / 2) + 1) * 8);
  }
}

// Constants for DCT implementation. Generated by the following snippet:
// for i in range(N // 2):
//    print(1.0 / (2 * math.cos((i + 0.5) * math.pi / N)), end=", ")
template <size_t N>
struct WcMultipliers;

template <>
struct WcMultipliers<4> {
  static constexpr float kMultipliers[] = {
      0.541196100146197,
      1.3065629648763764,
  };
};

template <>
struct WcMultipliers<8> {
  static constexpr float kMultipliers[] = {
      0.5097955791041592,
      0.6013448869350453,
      0.8999762231364156,
      2.5629154477415055,
  };
};

#if JXL_CXX_LANG < JXL_CXX_17
constexpr float WcMultipliers<4>::kMultipliers[];
constexpr float WcMultipliers<8>::kMultipliers[];
#endif

// Invoked on full vector.
template <size_t N>
void Multiply(float* JXL_RESTRICT coeff) {
  HWY_CAPPED(float, 8) d8;
  for (size_t i = 0; i < N / 2; i++) {
    auto in1 = Load(d8, coeff + (N / 2 + i) * 8);
    auto mul = Set(d8, WcMultipliers<N>::kMultipliers[i]);
    Store(Mul(in1, mul), d8, coeff + (N / 2 + i) * 8);
  }
}

void LoadFromBlock(const float* JXL_RESTRICT pixels, size_t pixels_stride,
                   size_t off, float* JXL_RESTRICT coeff) {
  HWY_CAPPED(float, 8) d8;
  for (size_t i = 0; i < 8; i++) {
    Store(LoadU(d8, pixels + i * pixels_stride + off), d8, coeff + i * 8);
  }
}

void StoreToBlockAndScale(const float* JXL_RESTRICT coeff, float* output,
                          size_t off) {
  HWY_CAPPED(float, 8) d8;
  auto mul = Set(d8, 1.0f / 8);
  for (size_t i = 0; i < 8; i++) {
    StoreU(Mul(mul, Load(d8, coeff + i * 8)), d8, output + i * 8 + off);
  }
}

template <size_t N>
struct DCT1DImpl;

template <>
struct DCT1DImpl<1> {
  JXL_INLINE void operator()(float* JXL_RESTRICT mem) {}
};

template <>
struct DCT1DImpl<2> {
  JXL_INLINE void operator()(float* JXL_RESTRICT mem) {
    HWY_CAPPED(float, 8) d8;
    auto in1 = Load(d8, mem);
    auto in2 = Load(d8, mem + 8);
    Store(Add(in1, in2), d8, mem);
    Store(Sub(in1, in2), d8, mem + 8);
  }
};

template <size_t N>
struct DCT1DImpl {
  void operator()(float* JXL_RESTRICT mem) {
    HWY_ALIGN float tmp[N * 8];
    AddReverse<N / 2>(mem, mem + N * 4, tmp);
    DCT1DImpl<N / 2>()(tmp);
    SubReverse<N / 2>(mem, mem + N * 4, tmp + N * 4);
    Multiply<N>(tmp);
    DCT1DImpl<N / 2>()(tmp + N * 4);
    B<N / 2>(tmp + N * 4);
    InverseEvenOdd<N>(tmp, mem);
  }
};

void DCT1D(const float* JXL_RESTRICT pixels, size_t pixels_stride,
           float* JXL_RESTRICT output) {
  HWY_CAPPED(float, 8) d8;
  HWY_ALIGN float tmp[64];
  for (size_t i = 0; i < 8; i += Lanes(d8)) {
    // TODO(veluca): consider removing the temporary memory here (as is done in
    // IDCT), if it turns out that some compilers don't optimize away the loads
    // and this is performance-critical.
    LoadFromBlock(pixels, pixels_stride, i, tmp);
    DCT1DImpl<8>()(tmp);
    StoreToBlockAndScale(tmp, output, i);
  }
}

JXL_INLINE JXL_MAYBE_UNUSED void TransformFromPixels(
    const float* JXL_RESTRICT pixels, size_t pixels_stride,
    float* JXL_RESTRICT coefficients, float* JXL_RESTRICT scratch_space) {
  DCT1D(pixels, pixels_stride, scratch_space);
  Transpose8x8Block(scratch_space, coefficients);
  DCT1D(coefficients, 8, scratch_space);
  Transpose8x8Block(scratch_space, coefficients);
}

JXL_INLINE JXL_MAYBE_UNUSED void StoreQuantizedValue(const Vec<DI>& ival,
                                                     int16_t* out) {
  Rebind<int16_t, DI> di16;
  Store(DemoteTo(di16, ival), di16, out);
}

JXL_INLINE JXL_MAYBE_UNUSED void StoreQuantizedValue(const Vec<DI>& ival,
                                                     int32_t* out) {
  DI di;
  Store(ival, di, out);
}

template <typename T>
void QuantizeBlock(const float* dct, const float* qmc, float aq_strength,
                   const float* zero_bias_offset, const float* zero_bias_mul,
                   T* block) {
  D d;
  DI di;
  const auto aq_mul = Set(d, aq_strength);
  for (size_t k = 0; k < DCTSIZE2; k += Lanes(d)) {
    const auto val = Load(d, dct + k);
    const auto q = Load(d, qmc + k);
    const auto qval = Mul(val, q);
    const auto zb_offset = Load(d, zero_bias_offset + k);
    const auto zb_mul = Load(d, zero_bias_mul + k);
    const auto threshold = Add(zb_offset, Mul(zb_mul, aq_mul));
    const auto nzero_mask = Ge(Abs(qval), threshold);
    const auto ival = ConvertTo(di, IfThenElseZero(nzero_mask, Round(qval)));
    StoreQuantizedValue(ival, block + k);
  }
}

template <typename T>
void ComputeCoefficientBlock(const float* JXL_RESTRICT pixels, size_t stride,
                             const float* JXL_RESTRICT qmc,
                             int16_t last_dc_coeff, float aq_strength,
                             const float* zero_bias_offset,
                             const float* zero_bias_mul,
                             float* JXL_RESTRICT tmp, T* block) {
  float* JXL_RESTRICT dct = tmp;
  float* JXL_RESTRICT scratch_space = tmp + DCTSIZE2;
  TransformFromPixels(pixels, stride, dct, scratch_space);
  QuantizeBlock(dct, qmc, aq_strength, zero_bias_offset, zero_bias_mul, block);
  // Center DC values around zero.
  static constexpr float kDCBias = 128.0f;
  const float dc = (dct[0] - kDCBias) * qmc[0];
  float dc_threshold = zero_bias_offset[0] + aq_strength * zero_bias_mul[0];
  if (std::abs(dc - last_dc_coeff) < dc_threshold) {
    block[0] = last_dc_coeff;
  } else {
    block[0] = std::round(dc);
  }
}

// NOLINTNEXTLINE(google-readability-namespace-comments)
}  // namespace
}  // namespace HWY_NAMESPACE
}  // namespace jpegli
HWY_AFTER_NAMESPACE();
#endif  // LIB_JPEGLI_DCT_INL_H_
