#!/usr/bin/env python3

# Copyright (c) 2022-2026, Sam Atkins <sam@ladybird.org>
# Copyright (c) 2026-present, the Ladybird developers.
#
# SPDX-License-Identifier: BSD-2-Clause

import argparse
import json
import sys

from pathlib import Path
from typing import TextIO

sys.path.append(str(Path(__file__).resolve().parent.parent))

from Utils.utils import title_casify

TYPE_CHECKS = {
    "<angle>": "matches_angle(percentages_resolve_as)",
    "<dimension>": "matches_dimension()",
    "<flex>": "matches_flex(percentages_resolve_as)",
    "<frequency>": "matches_frequency(percentages_resolve_as)",
    "<length>": "matches_length(percentages_resolve_as)",
    "<number>": "matches_number(percentages_resolve_as)",
    "<percentage>": "matches_percentage()",
    "<resolution>": "matches_resolution(percentages_resolve_as)",
    "<time>": "matches_time(percentages_resolve_as)",
}


def generate_calculation_type_check(calculation_variable_name: str, parameter_types: str) -> str:
    parts = []
    for allowed_type_name in parameter_types.split("|"):
        if allowed_type_name not in TYPE_CHECKS:
            print(f"I don't know what '{allowed_type_name}' is!", file=sys.stderr)
            sys.exit(1)
        parts.append(f"{calculation_variable_name}.{TYPE_CHECKS[allowed_type_name]}")
    return " || ".join(parts)


def write_header_file(out: TextIO, functions_data: dict) -> None:
    out.write("""
// This file is generated by GenerateCSSMathFunctions.cpp

#pragma once

#include <AK/Optional.h>
#include <AK/StringView.h>

namespace Web::CSS {

enum class MathFunction {
    Calc,
""")

    for name in functions_data:
        out.write(f"    {title_casify(name)},\n")

    out.write("""
};

Optional<MathFunction> math_function_from_string(StringView);
""")
    out.write("    \n")
    out.write("""}
""")


