#!/usr/bin/env python3

import argparse
import base64
import http.server
import json
import os
import socket
import socketserver
import sys
import time
import urllib.parse

from typing import Dict
from typing import Optional

"""
Description:
    This script starts a simple HTTP echo server on localhost for use in our in-tree tests.
    The port is assigned by the OS on startup and printed to stdout.

Endpoints:
    - POST /echo <json body>, Creates an echo response for later use. See "Echo" class below for body properties.
"""


class Echo:
    method: str
    path: str
    status: int
    headers: Dict[str, str]
    body: Optional[str]
    body_encoding: str
    delay_ms: Optional[int]
    reason_phrase: Optional[str]
    reflect_headers_in_body: bool
    close_connection: bool

    def __eq__(self, other):
        if not isinstance(other, Echo):
            return NotImplemented

        return (
            self.method == other.method
            and self.path == other.path
            and self.status == other.status
            and self.body == other.body
            and self.body_encoding == other.body_encoding
            and self.delay_ms == other.delay_ms
            and self.headers == other.headers
            and self.reason_phrase == other.reason_phrase
            and self.reflect_headers_in_body == other.reflect_headers_in_body
            and self.close_connection == other.close_connection
        )


# In-memory store for echo responses
echo_store: Dict[str, Echo] = {}

# Headers from the most recent request at each echo path, queryable via GET /recorded-request-headers<echo-path>.
recorded_request_headers: Dict[str, Dict[str, list]] = {}


class TestHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
    static_directory: str
    wpt_directory: str

    def __init__(self, *arguments, **kwargs):
        super().__init__(*arguments, directory=self.static_directory, **kwargs)

    def end_headers(self):
        if hasattr(self, "_extra_headers"):
            for key, value in self._extra_headers:
                self.send_header(key, value)
            del self._extra_headers
        super().end_headers()

    def _serve_static_request(self):
        if self.path.startswith("/static/"):
            # Explicit /static/ URLs continue to serve files from the general test root.
            self.directory = self.static_directory
            self.path = self.path[7:]
        else:
            # All other non-echo URLs are served from the imported WPT tree.
            # This lets absolute WPT paths like /html/... resolve through the test server.
            self.directory = self.wpt_directory

        file_path = self.translate_path(self.path)
        headers_path = file_path + ".headers"

        if os.path.isfile(headers_path):
            self._extra_headers = []
            with open(headers_path) as f:
                for line in f:
                    line = line.strip()
                    if ":" in line:
                        key, _, value = line.partition(":")
                        self._extra_headers.append((key.strip(), value.strip()))

        super().do_GET()

    def do_GET(self):
        if self.path.startswith("/echo"):
            self.handle_echo()
        elif self.path.startswith("/recorded-request-headers/"):
            self._serve_recorded_request_headers()
        else:
            self._serve_static_request()

    def do_POST(self):
        if self.path == "/echo":
            self._register_echo()
        elif self.path.startswith("/static/"):
            self.send_error(405, "Method Not Allowed")
        else:
            self.handle_echo()

    def do_OPTIONS(self):
        if self.path.startswith("/echo"):
            # Requests with "credentials=include" cannot have "Access-Control-Allow-Origin=*". If the test registered
            # an OPTIONS echo, return the headers that it specified.
            key = f"OPTIONS {self.path}"
            if key in echo_store:
                self.handle_echo()
                return

            self.send_response(204)
            self.send_header("Access-Control-Allow-Origin", "*")
            self.send_header("Access-Control-Allow-Methods", "*")
            self.send_header("Access-Control-Allow-Headers", "*")
            self.end_headers()
        else:
            self.do_other()

    def do_PUT(self):
        self.do_other()

    def do_HEAD(self):
        self.do_other()

    def do_DELETE(self):
        self.do_other()

    def _register_echo(self):
        """Handle a request to register an echo server handler"""
        content_length = int(self.headers["Content-Length"])
        post_data = self.rfile.read(content_length)
        data = json.loads(post_data.decode("utf-8"))

        echo = Echo()
        echo.method = data.get("method", None)
        echo.path = data.get("path", None)
        echo.status = data.get("status", None)
        echo.body = data.get("body", None)
        echo.body_encoding = data.get("body_encoding", "raw")
        echo.delay_ms = data.get("delay_ms", None)
        echo.headers = data.get("headers", {})
        echo.reason_phrase = data.get("reason_phrase", None)
        echo.reflect_headers_in_body = data.get("reflect_headers_in_body", False)
        echo.close_connection = data.get("close_connection", False)

        is_invalid_echo_path = echo.path is None or not echo.path.startswith("/echo/")

        # Return 400: Bad Request if invalid params are given or a reserved path is given
        if (
            echo.method is None
            or echo.path is None
            or echo.status is None
            or echo.body_encoding not in ("raw", "base64")
            or (echo.body is not None and "$HEADERS" not in echo.body and echo.reflect_headers_in_body)
            or is_invalid_echo_path
        ):
            self.send_response(400)
            self.send_header("Content-Type", "text/plain")
            self.end_headers()
            return

        # Return 409: Conflict if the method+path combination already exists
        key = f"{echo.method} {urllib.parse.urlparse(echo.path).path}"
        if key in echo_store and echo_store[key] != echo:
            self.send_response(409)
            self.send_header("Content-Type", "text/plain")
            self.end_headers()
            message = (
                "Echo already exists for method+path, but with a different definition.\n"
                f"key: {key}\n"
                "Hint: Use a unique path per test run (or keep the same definition).\n"
            )
            self.wfile.write(message.encode("utf-8"))
            return

        echo_store[key] = echo

        host = self.headers.get("host", "localhost")
        path = echo.path.lstrip("/")
        fetch_url = f"http://{host}/{path}"

        # The params to use on the client when making a request to the newly created echo endpoint
        fetch_config = {
            "method": echo.method,
            "url": fetch_url,
        }

        self.send_response(201)
        self.send_header("Access-Control-Allow-Origin", "*")
        self.send_header("Content-Type", "application/json")
        self.end_headers()
        self.wfile.write(json.dumps(fetch_config).encode("utf-8"))

    def _serve_recorded_request_headers(self):
        echo_path = self.path[len("/recorded-request-headers") :]
        headers = recorded_request_headers.get(echo_path)
        if headers is None:
            self.send_error(404, f"No recorded request at {echo_path}")
            return
        self.send_response(200)
        self.send_header("Access-Control-Allow-Origin", "*")
        self.send_header("Content-Type", "application/json")
        self.end_headers()
        self.wfile.write(json.dumps(headers).encode("utf-8"))

    def handle_echo(self):
        method = self.command.upper()
        parsed_url = urllib.parse.urlparse(self.path)
        query = urllib.parse.parse_qs(parsed_url.query)
        key = f"{method} {self.path}"
        if key not in echo_store:
            key = f"{method} {parsed_url.path}"

        headers_for_path: Dict[str, list[str]] = {}
        for header, value in self.headers.items():
            headers_for_path.setdefault(header, []).append(value)
        recorded_request_headers[self.path] = headers_for_path

        if parsed_url.path != self.path:
            recorded_request_headers[parsed_url.path] = recorded_request_headers[self.path]

        is_revalidation_request = "If-Modified-Since" in self.headers
        send_not_modified = is_revalidation_request and "X-Ladybird-Respond-With-Not-Modified" in self.headers

        send_incomplete_response = "X-Ladybird-Respond-With-Incomplete-Response" in self.headers

        set_invalid_cookie = "X-Ladybird-Set-Invalid-Cookie" in self.headers

        if key not in echo_store:
            self.send_error(404, f"Echo response not found for {key}")
            return

        echo = echo_store[key]

        if echo.close_connection:
            self.connection.shutdown(socket.SHUT_WR)
            self.connection.close()
            return

        response_headers = echo.headers.copy()

        if echo.delay_ms is not None:
            time.sleep(echo.delay_ms / 1000)

        if send_not_modified:
            self.send_response(304)
        else:
            self.send_response_only(echo.status, echo.reason_phrase)

            if is_revalidation_request:
                # Override the Last-Modified header to prevent cURL from thinking the response is still fresh.
                response_headers["Last-Modified"] = "Thu, 01 Jan 1970 00:00:00 GMT"
            elif send_incomplete_response:
                # We emulate an incomplete response by advertising a 10KB file, but only sending 2KB.
                response_headers["Content-Length"] = str(10 * 1024)

        if set_invalid_cookie:
            response_headers["Set-Cookie"] = "invalid=foo; Domain=\xc3\xa9\x6c\xc3\xa8\x76\x65\xff"

        # Set only the headers defined in the echo definition
        if response_headers:
            for header, value in response_headers.items():
                self.send_header(header, value)
            self.end_headers()

        if send_not_modified:
            return

        if send_incomplete_response:
            self.wfile.write(b"a" * (2 * 1024))
            self.wfile.flush()

            self.connection.shutdown(socket.SHUT_WR)
            self.connection.close()
            return

        if echo.reflect_headers_in_body:
            headers = {}
            for key in self.headers.keys():
                headers[key] = self.headers.get_all(key)
            headers = json.dumps(headers)
            response_body = echo.body.replace("$HEADERS", headers) if echo.body else headers
        else:
            response_body = echo.body or ""

        # FIXME: This only supports "Range: bytes=start-end" and "Range: bytes=start-". There are other formats to
        #        support if needed: https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Range#syntax
        if "Range" in self.headers:
            range_value = self.headers["Range"].strip()
            assert range_value.startswith("bytes=")
            assert range_value.count("-") == 1

            range_value = range_value[len("bytes=") :]
            start, end = range_value.split("-")

            if end:
                response_body = response_body[int(start) : min(int(end), len(response_body))]
            else:
                response_body = response_body[int(start) :]

        if echo.body_encoding == "base64":
            response_body_bytes = base64.b64decode(response_body)
        else:
            response_body_bytes = response_body.encode("utf-8")

        chunks = query.get("chunks", [])
        chunk_delay_ms = int(query.get("chunk_delay_ms", [0])[0] or 0)
        if chunks:
            chunk_sizes = [int(chunk_size) for chunk_size in chunks[0].split(",") if chunk_size]
            offset = 0
            for chunk_size in chunk_sizes:
                self.wfile.write(response_body_bytes[offset : offset + chunk_size])
                self.wfile.flush()
                offset += chunk_size
                if chunk_delay_ms > 0:
                    time.sleep(chunk_delay_ms / 1000)
            if offset < len(response_body_bytes):
                self.wfile.write(response_body_bytes[offset:])
            return

        self.wfile.write(response_body_bytes)

    def do_other(self):
        if self.path.startswith("/echo"):
            self.handle_echo()
        else:
            self.send_error(405, "Method Not Allowed")


def start_server(port, static_directory):
    TestHTTPRequestHandler.static_directory = os.path.abspath(static_directory)
    TestHTTPRequestHandler.wpt_directory = os.path.join(
        TestHTTPRequestHandler.static_directory, "Text", "input", "wpt-import"
    )
    httpd = socketserver.TCPServer(("127.0.0.1", port), TestHTTPRequestHandler)

    print(httpd.socket.getsockname()[1])
    sys.stdout.flush()

    try:
        httpd.serve_forever()
    except KeyboardInterrupt:
        pass
    finally:
        httpd.server_close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run a HTTP echo server")
    parser.add_argument(
        "-d",
        "--directory",
        type=str,
        default=".",
        help="Directory to serve static files from",
    )
    parser.add_argument(
        "-p",
        "--port",
        type=int,
        default=0,
        help="Port to run the server on",
    )
    args = parser.parse_args()

    start_server(port=args.port, static_directory=args.directory)
