import { ndjsonParse } from "@redotech/json/ndjson";
import { streamAsyncIterator } from "@redotech/streams-util/async-iterator";
import {
  dehydrate,
  hydrate,
  Serializable,
} from "@redotech/wire-protocol/wire-protocol";
import { ZodError } from "zod";
import {
  InferRpcDefinition,
  RpcClientDefinition,
  RpcDefinition,
  RpcSchema,
  UnaryRpcSchema,
  UnaryStreamRpcSchema,
} from "./definition";
import { ClientError, serializeError } from "./errors";

export type RpcClient<D extends RpcClientDefinition> = {
  [key in keyof D]: (
    input: D[key]["input"],
    options?: RpcRequestOptions,
  ) => D[key]["method"] extends "unary-stream"
    ? AsyncGenerator<D[key]["output"]>
    : Promise<D[key]["output"]>;
};

export type ClientOptions<D extends RpcDefinition> = {
  baseURL: URL;
  defaultMethod?: "GET" | "POST";
  headers?: Record<string, string>;
  onError?: (rpcName: keyof D, error: unknown) => Promise<void> | void;
};

export type RpcRequestOptions = {
  requestHeaders?: Record<string, string>;
  method?: "GET" | "POST";
  signal?: AbortSignal;
};

export function createRpcClient<D extends RpcDefinition>(
  def: D,
  clientOptions: ClientOptions<D>,
): RpcClient<InferRpcDefinition<D>> {
  const client = {} as RpcClient<InferRpcDefinition<D>>;
  for (const key in def) {
    const rpc = def[key];
    if (!rpc) {
      continue;
    }

    if (rpc.method === "unary-stream") {
      client[key] = unaryStreamRpcCaller(key, rpc as any, clientOptions) as any;
    } else {
      client[key] = unaryRpcCaller(key, rpc as any, clientOptions) as any;
    }
  }
  return client;
}

async function doRequest<
  Input extends Serializable,
  Output extends Serializable,
>(
  url: URL,
  name: string,
  def: RpcSchema<Input, Output>,
  params: Input,
  options: {
    method: string;
    headers: Record<string, string>;
    signal?: AbortSignal;
  },
) {
  let input;
  try {
    const zod = def.input.safeParse(params);
    if (!zod.success) {
      throw zod.error;
    }
    input = dehydrate({ input: zod.data });
  } catch (error) {
    throw new Error(
      `RPC Input Validation Error ${name}: ${serializeError(error)}`,
    );
  }
  const method = options.method;

  let response;
  if (method === "POST") {
    response = await fetch(url.toString(), {
      method: "POST",
      headers: { "Content-Type": "application/json", ...options.headers },
      body: JSON.stringify(input),
      signal: options.signal,
    });
  } else if (method === "GET") {
    const getUrl = new URL(url);
    const params = inputToSearchParams(input);
    getUrl.search = params.toString();

    response = await fetch(getUrl.toString(), {
      method: "GET",
      headers: { ...options.headers },
      signal: options?.signal,
    });
  } else {
    throw new Error(`Unsupported HTTP method ${method}`);
  }
  if (!response?.ok) {
    throw response;
  }
  return response;
}

async function handleRpcError(
  name: string,
  error: any,
  onError: ClientOptions<any>["onError"],
) {
  if (error instanceof Error && error.name === "AbortError") {
    throw new Error(`Request aborted for RPC method ${name}`);
  }

  await onError?.(name, error);
  if (error instanceof Response) {
    let message: string | undefined;
    let code: string | undefined;
    if (error.headers.get("Content-Type")?.includes("application/json")) {
      const json: unknown = await error.json();

      if (json && typeof json === "object") {
        const error = "error" in json ? String(json.error) : undefined;
        const m = "message" in json ? String(json.message) : undefined;
        code = "code" in json ? String(json.code) : undefined;
        message = error || m;
      }
    } else {
      message = await error.text();
    }
    throw new ClientError(error.status, message ?? "Unknown error", code);
  } else if (error instanceof ZodError) {
    throw new Error(`RPC Validation Error ${name}: ${serializeError(error)}`);
  } else if (typeof error === "object") {
    const e = "error" in error ? String(error.error) : undefined;
    const m = "message" in error ? String(error.message) : undefined;
    const code = "code" in error ? String(error.code) : undefined;
    const message = e || m;
    throw new ClientError(0, message ?? "Unknown error", code);
  }
}

