// Copyright 2019 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// For faster Arm builds. We still test SVE2_128, SVE and HWY_NEON_WITHOUT_AES
// (HWY_NEON may already be disabled, which would leave only EMU128).
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS (HWY_SVE_256 | HWY_NEON | HWY_SVE2)
#endif

#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "tests/fma_test.cc"
#include "hwy/foreach_target.h"  // IWYU pragma: keep
#include "hwy/highway.h"
#include "hwy/tests/test_util-inl.h"

HWY_BEFORE_NAMESPACE();
namespace hwy {
namespace HWY_NAMESPACE {
namespace {

#ifndef HWY_NATIVE_FMA
#error "Bug in set_macros-inl.h, did not set HWY_NATIVE_FMA"
#endif

struct TestMulAdd {
  template <typename T, class D>
  HWY_NOINLINE void operator()(T /*unused*/, D d) {
    RandomState rng;

    using TI = MakeSigned<T>;
    const Rebind<TI, D> di;
    const Vec<D> k0 = Zero(d);
    const Vec<D> v1 = Iota(d, 1);
    const Vec<D> v2 = Iota(d, 2);
    VFromD<decltype(di)> mask_i;
    Mask<D> mask;

    // Unlike RebindToSigned, we want to leave floating-point unchanged.
    // This allows Neg for unsigned types.
    const Rebind<If<IsFloat<T>(), T, MakeSigned<T>>, D> dif;
    const Vec<D> neg_v2 = BitCast(d, Neg(BitCast(dif, v2)));

    const size_t N = Lanes(d);
    auto bool_lanes = AllocateAligned<TI>(N);
    auto masked_expected = AllocateAligned<T>(N);
    auto expected = AllocateAligned<T>(N);
    HWY_ASSERT(bool_lanes && masked_expected && expected);
    HWY_ASSERT_VEC_EQ(d, k0, MulAdd(k0, k0, k0));
    HWY_ASSERT_VEC_EQ(d, v2, MulAdd(k0, v1, v2));
    HWY_ASSERT_VEC_EQ(d, v2, MulAdd(v1, k0, v2));
    HWY_ASSERT_VEC_EQ(d, k0, MaskedMulAdd(MaskTrue(d), k0, k0, k0));
    HWY_ASSERT_VEC_EQ(d, v2, MaskedMulAdd(MaskTrue(d), k0, v1, v2));
    HWY_ASSERT_VEC_EQ(d, v2, MaskedMulAdd(MaskTrue(d), v1, k0, v2));
    HWY_ASSERT_VEC_EQ(d, k0, NegMulAdd(k0, k0, k0));
    HWY_ASSERT_VEC_EQ(d, v2, NegMulAdd(k0, v1, v2));
    HWY_ASSERT_VEC_EQ(d, v2, NegMulAdd(v1, k0, v2));
    HWY_ASSERT_VEC_EQ(d, k0, MaskedNegMulAdd(MaskTrue(d), k0, k0, k0));
    HWY_ASSERT_VEC_EQ(d, v2, MaskedNegMulAdd(MaskTrue(d), k0, v1, v2));
    HWY_ASSERT_VEC_EQ(d, v2, MaskedNegMulAdd(MaskTrue(d), v1, k0, v2));

    for (size_t i = 0; i < N; ++i) {
      bool_lanes[i] = (Random32(&rng) & 1024) ? TI(1) : TI(0);
      expected[i] = ConvertScalarTo<T>((i + 1) * (i + 2));
      if (bool_lanes[i]) {
        masked_expected[i] = expected[i];
      } else {
        masked_expected[i] = ConvertScalarTo<T>(0);
      }
    }
    mask_i = Load(di, bool_lanes.get());
    mask = RebindMask(d, Gt(mask_i, Zero(di)));

    HWY_ASSERT_VEC_EQ(d, expected.get(), MulAdd(v2, v1, k0));
    HWY_ASSERT_VEC_EQ(d, expected.get(), MulAdd(v1, v2, k0));
    HWY_ASSERT_VEC_EQ(d, masked_expected.get(), MaskedMulAdd(mask, v2, v1, k0));
    HWY_ASSERT_VEC_EQ(d, masked_expected.get(), MaskedMulAdd(mask, v1, v2, k0));
    HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulAdd(neg_v2, v1, k0));
    HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulAdd(v1, neg_v2, k0));
    HWY_ASSERT_VEC_EQ(d, masked_expected.get(),
                      MaskedNegMulAdd(mask, neg_v2, v1, k0));
    HWY_ASSERT_VEC_EQ(d, masked_expected.get(),
                      MaskedNegMulAdd(mask, v1, neg_v2, k0));

