import { Landmark } from "@mediapipe/pose";
import { Vector3 } from "three";

//TODO: move to separate file
export interface Keypoint3D {
  x: number;
  y: number;
  z: number;
}
//TODO: move to separate file
export interface Measurement3D {
  P: Keypoint3D;
  K: Keypoint3D;
  X: Keypoint3D;
  pos3d: Keypoint3D;
  prevPos3d: Keypoint3D[];
}

export class CustomKalmanFilter {
  private KalmanParamQ: number;
  private KalmanParamR: number;

  private lowPasParam = 0.1;

  constructor(q: number = 0.001, r: number = 0.0015) {
    this.KalmanParamQ = q;
    //NOTE: with smaller values filter result is less smooth
    this.KalmanParamR = r;
  }

  filter3d = (measurement: Measurement3D, now2d: Landmark): Measurement3D => {

    this.measurementUpdate3d(measurement);
    const x = measurement.X.x + (now2d.x - measurement.X.x) * measurement.K.x;
    const y = measurement.X.y + (now2d.y - measurement.X.y) * measurement.K.y;
    const z = measurement.X.z + (now2d.z - measurement.X.z) * measurement.K.z;

    const pose: Keypoint3D = {
      x: x,
      y: y,
      z: z,
    };
    measurement.X = pose;
    measurement.pos3d = pose;
    return measurement;
  };

  measurementUpdate3d = (measurement: Measurement3D) => {
    measurement.K.x =
      (measurement.P.x + this.KalmanParamQ) /
      (measurement.P.x + this.KalmanParamQ + this.KalmanParamR);
    measurement.K.y =
      (measurement.P.y + this.KalmanParamQ) /
      (measurement.P.y + this.KalmanParamQ + this.KalmanParamR);
    measurement.K.z =
      (measurement.P.z + this.KalmanParamQ) /
      (measurement.P.z + this.KalmanParamQ + this.KalmanParamR);
    measurement.P.x =
      (this.KalmanParamR * (measurement.P.x + this.KalmanParamQ)) /
      (this.KalmanParamR + measurement.P.x + this.KalmanParamQ);
    measurement.P.y =
      (this.KalmanParamR * (measurement.P.y + this.KalmanParamQ)) /
      (this.KalmanParamR + measurement.P.y + this.KalmanParamQ);
    measurement.P.z =
      (this.KalmanParamR * (measurement.P.z + this.KalmanParamQ)) /
      (this.KalmanParamR + measurement.P.z + this.KalmanParamQ);
  };

  lowPassFilter3d = (measurements: Measurement3D[]) => {
    measurements.forEach((jp) => {
      jp.prevPos3d[0] = jp.pos3d;
      for (let index = 1; index < jp.prevPos3d.length; index++) {
        const current = jp.prevPos3d[index];
        const prev = jp.prevPos3d[index - 1];
        const vectorCurrent = new Vector3(current.x, current.y, current.z);
        const vectorPrev = new Vector3(prev.x, prev.y, prev.z);

        const multiplyCurrent = vectorCurrent.multiplyScalar(this.lowPasParam);
        const multiplyPrev = vectorPrev.multiplyScalar(1 - this.lowPasParam);

        jp.prevPos3d[index] = multiplyCurrent.add(multiplyPrev);
      }
      jp.pos3d = jp.prevPos3d[jp.prevPos3d.length - 1];
    });
  };

  getInitialMeasurements = (): Measurement3D[] => {
    const initalMeasurements: Measurement3D[] = [];

    const getPrevPoseArray = (): any[] => {
      const keypoints: any[] = [];
      for (let index = 0; index < 6; index++) {
        const element = { x: 0, y: 0, z: 0 };
        keypoints.push(element);
      }
      return keypoints;
    };

    const jointsCount = 33; //17;
    for (let index = 0; index < jointsCount; index++) {
      const element: Measurement3D = {
        X: { x: 0, y: 0, z: 0 },
        P: { x: 0, y: 0, z: 0 },
        K: { x: 0, y: 0, z: 0 },
        pos3d: { x: 0, y: 0, z: 0 },
        prevPos3d: getPrevPoseArray(),
      };
      initalMeasurements.push(element);
    }

    return initalMeasurements;
  };

  getInitialMeasurementsHolistic = (jointsCount: number): Measurement3D[] => {
    const initalMeasurements: Measurement3D[] = [];

    const getPrevPoseArray = (): any[] => {
      const keypoints: any[] = [];
      for (let index = 0; index < 6; index++) {
        const element = { x: 0, y: 0, z: 0 };
        keypoints.push(element);
      }
      return keypoints;
    };
    
    for (let index = 0; index < jointsCount; index++) {
      const element: Measurement3D = {
        X: { x: 0, y: 0, z: 0 },
        P: { x: 0, y: 0, z: 0 },
        K: { x: 0, y: 0, z: 0 },
        pos3d: { x: 0, y: 0, z: 0 },
        prevPos3d: getPrevPoseArray(),
      };
      initalMeasurements.push(element);
    }

    return initalMeasurements;
  };
}

export const applyKalmanFilterUnity = (
  measurements: Measurement3D[],
  joints: Landmark[],
  customKalmanFilter: CustomKalmanFilter
) => {
  const newMeasurements = joints.map((joint: Landmark, idx: number) => {
    const measurement = measurements[idx];
    const newMeasurement = customKalmanFilter.filter3d(measurement, joint);
    return newMeasurement;
  });

  customKalmanFilter.lowPassFilter3d(newMeasurements);

  return newMeasurements;
};