def write_implementation_file(out: TextIO, functions_data: dict) -> None:
    out.write("""
// This file is generated by GenerateCSSMathFunctions.cpp

#include <LibWeb/CSS/Enums.h>
#include <LibWeb/CSS/MathFunctions.h>
#include <LibWeb/CSS/Parser/ErrorReporter.h>
#include <LibWeb/CSS/Parser/Parser.h>
#include <LibWeb/CSS/StyleValues/CalculatedStyleValue.h>
#include <LibWeb/CSS/StyleValues/RandomValueSharingStyleValue.h>

namespace Web::CSS {

Optional<MathFunction> math_function_from_string(StringView name)
{
    if (name.equals_ignoring_ascii_case("calc"sv))
        return MathFunction::Calc;
""")

    for name in functions_data:
        out.write(f"""
    if (name.equals_ignoring_ascii_case("{name}"sv))
        return MathFunction::{title_casify(name)};
""")

    out.write("""
    return {};
}

}

namespace Web::CSS::Parser {

static Optional<RoundingStrategy> parse_rounding_strategy(TokenStream<ComponentValue>& stream)
{
    stream.discard_whitespace();
    if (!stream.has_next_token())
        return {};

    auto& ident = stream.consume_a_token();
    if (!ident.is(Token::Type::Ident))
        return {};

    stream.discard_whitespace();
    if (stream.has_next_token())
        return {};

    auto maybe_keyword = keyword_from_string(ident.token().ident());
    if (!maybe_keyword.has_value())
        return {};

    return keyword_to_rounding_strategy(maybe_keyword.value());
}

RefPtr<CalculationNode const> Parser::parse_math_function(Function const& function, CalculationContext const& context)
{
    TokenStream stream { function.value };
    auto arguments = parse_a_comma_separated_list_of_component_values(stream);
    auto const& percentages_resolve_as = context.percentages_resolve_as;
""")

    for name, function_data in functions_data.items():
        parameters = function_data["parameters"]
        parameter_validation_rule = function_data.get("parameter-validation")
        if parameter_validation_rule is not None:
            requires_same_parameters = parameter_validation_rule == "same"
        else:
            requires_same_parameters = True

        name_titlecase = title_casify(name)
        out.write(f'    if (function.name.equals_ignoring_ascii_case("{name}"sv)) {{\n')

        if function_data.get("is-variadic", False):
            # Variadic function
            out.write(f"""
        Optional<NumericType> determined_argument_type;
        Vector<NonnullRefPtr<CalculationNode const>> parsed_arguments;
        parsed_arguments.ensure_capacity(arguments.size());

        for (auto& argument : arguments) {{
            TokenStream<ComponentValue> tokens {{ argument }};
            auto calculation_node = parse_a_calculation(tokens, context);
            if (!calculation_node) {{
                ErrorReporter::the().report(InvalidValueError {{
                    .value_type = "{name}()"_fly_string,
                    .value_string = stream.dump_string(),
                    .description = MUST(String::formatted("Argument #{{}} is not a valid calculation.", parsed_arguments.size())),
                }});
                return nullptr;
            }}

            auto maybe_argument_type = calculation_node->numeric_type();
            if (!maybe_argument_type.has_value()) {{
                ErrorReporter::the().report(InvalidValueError {{
                    .value_type = "{name}()"_fly_string,
                    .value_string = stream.dump_string(),
                    .description = MUST(String::formatted("Argument #{{}} couldn't determine its type.", parsed_arguments.size())),
                }});
                return nullptr;
            }}
            auto argument_type = maybe_argument_type.release_value();

""")
            assert len(parameters) == 1
            parameter_type_string = parameters[0]["type"]
            type_check = generate_calculation_type_check("argument_type", parameter_type_string)
            out.write(f"""
            if (!({type_check})) {{
                ErrorReporter::the().report(InvalidValueError {{
                    .value_type = "{name}()"_fly_string,
                    .value_string = stream.dump_string(),
                    .description = MUST(String::formatted("Argument #{{}} type ({{}}) is not an accepted type.", parsed_arguments.size(), argument_type.dump())),
                }});
                return nullptr;
            }}

            if (!determined_argument_type.has_value()) {{
                determined_argument_type = move(argument_type);
            }} else {{
""")
            if requires_same_parameters:
                out.write(f"""
                if (determined_argument_type != argument_type) {{
                    ErrorReporter::the().report(InvalidValueError {{
                        .value_type = "{name}()"_fly_string,
                        .value_string = stream.dump_string(),
                        .description = MUST(String::formatted("Argument #{{}} type ({{}}) doesn't match type of previous arguments ({{}}).", parsed_arguments.size(), argument_type.dump(), determined_argument_type->dump())),
                    }});
                    return nullptr;
                }}
""")
            else:
                out.write(f"""
                if (auto consistent_type = determined_argument_type->consistent_type(argument_type); consistent_type.has_value()) {{
                    determined_argument_type = consistent_type.release_value();
                }} else {{
                    ErrorReporter::the().report(InvalidValueError {{
                        .value_type = "{name}()"_fly_string,
                        .value_string = stream.dump_string(),
                        .description = MUST(String::formatted("Argument #{{}} type ({{}}) is not consistent with type of previous arguments ({{}}).", parsed_arguments.size(), argument_type.dump(), determined_argument_type->dump())),
                    }});
                    return nullptr;
                }}
""")
            out.write(f"""
            }}

            parsed_arguments.append(calculation_node.release_nonnull());
        }}

        return {name_titlecase}CalculationNode::create(move(parsed_arguments));
    }}
""")

        else:
            # Function with specified parameters.
            min_argument_count = sum(1 for p in parameters if p.get("required") is True)
            max_argument_count = len(parameters)

            if name == "random":
                out.write("""
        if (!context_allows_random_functions())
            return nullptr;

        m_random_function_index++;
""")

            out.write(f"""
        if (arguments.size() < {min_argument_count} || arguments.size() > {max_argument_count}) {{
            ErrorReporter::the().report(InvalidValueError {{
                .value_type = "{name}()"_fly_string,
                .value_string = stream.dump_string(),
                .description = MUST(String::formatted("Wrong number of arguments {{}}, expected between {min_argument_count} and {max_argument_count} inclusive.", arguments.size())),
            }});
            return nullptr;
        }}
        size_t argument_index = 0;
        Optional<NumericType> determined_argument_type;
""")

            for parameter_index, parameter in enumerate(parameters):
                parameter_type_string = parameter["type"]
                parameter_required = parameter["required"]
                parameter_name = parameter["name"]

                if parameter_type_string == "<rounding-strategy>":
                    parameter_is_calculation = False
                    parameter_type = "RoundingStrategy"
                    parse_function = f"parse_rounding_strategy(tokens_{parameter_index})"
                    check_function = ".has_value()"
                    release_function = ".release_value()"
                    default_value = parameter.get("default")
                    if default_value is not None:
                        parameter_default = f" = RoundingStrategy::{title_casify(default_value)}"
                    else:
                        parameter_default = ""
                elif parameter_type_string == "<random-value-sharing>":
                    parameter_is_calculation = False
                    parameter_type = "RefPtr<RandomValueSharingStyleValue const>"
                    parse_function = f"parse_random_value_sharing(tokens_{parameter_index})"
                    check_function = " != nullptr"
                    release_function = ".release_nonnull()"
                    parameter_default = (
                        " = RandomValueSharingStyleValue::create_auto(random_value_sharing_auto_name(), false)"
                    )
                else:
                    # NOTE: This assumes everything not handled above is a calculation node of some kind.
                    parameter_is_calculation = True
                    parameter_type = "RefPtr<CalculationNode const>"
                    parse_function = f"parse_a_calculation(tokens_{parameter_index}, context)"
                    check_function = " != nullptr"
                    release_function = ".release_nonnull()"
                    default_value = parameter.get("default")
                    if default_value is not None:
                        parameter_default = (
                            f" = NumericCalculationNode::from_keyword(Keyword::{title_casify(default_value)}, context)"
                        )
                    else:
                        parameter_default = ""

                out.write(f"""
        {parameter_type} parameter_{parameter_index}{parameter_default};
""")

                if parameter_required:
                    out.write(f"""
        if (argument_index >= arguments.size()) {{
            ErrorReporter::the().report(InvalidValueError {{
                .value_type = "{name}()"_fly_string,
                .value_string = stream.dump_string(),
                .description = "Missing required argument '{parameter_name}'."_string,
            }});
            return nullptr;
        }} else {{
""")
                else:
                    out.write("""
        if (argument_index < arguments.size()) {
""")

                out.write(f"""
            TokenStream tokens_{parameter_index} {{ arguments[argument_index] }};
            auto maybe_parsed_argument_{parameter_index} = {parse_function};
            if (maybe_parsed_argument_{parameter_index}{check_function}) {{
                parameter_{parameter_index} = maybe_parsed_argument_{parameter_index}{release_function};
                argument_index++;
""")
                if parameter_required:
                    out.write(f"""
            }} else {{
                ErrorReporter::the().report(InvalidValueError {{
                    .value_type = "{name}()"_fly_string,
                    .value_string = stream.dump_string(),
                    .description = "Failed to parse required argument '{parameter_name}'."_string,
                }});
                return nullptr;
""")
                out.write("""
            }
        }
""")

                if parameter_is_calculation:
                    parameter_type_variable = f"argument_type_{parameter_index}"
                    type_check = generate_calculation_type_check(parameter_type_variable, parameter_type_string)
                    out.write(f"""
        if (parameter_{parameter_index}) {{
            auto maybe_argument_type_{parameter_index} = parameter_{parameter_index}->numeric_type();
            if (!maybe_argument_type_{parameter_index}.has_value()) {{
                ErrorReporter::the().report(InvalidValueError {{
                    .value_type = "{name}()"_fly_string,
                    .value_string = stream.dump_string(),
                    .description = "Argument '{parameter_name}' couldn't determine its type."_string,
                }});
                return nullptr;
            }}
            auto argument_type_{parameter_index} = maybe_argument_type_{parameter_index}.release_value();

            if (!({type_check})) {{
                ErrorReporter::the().report(InvalidValueError {{
                    .value_type = "{name}()"_fly_string,
                    .value_string = stream.dump_string(),
                    .description = MUST(String::formatted("Argument '{parameter_name}' type ({{}}) is not an accepted type.", argument_type_{parameter_index}.dump())),
                }});
                return nullptr;
            }}

            if (!determined_argument_type.has_value()) {{
                determined_argument_type = argument_type_{parameter_index};
            }} else {{
""")
                    if requires_same_parameters:
                        out.write(f"""
                if (determined_argument_type != argument_type_{parameter_index}) {{
                    ErrorReporter::the().report(InvalidValueError {{
                        .value_type = "{name}()"_fly_string,
                        .value_string = stream.dump_string(),
                        .description = MUST(String::formatted("Argument '{parameter_name}' type ({{}}) doesn't match type of previous arguments ({{}}).", argument_type_{parameter_index}.dump(), determined_argument_type->dump())),
                    }});
                    return nullptr;
                }}
""")
                    else:
                        out.write(f"""
                if (auto consistent_type = determined_argument_type->consistent_type(argument_type_{parameter_index}); consistent_type.has_value()) {{
                    determined_argument_type = consistent_type.release_value();
                }} else {{
                    ErrorReporter::the().report(InvalidValueError {{
                        .value_type = "{name}()"_fly_string,
                        .value_string = stream.dump_string(),
                        .description = MUST(String::formatted("Argument '{parameter_name}' type ({{}}) is not consistent with type of previous arguments ({{}}).", argument_type_{parameter_index}.dump(), determined_argument_type->dump())),
                    }});
                    return nullptr;
                }}
""")
                    out.write("""
            }
        }
""")

            out.write("""
        if (argument_index < arguments.size())
            return nullptr;
""")
            # Generate the call to the constructor
            out.write(f"        return {name_titlecase}CalculationNode::create(")
            for parameter_index, parameter in enumerate(parameters):
                parameter_type_string = parameter["type"]
                if parameter_type_string == "<rounding-strategy>":
                    release_value = ""
                else:
                    if parameter["required"] or parameter.get("default") is not None:
                        release_value = ".release_nonnull()"
                    else:
                        release_value = ""

                if parameter_index == 0:
                    out.write(f"parameter_{parameter_index}{release_value}")
                else:
                    out.write(f", parameter_{parameter_index}{release_value}")
            out.write(""");
    }
""")

    out.write("""
    return nullptr;
}

}
""")


def main():
    parser = argparse.ArgumentParser(description="Generate CSS MathFunctions", add_help=False)
    parser.add_argument("--help", action="help", help="Show this help message and exit")
    parser.add_argument("-h", "--header", required=True, help="Path to the MathFunctions header file to generate")
    parser.add_argument(
        "-c",
        "--implementation",
        required=True,
        help="Path to the MathFunctions implementation file to generate",
    )
    parser.add_argument("-j", "--json", required=True, help="Path to the JSON file to read from")
    args = parser.parse_args()

    with open(args.json, "r", encoding="utf-8") as input_file:
        functions_data = json.load(input_file)

    with open(args.header, "w", encoding="utf-8") as output_file:
        write_header_file(output_file, functions_data)

    with open(args.implementation, "w", encoding="utf-8") as output_file:
        write_implementation_file(output_file, functions_data)


if __name__ == "__main__":
    main()