    for (size_t i = 0; i < N; ++i) {
      bool_lanes[i] = (Random32(&rng) & 1024) ? TI(1) : TI(0);
      expected[i] = ConvertScalarTo<T>((i + 2) * (i + 2) + (i + 1));
      if (bool_lanes[i]) {
        masked_expected[i] = expected[i];
      } else {
        masked_expected[i] = ConvertScalarTo<T>(0);
      }
    }
    mask_i = Load(di, bool_lanes.get());
    mask = RebindMask(d, Gt(mask_i, Zero(di)));

    HWY_ASSERT_VEC_EQ(d, expected.get(), MulAdd(v2, v2, v1));
    HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulAdd(neg_v2, v2, v1));
    HWY_ASSERT_VEC_EQ(d, masked_expected.get(), MaskedMulAdd(mask, v2, v2, v1));
    HWY_ASSERT_VEC_EQ(d, masked_expected.get(),
                      MaskedNegMulAdd(mask, neg_v2, v2, v1));

    for (size_t i = 0; i < N; ++i) {
      const T nm = ConvertScalarTo<T>(-static_cast<int>(i + 2));
      const T f = ConvertScalarTo<T>(i + 2);
      const T a = ConvertScalarTo<T>(i + 1);
      bool_lanes[i] = (Random32(&rng) & 1024) ? TI(1) : TI(0);
      expected[i] = ConvertScalarTo<T>(nm * f + a);
      if (bool_lanes[i]) {
        masked_expected[i] = expected[i];
      } else {
        masked_expected[i] = ConvertScalarTo<T>(0);
      }
    }
    mask_i = Load(di, bool_lanes.get());
    mask = RebindMask(d, Gt(mask_i, Zero(di)));

    HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulAdd(v2, v2, v1));
    HWY_ASSERT_VEC_EQ(d, masked_expected.get(),
                      MaskedNegMulAdd(mask, v2, v2, v1));
  }
};

HWY_NOINLINE void TestAllMulAdd() {
  ForAllTypes(ForPartialVectors<TestMulAdd>());
}

struct TestMulSub {
  template <typename T, class D>
  HWY_NOINLINE void operator()(T /*unused*/, D d) {
    const Vec<D> k0 = Zero(d);
    const Vec<D> kNeg0 = Set(d, ConvertScalarTo<T>(-0.0));
    const Vec<D> v1 = Iota(d, 1);
    const Vec<D> v2 = Iota(d, 2);
    const size_t N = Lanes(d);
    auto expected = AllocateAligned<T>(N);
    HWY_ASSERT(expected);

    // Unlike RebindToSigned, we want to leave floating-point unchanged.
    // This allows Neg for unsigned types.
    const Rebind<If<IsFloat<T>(), T, MakeSigned<T>>, D> dif;

    HWY_ASSERT_VEC_EQ(d, k0, MulSub(k0, k0, k0));
    HWY_ASSERT_VEC_EQ(d, kNeg0, NegMulSub(k0, k0, k0));

    for (size_t i = 0; i < N; ++i) {
      expected[i] = ConvertScalarTo<T>(-static_cast<int>(i + 2));
    }
    const auto neg_k0 = BitCast(d, Neg(BitCast(dif, k0)));
    HWY_ASSERT_VEC_EQ(d, expected.get(), MulSub(k0, v1, v2));
    HWY_ASSERT_VEC_EQ(d, expected.get(), MulSub(v1, k0, v2));
    HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulSub(neg_k0, v1, v2));
    HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulSub(v1, neg_k0, v2));

    for (size_t i = 0; i < N; ++i) {
      expected[i] = ConvertScalarTo<T>((i + 1) * (i + 2));
    }
    const auto neg_v1 = BitCast(d, Neg(BitCast(dif, v1)));
    HWY_ASSERT_VEC_EQ(d, expected.get(), MulSub(v1, v2, k0));
    HWY_ASSERT_VEC_EQ(d, expected.get(), MulSub(v2, v1, k0));
    HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulSub(neg_v1, v2, k0));
    HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulSub(v2, neg_v1, k0));

    for (size_t i = 0; i < N; ++i) {
      expected[i] = ConvertScalarTo<T>((i + 2) * (i + 2) - (1 + i));
    }
    const auto neg_v2 = BitCast(d, Neg(BitCast(dif, v2)));
    HWY_ASSERT_VEC_EQ(d, expected.get(), MulSub(v2, v2, v1));
    HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulSub(neg_v2, v2, v1));
    HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulSub(v2, neg_v2, v1));
  }
};

