// Copyright (c) the JPEG XL Project Authors. All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

#include "lib/extras/metrics.h"

#include <math.h>
#include <stdlib.h>

#include <atomic>

#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "lib/extras/metrics.cc"
#include <hwy/foreach_target.h>
#include <hwy/highway.h>

#include "lib/jxl/base/compiler_specific.h"
#include "lib/jxl/base/rect.h"
#include "lib/jxl/base/status.h"
#include "lib/jxl/color_encoding_internal.h"
HWY_BEFORE_NAMESPACE();
namespace jxl {
namespace HWY_NAMESPACE {

// These templates are not found via ADL.
using hwy::HWY_NAMESPACE::Add;
using hwy::HWY_NAMESPACE::GetLane;
using hwy::HWY_NAMESPACE::Mul;
using hwy::HWY_NAMESPACE::Rebind;

double ComputeDistanceP(const ImageF& distmap, const ButteraugliParams& params,
                        double p) {
  const double onePerPixels = 1.0 / (distmap.ysize() * distmap.xsize());
  if (std::abs(p - 3.0) < 1E-6) {
    double sum1[3] = {0.0};

// Prefer double if possible, but otherwise use float rather than scalar.
#if HWY_CAP_FLOAT64
    using T = double;
    const Rebind<float, HWY_FULL(double)> df;
#else
    using T = float;
#endif
    const HWY_FULL(T) d;
    constexpr size_t N = MaxLanes(d);
    // Manually aligned storage to avoid asan crash on clang-7 due to
    // unaligned spill.
    HWY_ALIGN T sum_totals0[N] = {0};
    HWY_ALIGN T sum_totals1[N] = {0};
    HWY_ALIGN T sum_totals2[N] = {0};

    for (size_t y = 0; y < distmap.ysize(); ++y) {
      const float* JXL_RESTRICT row = distmap.ConstRow(y);

      auto sums0 = Zero(d);
      auto sums1 = Zero(d);
      auto sums2 = Zero(d);

      size_t x = 0;
      for (; x + Lanes(d) <= distmap.xsize(); x += Lanes(d)) {
#if HWY_CAP_FLOAT64
        const auto d1 = PromoteTo(d, Load(df, row + x));
#else
        const auto d1 = Load(d, row + x);
#endif
        const auto d2 = Mul(d1, Mul(d1, d1));
        sums0 = Add(sums0, d2);
        const auto d3 = Mul(d2, d2);
        sums1 = Add(sums1, d3);
        const auto d4 = Mul(d3, d3);
        sums2 = Add(sums2, d4);
      }

      Store(Add(sums0, Load(d, sum_totals0)), d, sum_totals0);
      Store(Add(sums1, Load(d, sum_totals1)), d, sum_totals1);
      Store(Add(sums2, Load(d, sum_totals2)), d, sum_totals2);

      for (; x < distmap.xsize(); ++x) {
        const double d1 = row[x];
        double d2 = d1 * d1 * d1;
        sum1[0] += d2;
        d2 *= d2;
        sum1[1] += d2;
        d2 *= d2;
        sum1[2] += d2;
      }
    }
    double v = 0;
    v += pow(
        onePerPixels * (sum1[0] + GetLane(SumOfLanes(d, Load(d, sum_totals0)))),
        1.0 / (p * 1.0));
    v += pow(
        onePerPixels * (sum1[1] + GetLane(SumOfLanes(d, Load(d, sum_totals1)))),
        1.0 / (p * 2.0));
    v += pow(
        onePerPixels * (sum1[2] + GetLane(SumOfLanes(d, Load(d, sum_totals2)))),
        1.0 / (p * 4.0));
    v /= 3.0;
    return v;
  } else {
    static std::atomic<int> once{0};
    if (once.fetch_add(1, std::memory_order_relaxed) == 0) {
      JXL_WARNING("WARNING: using slow ComputeDistanceP");
    }
    double sum1[3] = {0.0};
    for (size_t y = 0; y < distmap.ysize(); ++y) {
      const float* JXL_RESTRICT row = distmap.ConstRow(y);
      for (size_t x = 0; x < distmap.xsize(); ++x) {
        double d2 = std::pow(row[x], p);
        sum1[0] += d2;
        d2 *= d2;
        sum1[1] += d2;
        d2 *= d2;
        sum1[2] += d2;
      }
    }
    double v = 0;
    for (int i = 0; i < 3; ++i) {
      v += pow(onePerPixels * (sum1[i]), 1.0 / (p * (1 << i)));
    }
    v /= 3.0;
    return v;
  }
}

void ComputeSumOfSquares(const ImageBundle& ib1, const ImageBundle& ib2,
                         const JxlCmsInterface& cms, double sum_of_squares[3]) {
  sum_of_squares[0] = sum_of_squares[1] = sum_of_squares[2] =
      std::numeric_limits<double>::max();
  // Convert to sRGB - closer to perception than linear.
  const Image3F* srgb1 = &ib1.color();
  Image3F copy1;
  if (!ib1.IsSRGB()) {
    if (!ib1.CopyTo(Rect(ib1), ColorEncoding::SRGB(ib1.IsGray()), cms, &copy1))
      return;
    srgb1 = &copy1;
  }
  const Image3F* srgb2 = &ib2.color();
  Image3F copy2;
  if (!ib2.IsSRGB()) {
    if (!ib2.CopyTo(Rect(ib2), ColorEncoding::SRGB(ib2.IsGray()), cms, &copy2))
      return;
    srgb2 = &copy2;
  }

  if (!SameSize(*srgb1, *srgb2)) return;

  sum_of_squares[0] = sum_of_squares[1] = sum_of_squares[2] = 0.0;

  // TODO(veluca): SIMD.
  float yuvmatrix[3][3] = {{0.299, 0.587, 0.114},
                           {-0.14713, -0.28886, 0.436},
                           {0.615, -0.51499, -0.10001}};
  for (size_t y = 0; y < srgb1->ysize(); ++y) {
    const float* JXL_RESTRICT row1[3];
    const float* JXL_RESTRICT row2[3];
    for (size_t j = 0; j < 3; j++) {
      row1[j] = srgb1->ConstPlaneRow(j, y);
      row2[j] = srgb2->ConstPlaneRow(j, y);
    }
    for (size_t x = 0; x < srgb1->xsize(); ++x) {
      float cdiff[3] = {};
      // YUV conversion is linear, so we can run it on the difference.
      for (size_t j = 0; j < 3; j++) {
        cdiff[j] = row1[j][x] - row2[j][x];
      }
      float yuvdiff[3] = {};
      for (size_t j = 0; j < 3; j++) {
        for (size_t k = 0; k < 3; k++) {
          yuvdiff[j] += yuvmatrix[j][k] * cdiff[k];
        }
      }
      for (size_t j = 0; j < 3; j++) {
        sum_of_squares[j] += yuvdiff[j] * yuvdiff[j];
      }
    }
  }
}

// NOLINTNEXTLINE(google-readability-namespace-comments)
}  // namespace HWY_NAMESPACE
}  // namespace jxl
HWY_AFTER_NAMESPACE();

