import * as PoseDetection from "@tensorflow-models/pose-detection";
import {
  DEFAULT_RADIUS,
  DEFAULT_LINE_WIDTH,
  BLAZEPOSE_MODEL,
  BLAZEPOSE_SCORE_THRESHOLD,
} from "@utils/Erised";

const KeypointIndexBySide = PoseDetection.util.getKeypointIndexBySide(
  BLAZEPOSE_MODEL
);

/**
 * Draw video content in canvas.
 */
export const drawCtx = (
  context: CanvasRenderingContext2D,
  video: HTMLVideoElement,
  width: number,
  height: number
) => {
  context.drawImage(video, 0, 0, width, height);
};

/**
 * Clear canvas.
 */
export const clearCtx = (
  context: CanvasRenderingContext2D,
  width: number,
  height: number
) => {
  context.clearRect(0, 0, width, height);
};

/**
 * Draw the keypoints and skeleton on the video.
 */
export const drawPose = (
  context: CanvasRenderingContext2D,
  pose: PoseDetection.Pose,
  keypointColor: string,
) => {
  if (pose.keypoints != null) {
    drawKeypoints(context, pose.keypoints, keypointColor, keypointColor);
    drawSkeleton(context, pose.keypoints, keypointColor, keypointColor);
  }
};

/**
 * Draw the keypoints on the video.
 * @param keypoints A list of keypoints.
 */
const drawKeypoints = (
  context: CanvasRenderingContext2D,
  keypoints: PoseDetection.Keypoint[],
  fillStyle: string = "white",
  strokeStyle: string = "white"
) => {
  context.globalAlpha = 0.75;

  context.fillStyle = fillStyle;
  context.strokeStyle = strokeStyle;
  context.lineWidth = DEFAULT_LINE_WIDTH;

  for (const i of KeypointIndexBySide.middle) {
    drawKeypoint(context, keypoints[i]);
  }

  context.fillStyle = "white";
  for (const i of KeypointIndexBySide.left) {
    drawKeypoint(context, keypoints[i]);
  }

  context.fillStyle = "white";
  for (const i of KeypointIndexBySide.right) {
    drawKeypoint(context, keypoints[i]);
  }
};

/**
 * Draw a single keypoint.
 * @param keypoints A list of keypoints.
 */
const drawKeypoint = (
  context: CanvasRenderingContext2D,
  keypoint: PoseDetection.Keypoint
) => {
  // If score is null, just show the keypoint.
  const score = keypoint.score != null ? keypoint.score : 1;
  const scoreThreshold = BLAZEPOSE_SCORE_THRESHOLD || 0;

  if (score >= scoreThreshold) {
    const circle = new Path2D();
    circle.arc(keypoint.x, keypoint.y, DEFAULT_RADIUS, 0, 2 * Math.PI);
    context.fill(circle);
    context.stroke(circle);
  }
};

/**
 * Draw the skeleton of a body on the video.
 */
const drawSkeleton = (
  context: CanvasRenderingContext2D,
  keypoints: PoseDetection.Keypoint[],
  fillStyle: string = "white",
  strokeStyle: string = "white"
) => {
  context.fillStyle = fillStyle;
  context.strokeStyle = strokeStyle;
  context.lineWidth = DEFAULT_LINE_WIDTH;

  PoseDetection.util.getAdjacentPairs(BLAZEPOSE_MODEL).forEach(([i, j]) => {
    const kp1 = keypoints[i];
    const kp2 = keypoints[j];

    // If score is null, just show the keypoint.
    const score1 = kp1.score != null ? kp1.score : 1;
    const score2 = kp2.score != null ? kp2.score : 1;
    const scoreThreshold = BLAZEPOSE_SCORE_THRESHOLD || 0;

    if (score1 >= scoreThreshold && score2 >= scoreThreshold) {
      context.beginPath();
      context.moveTo(kp1.x, kp1.y);
      context.lineTo(kp2.x, kp2.y);
      context.stroke();
    }
  });
};