HWY_NOINLINE void TestAllMulSub() {
  ForAllTypes(ForPartialVectors<TestMulSub>());
}

struct TestMulAddSub {
  template <typename T, class D>
  HWY_NOINLINE void operator()(T /*unused*/, D d) {
    const Vec<D> k0 = Zero(d);
    const Vec<D> v1 = Iota(d, 1);
    const Vec<D> v2 = Iota(d, 2);

    // Unlike RebindToSigned, we want to leave floating-point unchanged.
    // This allows Neg for unsigned types.
    const Rebind<If<IsFloat<T>(), T, MakeSigned<T>>, D> dif;
    const Vec<D> neg_v2 = BitCast(d, Neg(BitCast(dif, v2)));

    const size_t N = Lanes(d);
    auto expected = AllocateAligned<T>(N);
    HWY_ASSERT(expected);

    HWY_ASSERT_VEC_EQ(d, k0, MulAddSub(k0, k0, k0));

    const auto v2_negated_if_even = OddEven(v2, neg_v2);
    HWY_ASSERT_VEC_EQ(d, v2_negated_if_even, MulAddSub(k0, v1, v2));
    HWY_ASSERT_VEC_EQ(d, v2_negated_if_even, MulAddSub(v1, k0, v2));

    for (size_t i = 0; i < N; ++i) {
      expected[i] =
          ConvertScalarTo<T>(((i & 1) == 0) ? ((i + 2) * (i + 2) - (i + 1))
                                            : ((i + 2) * (i + 2) + (i + 1)));
    }
    HWY_ASSERT_VEC_EQ(d, expected.get(), MulAddSub(v2, v2, v1));
  }
};

HWY_NOINLINE void TestAllMulAddSub() {
  ForAllTypes(ForPartialVectors<TestMulAddSub>());
}

struct TestMulSubAdd {
  template <typename T, class D>
  HWY_NOINLINE void operator()(T /*unused*/, D d) {
    const Vec<D> k0 = Zero(d);
    const Vec<D> v1 = Iota(d, 1);
    const Vec<D> v2 = Iota(d, 2);

    // Unlike RebindToSigned, we want to leave floating-point unchanged.
    // This allows Neg for unsigned types.
    const Rebind<If<IsFloat<T>(), T, MakeSigned<T>>, D> dif;
    const Vec<D> neg_v2 = BitCast(d, Neg(BitCast(dif, v2)));

    const size_t N = Lanes(d);
    auto expected = AllocateAligned<T>(N);
    HWY_ASSERT(expected);

    HWY_ASSERT_VEC_EQ(d, k0, MulSubAdd(k0, k0, k0));

    const auto v2_negated_if_odd = OddEven(neg_v2, v2);
    HWY_ASSERT_VEC_EQ(d, v2_negated_if_odd, MulSubAdd(k0, v1, v2));
    HWY_ASSERT_VEC_EQ(d, v2_negated_if_odd, MulSubAdd(v1, k0, v2));

    for (size_t i = 0; i < N; ++i) {
      expected[i] =
          ConvertScalarTo<T>(((i & 1) == 0) ? ((i + 2) * (i + 2) + (i + 1))
                                            : ((i + 2) * (i + 2) - (i + 1)));
    }
    HWY_ASSERT_VEC_EQ(d, expected.get(), MulSubAdd(v2, v2, v1));
  }
};

HWY_NOINLINE void TestAllMulSubAdd() {
  ForAllTypes(ForPartialVectors<TestMulSubAdd>());
}

}  // namespace
// NOLINTNEXTLINE(google-readability-namespace-comments)
}  // namespace HWY_NAMESPACE
}  // namespace hwy
HWY_AFTER_NAMESPACE();

#if HWY_ONCE
namespace hwy {
namespace {
HWY_BEFORE_TEST(HwyFmaTest);
HWY_EXPORT_AND_TEST_P(HwyFmaTest, TestAllMulAdd);
HWY_EXPORT_AND_TEST_P(HwyFmaTest, TestAllMulSub);
HWY_EXPORT_AND_TEST_P(HwyFmaTest, TestAllMulAddSub);
HWY_EXPORT_AND_TEST_P(HwyFmaTest, TestAllMulSubAdd);
HWY_AFTER_TEST();
}  // namespace
}  // namespace hwy
HWY_TEST_MAIN();
#endif  // HWY_ONCE