#if HWY_ONCE
namespace jxl {
HWY_EXPORT(ComputeDistanceP);
double ComputeDistanceP(const ImageF& distmap, const ButteraugliParams& params,
                        double p) {
  return HWY_DYNAMIC_DISPATCH(ComputeDistanceP)(distmap, params, p);
}

HWY_EXPORT(ComputeSumOfSquares);

double ComputeDistance2(const ImageBundle& ib1, const ImageBundle& ib2,
                        const JxlCmsInterface& cms) {
  double sum_of_squares[3] = {};
  HWY_DYNAMIC_DISPATCH(ComputeSumOfSquares)(ib1, ib2, cms, sum_of_squares);
  // Weighted PSNR as in JPEG-XL: chroma counts 1/8.
  const float weights[3] = {6.0f / 8, 1.0f / 8, 1.0f / 8};
  // Avoid squaring the weight - 1/64 is too extreme.
  double norm = 0;
  for (size_t i = 0; i < 3; i++) {
    norm += std::sqrt(sum_of_squares[i]) * weights[i];
  }
  // This function returns distance *squared*.
  return norm * norm;
}

double ComputePSNR(const ImageBundle& ib1, const ImageBundle& ib2,
                   const JxlCmsInterface& cms) {
  if (!SameSize(ib1, ib2)) return 0.0;
  double sum_of_squares[3] = {};
  HWY_DYNAMIC_DISPATCH(ComputeSumOfSquares)(ib1, ib2, cms, sum_of_squares);
  constexpr double kChannelWeights[3] = {6.0 / 8, 1.0 / 8, 1.0 / 8};
  double avg_psnr = 0;
  const size_t input_pixels = ib1.xsize() * ib1.ysize();
  for (int i = 0; i < 3; ++i) {
    const double rmse = std::sqrt(sum_of_squares[i] / input_pixels);
    const double psnr =
        sum_of_squares[i] == 0 ? 99.99 : (20 * std::log10(1 / rmse));
    avg_psnr += kChannelWeights[i] * psnr;
  }
  return avg_psnr;
}

}  // namespace jxl
#endif
