import { Draft } from "immer";
import {
  DebugMessage,
  DebugMessageQueryEnd,
  DebugMessageQueryStart,
  DebugMessageSolverAborted,
  DebugMessageSolverEnd,
  DebugMessageSolverLog,
  DebugMessageSolverScheduled,
  DebugMessageSolverStart,
  LlSolution,
  ResponseTopLevelAnswer,
  SolveResponse,
  SolverInfo,
} from "../reasoning-engine";
import { makeEmptyQueryState } from "./query";
import { QuestionState } from "./question";
import { makeEmptySolverNodeState } from "./solverNode";
import { StatusRunning } from "./status";

export const updateQuestionState = (
  questionState: Draft<QuestionState>,
  resp: SolveResponse,
  offline = false
) => {
  switch (resp["@type"]) {
    case "TOP_LEVEL_ANSWER":
      handleResponseTopLevelAnswer(questionState, resp, offline);
      return;
    case "SOLVER_INFO":
      handleSolverInfo(questionState, resp.solverInfo);
      return;
    case "DEBUG_MESSAGE":
      handleDebugMessage(questionState, resp.message);
      return;
    case "HEARTBEAT":
      return;
    case "INTERNAL_ERROR":
      handleInternalErrorMessage(
        questionState,
        resp.errorClass,
        resp.message,
        resp.stackTrace
      );
      return;
    case "NETWORK_ERROR":
      handleNetworkErrorMessage(questionState, resp.message);
      return;
    default:
      console.warn("Received unknown response", resp);
  }
};

const handleResponseTopLevelAnswer = (
  questionState: Draft<QuestionState>,
  resp: ResponseTopLevelAnswer,
  offline: boolean
) => {
  if (!offline) {
    resp.timeTakenMs = getNowUtc() - questionState.queryStartTime;
  }
  questionState.topLevelAnswers.push(resp);
  resp.answer.solutionInvalidations.forEach((invalidatedSolutionId) => {
    questionState.invalidatedTopLevelAnswerIds.push(invalidatedSolutionId);
  });
};

const handleSolverInfo = (
  questionState: Draft<QuestionState>,
  solverInfo: SolverInfo
) => {
  questionState.solverInfo.set(solverInfo.solverId, solverInfo);
};

const handleDebugMessage = (
  questionState: Draft<QuestionState>,
  debugMessage: DebugMessage
) => {
  questionState.numDebugMessages += 1;

  switch (debugMessage["@type"]) {
    case "SOLUTION":
      handleSolution(questionState, debugMessage.solution);
      break;
    case "STATUS_QUERY_START":
      handleDebugMessageQueryStart(questionState, debugMessage);
      break;
    case "STATUS_QUERY_END":
      handleDebugMessageQueryEnd(questionState, debugMessage);
      break;
    case "STATUS_SOLVER_SCHEDULED":
      handleDebugMessageSolverScheduled(questionState, debugMessage);
      break;
    case "STATUS_SOLVER_START":
      handleDebugMessageSolverStart(questionState, debugMessage);
      break;
    case "STATUS_SOLVER_END":
      handleDebugMessageSolverEnd(questionState, debugMessage);
      break;
    case "STATUS_SOLVER_LOG":
      handleDebugMessageSolverLog(questionState, debugMessage);
      break;
    case "STATUS_SOLVER_ABORTED":
      handleDebugMessageSolverAborted(questionState, debugMessage);
      break;
    default:
      console.warn("Received unknown DebugMessage:", debugMessage);
  }
};

const handleNetworkErrorMessage = (
  questionState: Draft<QuestionState>,
  message: string
) => {
  const statusType = questionState.status["@type"];

  if (statusType === "INTERNAL_ERROR" || statusType === "NETWORK_ERROR") {
    console.warn("Received multiple errors during processing.");
  }

  questionState.status = {
    "@type": "NETWORK_ERROR",
    startedAt: (questionState.status as StatusRunning).startedAt,
    endedAt: getNowUtc(),
    error: { message },
  };
};

const handleInternalErrorMessage = (
  questionState: Draft<QuestionState>,
  errorClass: string,
  message: string | null,
  stackTrace: string
) => {
  const statusType = questionState.status["@type"];

  if (statusType === "INTERNAL_ERROR" || statusType === "NETWORK_ERROR") {
    console.warn("Received multiple errors during processing.");
  }

  questionState.status = {
    "@type": "INTERNAL_ERROR",
    startedAt: (questionState.status as StatusRunning).startedAt,
    endedAt: getNowUtc(),
    error: { errorClass, message, stackTrace },
  };
};

const handleSolution = (
  questionState: Draft<QuestionState>,
  solution: LlSolution
) => {
  questionState.allSolutions.set(solution.solutionId, [
    solution,
    ...(questionState.allSolutions.get(solution.solutionId) ?? []),
  ]);

  // Update solver node
  const solverNode = questionState.solverNodeStates.get(solution.solverNodeId);

  if (solverNode) {
    solverNode.solutionIds.push(solution.solutionId);

    solverNode.parentQueryIds.forEach((parentQueryId) => {
      const parentQueryState = questionState.queryStates.get(parentQueryId);
      if (parentQueryState) {
        parentQueryState.solutionIds.push(solution.solutionId);
      }
    });
  }
};