function unaryRpcCaller<
  Input extends Serializable,
  Output extends Serializable,
>(
  name: string,
  def: UnaryRpcSchema<Input, Output>,
  { baseURL, defaultMethod = "POST", headers, onError }: ClientOptions<any>,
) {
  const url = new URL(
    `${baseURL.pathname.replace(/\/$/, "")}/${name}`,
    baseURL,
  );
  return async (
    params: Input,
    options?: RpcRequestOptions,
  ): Promise<Output> => {
    try {
      const response = await doRequest(url, name, def, params, {
        method: options?.method ?? defaultMethod,
        headers: { ...headers, ...options?.requestHeaders },
        signal: options?.signal,
      });
      const json: unknown = await response.json();
      if (!json || typeof json !== "object" || !("output" in json)) {
        throw new TypeError("Network response did not contain RPC output");
      }

      let data: Output;
      try {
        const hydratedOutput = hydrate(json.output);
        const zod = def.output.safeParse(hydratedOutput);
        if (!zod.success) {
          throw zod.error;
        }
        data = zod.data;
      } catch (error) {
        throw new Error(
          `RPC Output Validation Error ${name}: ${serializeError(error)}`,
        );
      }

      return data;
    } catch (error) {
      await handleRpcError(name, error, onError);
      throw error;
    }
  };
}

function unaryStreamRpcCaller<
  Input extends Serializable,
  Output extends Serializable,
>(
  name: string,
  def: UnaryStreamRpcSchema<Input, Output>,
  { baseURL, defaultMethod = "POST", headers, onError }: ClientOptions<any>,
) {
  const url = new URL(
    `${baseURL.pathname.replace(/\/$/, "")}/${name}`,
    baseURL,
  );
  return async function* (
    params: Input,
    options?: RpcRequestOptions,
  ): AsyncIterableIterator<Output> {
    try {
      const response = await doRequest(url, name, def, params, {
        method: options?.method ?? defaultMethod,
        headers: { ...headers, ...options?.requestHeaders },
        signal: options?.signal,
      });
      const data: ReadableStream<any> | null = response.body;
      if (data) {
        const reader = data.pipeThrough(new TextDecoderStream());
        for await (const result of ndjsonParse(streamAsyncIterator(reader))) {
          if (!result || typeof result !== "object" || Array.isArray(result)) {
            throw new TypeError("Network response did not contain RPC output");
          }
          if ("output" in result) {
            try {
              const hydratedOutput = hydrate(result.output);
              const zod = def.output.safeParse(hydratedOutput);
              if (!zod.success) {
                throw zod.error;
              }
              yield zod.data;
            } catch (error) {
              throw new Error(
                `RPC Output Validation Error ${name}: ${serializeError(error)}`,
              );
            }
          } else if ("heartbeat" in result) {
            continue;
          } else if ("error" in result) {
            await handleRpcError(name, result, onError);
            return;
          }
        }
      }
    } catch (error) {
      await handleRpcError(name, error, onError);
      throw error;
    }
  };
}

function inputToSearchParams(input: any): URLSearchParams {
  const searchParams = new URLSearchParams();
  if (typeof input?.input !== "object") {
    throw new Error("GET RPC input must be an object");
  }
  for (const key in input.input) {
    const value = input.input[key];
    if (value === undefined) {
      continue;
    }
    if (value === null) {
      searchParams.append(key, "");
    } else {
      searchParams.append(key, JSON.stringify(value));
    }
  }
  return searchParams;
}
