import '@tensorflow/tfjs-backend-webgl'
import * as tf from '@tensorflow/tfjs-core'
import { load as loadFaceMeshModel, SupportedPackages } from '@tensorflow-models/face-landmarks-detection'
import { loadGraphModel } from '@tensorflow/tfjs-converter'

// inspired by https://github.com/mirrory-dev/eyeblink

export default class EyeBlink {
  async load () {
    if (this.eyeblinkModel) return Promise.resolve()
    this.eyeblinkModel = await loadGraphModel('/assets/models/model.json')
    this.facemeshModel = await loadFaceMeshModel(SupportedPackages.mediapipeFacemesh, { maxFaces: 1 })
  }

  extractEyeBoundingBox (face, [top, right, bottom, left]) {
    const eyeTopY = face.scaledMesh[top][1]
    const eyeRightX = face.scaledMesh[right][0]
    const eyeBottomY = face.scaledMesh[bottom][1]
    const eyeLeftX = face.scaledMesh[left][0]
    const topLeft = [eyeLeftX, eyeTopY]
    const bottomRight = [eyeRightX, eyeBottomY]
    return { topLeft, bottomRight }
  }

  async getPredictionWithinBoundingBox (input, boundingBoxes) {
    const prediction = tf.tidy(() => {
      const boundingBoxesNormalized = boundingBoxes.map((box) => {
        return [
          box.topLeft[1] / input.shape[0],
          box.topLeft[0] / input.shape[1],
          box.bottomRight[1] / input.shape[0],
          box.bottomRight[0] / input.shape[1]
        ]
      })

      const i = input.expandDims(0)

      const cropped = tf.image
        .cropAndResize(i, boundingBoxesNormalized, boundingBoxesNormalized.map((v, i) => 0), [26, 34])
        .toFloat()
      const grayscale = cropped.mean(3).expandDims(3)
      const inputImage = grayscale.toFloat().div(255)
      return this.eyeblinkModel.predict(inputImage)
    })
    const result = await prediction.data()
    prediction.dispose()
    return result
  }

  getImageData (videoEl, { width, height }) {
    const procCanvas = this.buildCanvas(width, height)
    const ctx = procCanvas.getContext('2d')
    ctx.drawImage(videoEl, 0, 0, procCanvas.width, procCanvas.height)
    const image = ctx.getImageData(0, 0, procCanvas.width, procCanvas.height)
    return image
  }

  buildCanvas (width, height) {
    if (window.OffscreenCanvas) {
      return new OffscreenCanvas(width, height)
    } else {
      const c = document.createElement('canvas')
      c.width = width
      c.height = height
      return c
    }
  }

  async predictEyeOpenness (image, face, rawImage) {
    if (!(image instanceof tf.Tensor)) {
      const tensor = tf.browser.fromPixels(image)
      const result = await this.predictEyeOpenness(tensor, face, image)
      tensor.dispose()
      return result
    }
    if (!face) {
      const facePredictions = await this.facemeshModel.estimateFaces({ input: image })
      if (facePredictions.length === 0)
        return null
      face = facePredictions[0]
    }
    const rightEyeMeshIdx = [27, 243, 23, 130]
    const leftEyeMeshIdx = [257, 359, 253, 362]

    const rightEyeBB = this.extractEyeBoundingBox(face, rightEyeMeshIdx)
    const leftEyeBB = this.extractEyeBoundingBox(face, leftEyeMeshIdx)

    const [rightEyePred, leftEyePred] = await this.getPredictionWithinBoundingBox(image, [
      rightEyeBB,
      leftEyeBB
    ])
    return { right: rightEyePred, left: leftEyePred, boxes: [rightEyeBB, leftEyeBB] }
  }
}