const handleDebugMessageQueryStart = (
  questionState: Draft<QuestionState>,
  debugMessage: DebugMessageQueryStart
) => {
  const queryState = makeEmptyQueryState(
    debugMessage.queryId,
    debugMessage.query,
    debugMessage.queryAsQuestion,
    debugMessage.forRulePassage,
    debugMessage.parentSolverNodeId,
    debugMessage.timestamp
  );

  if (debugMessage.parentSolverNodeId) {
    const parentSolverState = questionState.solverNodeStates.get(
      debugMessage.parentSolverNodeId
    );
    if (parentSolverState) {
      parentSolverState.childQueryIds.push(debugMessage.queryId);
    }
  } else {
    questionState.topLevelQueryIds.push(debugMessage.queryId);
  }

  questionState.queryStates.set(debugMessage.queryId, queryState);
};

const handleDebugMessageQueryEnd = (
  questionState: Draft<QuestionState>,
  debugMessage: DebugMessageQueryEnd
) => {
  const queryState = questionState.queryStates.get(debugMessage.queryId);
  if (queryState) {
    queryState.status = {
      "@type": "COMPLETED",
      startedAt: (queryState.status as StatusRunning).startedAt,
      endedAt: debugMessage.timestamp,
    };
    // queryState.numSolutions = debugMessage.numSolutions; // redundant
  }
};

const handleDebugMessageSolverScheduled = (
  questionState: Draft<QuestionState>,
  debugMessage: DebugMessageSolverScheduled
) => {
  const solverNodeState = makeEmptySolverNodeState(
    debugMessage.solverNodeId,
    debugMessage.solverId,
    debugMessage.timestamp,
    debugMessage.parentQueryId
  );

  questionState.solverNodeStates.set(
    debugMessage.solverNodeId,
    solverNodeState
  );

  // Add denormalised information
  const parentQueryState = questionState.queryStates.get(
    debugMessage.parentQueryId
  );

  if (parentQueryState !== undefined) {
    solverNodeState.parentUlQuery = parentQueryState.ulQuery;
    parentQueryState.solverNodeIds.push(debugMessage.solverNodeId);
  } else {
    console.log("Missing " + debugMessage.parentQueryId);
  }
};

const handleDebugMessageSolverStart = (
  questionState: Draft<QuestionState>,
  debugMessage: DebugMessageSolverStart
) => {
  const solverNodeState = questionState.solverNodeStates.get(
    debugMessage.solverNodeId
  );
  if (solverNodeState) {
    solverNodeState.status = {
      "@type": "RUNNING",
      startedAt: debugMessage.timestamp,
    };
    solverNodeState.cached = debugMessage.cached;
  }
  // Add links between executing solvers and their non-executing counterparts
  if (debugMessage.executingSolverNodeId) {
    const solver = questionState.solverNodeStates.get(
      debugMessage.solverNodeId
    );
    if (solver) {
      solver.executingSolverId = debugMessage.executingSolverNodeId;
    } else {
      console.log("Missing solver with id " + debugMessage.solverNodeId);
    }
  }
};

const handleDebugMessageSolverEnd = (
  questionState: Draft<QuestionState>,
  debugMessage: DebugMessageSolverEnd
) => {
  const solverNodeState = questionState.solverNodeStates.get(
    debugMessage.solverNodeId
  );
  if (solverNodeState) {
    solverNodeState.status = {
      "@type": "COMPLETED",
      startedAt: (solverNodeState.status as StatusRunning).startedAt,
      endedAt: debugMessage.timestamp,
    };
    // solverNodeState.numSolutions = debugMessage.numSolutions; // redundant
  }
};

const handleDebugMessageSolverLog = (
  questionState: Draft<QuestionState>,
  debugMessage: DebugMessageSolverLog
) => {
  const solverNodeState = questionState.solverNodeStates.get(
    debugMessage.solverNodeId
  );
  if (solverNodeState) {
    solverNodeState.logMessages.push(debugMessage);
  }
};

const handleDebugMessageSolverAborted = (
  questionState: Draft<QuestionState>,
  debugMessage: DebugMessageSolverAborted
) => {
  const solverNodeState = questionState.solverNodeStates.get(
    debugMessage.solverNodeId
  );

  if (solverNodeState) {
    solverNodeState.status = {
      "@type": "ABORTED",
      startedAt: (solverNodeState.status as StatusRunning).startedAt,
      endedAt: debugMessage.timestamp,
      reason: debugMessage.reason,
    };
  }
};

export const getNowUtc = () => {
  return new Date().getTime();
};

export const millisToString = (millis: number) => {
  if (millis < 1000) {
    return `${millis}ms`;
  }

  return `${(millis / 1000).toFixed(2)}s`;
};

const bytesUnits = ["bytes", "KiB", "MiB", "GiB", "TiB"];

/**
 * Formats a number of bytes into a human-readable string using the units in {@link bytesUnits}. It
 * returns it in the greatest appropriate unit (i.e. that which returns a value between 1 and 1024).
 * For readability, it returns to 1 decimal place if the resulting value is less than 10, else
 * returns as an integer.
 *
 *
 * E.g. 7 -> "7 bytes", 3398 -> "3.3 KiB", 490398 -> "479 KiB", 6544528 -> "6.2 MiB",
 * 23483023 -> "22 MiB" etc.
 */
export function formatByteString(numberOfBytes: number) {
  let unitIndexToUse = 0;
  while (numberOfBytes >= 1024 && ++unitIndexToUse) {
    numberOfBytes = numberOfBytes / 1024;
  }
  return (
    numberOfBytes.toFixed(numberOfBytes < 10 && unitIndexToUse > 0 ? 1 : 0) +
    " " +
    bytesUnits[unitIndexToUse]
  );
}
