import log from 'loglevel'
import PRNG from './PRNG'
import { Crossfade, MixType, Transition, TransitionPrimary, TransitionSecondary } from './Crossfade'
import { GetCloselyRelatedKeys, GetKeyTransposed, GetVaguelyRelatedKeys, Key } from './Keys'
import Library from './Library'
import { SegmentType, Track } from './Track'
import { ScoreField, ScoreVector } from './ScoreVector'
import Vector from './Vector'

const LENGTH_ROLLING_IN = 16
const LENGTH_ROLLING_OUT = 16
const LENGTH_DOUBLE_DROP_IN = 16
const LENGTH_DOUBLE_DROP_OUT = 32
const LENGTH_CHILL_IN = 16
const LENGTH_CHILL_OUT = 16

const THEME_WEIGHT = 0.4
const PREV_SONG_WEIGHT = -0.1 * (1 - THEME_WEIGHT)
const CURRENT_SONG_WEIGHT = 1 - (THEME_WEIGHT + PREV_SONG_WEIGHT)

// Values are [P_chill, P_rolling, P_ddrop]
const TRANSITION_PROBS = {
  Chill: [0.0, 0.7, 0.3],
  Rolling: [0.2, 0.8, 0.0],
  DoubleDrop: [0.2, 0.8, 0.0],
}

const MAX_CANDIDATES_PER_TRACK = 3

export function MixTypeScore(curFadeType: MixType, nextFadeType: MixType) {
  const [P_chill, P_roll, P_ddrop] = TRANSITION_PROBS[curFadeType]
  switch (nextFadeType) {
    case MixType.Chill:
      return P_chill
    case MixType.Rolling:
      return P_roll
    case MixType.DoubleDrop:
      return P_ddrop
  }
}

export function ODFScore(
  curTrack: Track,
  track: Track,
  cueOut: TransitionPrimary,
  cueIn: TransitionSecondary
) {
  const ODF_BARS = 4

  const fadeInBars = Math.min(cueOut.fadeInBars, cueIn.fadeInBars)
  const fadeOutBars = cueOut.fadeOutBars
  const totalBars = fadeInBars + fadeOutBars
  const primaryBar = cueOut.bar
  const secondaryBar = cueIn.bar

  // Compute the Onset Detection Function (ODF) similarity, averaged across blocks of four bars
  let sum = 0.0
  let count = 0
  for (let i = 0; i < totalBars; i += ODF_BARS) {
    const startBeatA = (primaryBar - fadeInBars + i) * 4
    const endBeatA = (primaryBar - fadeInBars + i + ODF_BARS) * 4
    const startBeatB = (secondaryBar - fadeInBars + i) * 4
    const endBeatB = (secondaryBar - fadeInBars + i + ODF_BARS) * 4
    const simil = OnsetSimilarity(curTrack, track, startBeatA, endBeatA, startBeatB, endBeatB)
    if (isNaN(simil)) {
      throw new Error(
        `OnsetSimilarity(${curTrack.name}, ${track.name}, ${startBeatA}, ${endBeatA}, ` +
          `${startBeatB}, ${endBeatB}) returned NaN`
      )
    }
    sum += simil
    ++count
  }
  return sum / count
}

export function IsVocalClash(
  vocalsA: number[],
  vocalsB: number[],
  startBarA: number,
  startBarB: number,
  barCount: number
) {
  if (barCount > vocalsA.length - startBarA)
    throw new Error(`barCount=${barCount}, vocalsALen=${vocalsA.length}, startBarA=${startBarA}`)
  if (barCount > vocalsB.length - startBarB)
    throw new Error(`barCount=${barCount}, vocalsBLen=${vocalsB.length}, startBarB=${startBarB}`)

  let count = 0
  for (let i = 1; i < barCount - 1; i++) {
    const xA = vocalsA[startBarA + i]
    const prevA = vocalsA[startBarA + i - 1]
    const nextA = vocalsA[startBarA + i + 1]
    const a = xA > 0 || (prevA > 0 && nextA > 0)

    const xB = vocalsB[startBarB + i]
    const prevB = vocalsB[startBarB + i - 1]
    const nextB = vocalsB[startBarB + i + 1]
    const b = xB > 0 || (prevB > 0 && nextB > 0)

    if (a && b && ++count >= 2) return true
  }

  return false
}

export function Onsets(onsets: Float32Array, beats: number[], startBeat: number, endBeat: number) {
  const WINDOW_SAMPLES = 512
  const SAMPLE_RATE = 44100

  startBeat = Math.max(0, startBeat)
  endBeat = Math.min(beats.length - 1, endBeat)

  const startWindow = Math.floor(Math.floor(SAMPLE_RATE * beats[startBeat]) / WINDOW_SAMPLES)
  const endWindow = Math.floor(Math.floor(SAMPLE_RATE * beats[endBeat]) / WINDOW_SAMPLES)
  return new Vector(onsets.subarray(startWindow, endWindow))
}

export function OnsetSimilarity(
  a: Track,
  b: Track,
  startBeatA: number,
  endBeatA: number,
  startBeatB: number,
  endBeatB: number
) {
  if (!a.onsets) throw new Error(`Missing onsets for "${a.name}"`)
  if (!b.onsets) throw new Error(`Missing onsets for "${b.name}"`)

  const odfA = Onsets(a.onsets, a.metadata.beats, startBeatA, endBeatA)
  const odfB = Onsets(b.onsets, b.metadata.beats, startBeatB, endBeatB)
  if (!odfA.length || !odfB.length) return NaN
  let odf1 = odfA.length >= odfB.length ? odfA : odfB
  let odf2 = odfA.length >= odfB.length ? odfB : odfA

  // Scale the onset windows so they are comparable
  const invMean1 = 1 / (odf1.mean() + 1e-6)
  const invMean2 = 1 / (odf2.mean() + 1e-6)
  odf1 = odf1.mul(invMean1)
  odf2 = odf2.mul(invMean2)

  const scores = [0.0, 0.0, 0.0, 0.0, 0.0]
  const prevScores = [0.0, 0.0, 0.0, 0.0, 0.0]
  const slope = odf2.length / odf1.length
  let prevI2Center = 0

  for (let i1 = 0; i1 < odf1.length; i1++) {
    const i2Center = Math.floor(i1 * slope + 0.5) // Center of the [-N,N] beam around the diagonal
    for (let i = 0; i < 5; i++) {
      const i2 = i2Center - 2 + i
      if (i2 >= odf2.length || i2 < 0) break
      const scoreIncrement = Math.abs(odf1.vec[i1] - odf2.vec[i2])

      let scoreNew = prevScores[i]
      if (prevI2Center === i2Center) {
        // Check up and diagonal
        if (i > 0) scoreNew = Math.min(scoreNew, scores[i - 1], prevScores[i - 1])
      } else {
        // Check right
        if (i < 4) scoreNew = Math.min(scoreNew, prevScores[i + 1])
        // Check up
        if (i > 0) scoreNew = Math.min(scoreNew, scores[i - 1])
      }

      scores[i] = scoreNew + scoreIncrement
    }

    prevI2Center = i2Center
    for (let j = 0; j < 5; j++) prevScores[j] = scores[j]
  }

  // Best score is on the diagonal
  return scores[2]
}

export function SegmentAfter(track: Track, bar: number, type: SegmentType) {
  for (let i = 0; i < track.metadata.segments.length; i++) {
    const segment = track.metadata.segments[i]
    if (segment.type === type && segment.bar > bar) return segment
  }
  return undefined
}

export function GetPitchShift(curKey: Key, nextKey: Key) {
  const closeKeys = new Set(GetCloselyRelatedKeys(curKey))
  if (closeKeys.has(nextKey)) return 0

  const vagueKeys = new Set(GetVaguelyRelatedKeys(curKey))
  if (vagueKeys.has(nextKey)) return 0

  if (closeKeys.has(GetKeyTransposed(nextKey, 1))) return 1
  if (closeKeys.has(GetKeyTransposed(nextKey, -1))) return -1
  if (vagueKeys.has(GetKeyTransposed(nextKey, 1))) return 1
  if (vagueKeys.has(GetKeyTransposed(nextKey, -1))) return -1

  log.warn(`Cannot pitch shift from ${curKey} to ${nextKey}`)
  return 0
}

export function GetWeightedRandomEntry(array: any[], rng: PRNG, weightEarly: boolean) {
  const seqSum = (array.length * (array.length + 1)) / 2
  const x = rng.next()
  let total = 0
  for (let i = 0; i < array.length; i++) {
    const weight = (weightEarly ? array.length - i : i + 1) / seqSum
    total += weight
    if (total >= x) return array[i]
  }
  return array[array.length - 1]
}

export class DJ {
  rng: PRNG
  library: Library
  unplayed: Set<Track>
  odfScoring: boolean
  mix?: Crossfade
  mixCandidates: Crossfade[] = []
  candidateIndex = 0
  semitoneOffset = 0
  themeCentroid?: Vector
  prevTheme = new Vector([0, 0, 0])

  constructor(rng: PRNG, library: Library, options: { odfScoring?: boolean, themeCentroid?: Vector } = {}) {
    this.rng = rng
    this.library = library
    this.unplayed = new Set<Track>([...this.library.tracks])
    this.odfScoring = options.odfScoring ?? false
    this.themeCentroid = options.themeCentroid
  }

  setMix(mix: Crossfade) {
    const track = mix.transition.track
    if (this.mix) {
      const prevTrack = this.mix.transition.track
      this.prevTheme = new Vector(prevTrack.metadata.themeDescriptor)
    }
    this.mix = mix
    this.semitoneOffset = mix.transition.semitoneOffset
    this.prevTheme = new Vector(track.metadata.themeDescriptor)
    this.unplayed.delete(track)
    if (!this.themeCentroid) {
      this.themeCentroid = new Vector(track.metadata.themeDescriptor)
    }

    this.mixCandidates = this.createMixCandidates()
    this.candidateIndex = 0
    if (!this.mixCandidates.length) throw new Error(`No mix candidates created`)
  }

  firstMix(): Crossfade {
    if (!this.library.loaded) throw new Error(`firstMix() called before library is loaded`)

    log.info('Initializing unplayed track pool')
    this.unplayed = new Set<Track>([...this.library.tracks])
    if (!this.unplayed.size) throw new Error(`Cannot mix, library is empty`)

    const firstTrack = this._popRandom()
    if (!firstTrack) throw new Error(`Cannot select first track, library is empty`)
    const mix = Crossfade.Create({
      track: firstTrack,
      primaryBar: 0,
      secondaryBar: 0,
      fadeInBars: 0,
      fadeOutBars: 4,
      fadeType: MixType.Chill,
      semitoneOffset: 0,
      vocalClash: false,
      score: new ScoreVector(),
    })

    this.setMix(mix)
    return mix
  }

  nextMix(): Crossfade {
    if (!this.library.loaded) throw new Error(`nextMix() called before library is loaded`)

    if (this.unplayed.size === 0) {
      log.info('Replenishing unplayed track pool')
      this.unplayed = new Set<Track>([...this.library.tracks])
      if (this.mix) this.unplayed.delete(this.mix.transition.track)
    }
    if (!this.unplayed.size) throw new Error(`Cannot mix, library is empty`)

    if (!this.mixCandidates.length) throw new Error(`Cannot mix, no candidate mixes`)
    if (this.candidateIndex >= this.mixCandidates.length) {
      log.error(`candidateIndex ${this.candidateIndex} out of range (${this.mixCandidates.length})`)
      this.candidateIndex = 0
    }

    const mix = this.mixCandidates[this.candidateIndex]
    this.setMix(mix)
    return mix
  }

  candidateTransitionsPrimary(
    track: Track,
    startBar: number,
    fadeType: MixType
  ): TransitionPrimary[] {
    const allCues = track.metadata.segments
    const lastBar = track.metadata.downbeats.length - 1
    const LH = allCues.filter((c, i) => i > 0 && allCues[i - 1].type === 'L' && c.type === 'H')
    const HL = allCues.filter((c, i) => i > 0 && allCues[i - 1].type === 'H' && c.type === 'L')
    let candidates: TransitionPrimary[]

    if (fadeType === MixType.Chill) {
      // High -> Low transitions
      const cues = HL
      let maxBars = 0
      candidates = cues
        .map((cue) => {
          const bar = cue.bar
          const stopBar = SegmentAfter(track, bar, SegmentType.High)?.bar || lastBar
          const fadeInBars = Math.min(LENGTH_CHILL_IN, bar)
          const fadeOutBars = Math.min(LENGTH_CHILL_OUT, stopBar - bar)
          maxBars = Math.max(maxBars, fadeInBars + fadeOutBars)
          return { bar, fadeInBars, fadeOutBars, fadeType }
        })
        .filter((cue) => cue.fadeInBars + cue.fadeOutBars === maxBars)
    } else if (fadeType === MixType.DoubleDrop) {
      // Low -> High transitions
      const cues = LH.filter((c) => c.bar <= lastBar - LENGTH_DOUBLE_DROP_OUT)
      candidates = cues.map((cue) => {
        const bar = cue.bar
        return {
          bar,
          fadeInBars: Math.min(LENGTH_DOUBLE_DROP_IN, bar),
          fadeOutBars: LENGTH_DOUBLE_DROP_OUT,
          fadeType,
        }
      })
    } else if (fadeType === MixType.Rolling) {
      // High -> High -> Low transitions
      const cues = allCues.filter(
        (c, i) =>
          i >= 1 &&
          i < allCues.length - 1 &&
          allCues[i - 1].type === SegmentType.High &&
          c.type === SegmentType.High &&
          allCues[i + 1].type === SegmentType.Low &&
          allCues[i + 1].bar - c.bar >= LENGTH_ROLLING_OUT
      )

      candidates = cues.map((cue) => {
        const bar = cue.bar
        return {
          bar,
          fadeInBars: Math.min(LENGTH_ROLLING_IN, bar),
          fadeOutBars: Math.min(LENGTH_ROLLING_OUT, lastBar - bar),
          fadeType,
        }
      })
    } else {
      throw new Error(`Unknown fadeType ${fadeType}`)
    }

    return candidates.filter(
      (c) => c.bar - c.fadeInBars >= startBar && c.bar + c.fadeOutBars <= lastBar
    )
  }

  candidateTransitionsSecondary(
    track: Track,
    fadeType: MixType,
    minBars: number
  ): TransitionSecondary[] {
    const allCues = track.metadata.segments
    const firstCue = allCues[0]
    const lastCue = allCues[allCues.length - 1]

    if (fadeType === MixType.Chill) {
      // Chill always transitions to the beginning of the song
      return [{ bar: firstCue.bar + LENGTH_CHILL_IN, fadeInBars: LENGTH_CHILL_IN }]
    }

    const LH = allCues.filter(
      (c, i) =>
        i > 0 && allCues[i - 1].type === 'L' && c.type === 'H' && lastCue.bar - c.bar > minBars
    )

    if (fadeType === MixType.DoubleDrop) {
      // Low -> High transitions
      return LH.map((cue) => ({
        bar: cue.bar,
        fadeInBars: Math.min(cue.bar, LENGTH_DOUBLE_DROP_IN),
      }))
    } else if (fadeType === MixType.Rolling) {
      // Low -> High transitions
      return LH.map((cue) => ({
        bar: cue.bar,
        fadeInBars: Math.min(cue.bar, LENGTH_ROLLING_IN),
      }))
    } else {
      throw new Error(`Unknown fadeType ${fadeType}`)
    }
  }

  bestTransitions(curTrack: Track, cueOut: TransitionPrimary): Transition[] {
    type ScoredTrack = [ScoreVector, Track]
    type ScoredCueOut = [ScoreVector, Track, TransitionSecondary]

    const curKey = curTrack.metadata.key
    const inTotalBars = cueOut.fadeInBars + cueOut.fadeOutBars

    const curFadeType = this.mix?.transition.fadeType ?? MixType.Chill
    const mixTypeScore = MixTypeScore(curFadeType, cueOut.fadeType)

    // Select unplayed tracks that are compatible with our target key
    const key = GetKeyTransposed(curKey, this.semitoneOffset)
    let scoredTracks = this.library
      .findTracksInKey(key)
      .filter((t) => this.unplayed.has(t))
      .map((t) => [new ScoreVector().set(ScoreField.Key, 1), t]) as ScoredTrack[]
    if (scoredTracks.length < 5) {
      scoredTracks = scoredTracks.concat(
        this.library
          .findTracksNearKey(key)
          .filter((t) => this.unplayed.has(t))
          .map((t) => [new ScoreVector().set(ScoreField.Key, 0.75), t])
      )
    }
    if (!scoredTracks.length) {
      scoredTracks = scoredTracks.concat(
        [...this.unplayed].map((t) => [new ScoreVector().set(ScoreField.Key, 0), t])
      )
    }

    // Theme distance scoring
    const curTheme = new Vector(curTrack.metadata.themeDescriptor)
    const x0 = (this.themeCentroid ?? curTheme).mul(THEME_WEIGHT) // Session theme
    const x1 = curTheme.mul(CURRENT_SONG_WEIGHT) // Current song theme
    const x2 = this.prevTheme.mul(PREV_SONG_WEIGHT) // Previous song theme
    const curCentroid = x0.add(x1).add(x2)

    const scoredCues: ScoredCueOut[] = []
    const minBars = inTotalBars + 16
    for (let [score, t] of scoredTracks) {
      score.set(ScoreField.Theme, curCentroid.dist(new Vector(t.metadata.themeDescriptor)))
      for (const cueIn of this.candidateTransitionsSecondary(t, cueOut.fadeType, minBars)) {
        scoredCues.push([new ScoreVector(score), t, cueIn])
      }
    }

    for (let [score, t, cueIn] of scoredCues) {
      const fadeInBars = Math.min(cueOut.fadeInBars, cueIn.fadeInBars)
      const fadeOutBars = cueOut.fadeOutBars
      const totalBars = fadeInBars + fadeOutBars
      const primaryBar = cueOut.bar
      const secondaryBar = cueIn.bar

      // Check for overlapping vocals
      const vocalClash = IsVocalClash(
        curTrack.metadata.vocals,
        t.metadata.vocals,
        primaryBar - fadeInBars,
        secondaryBar - fadeInBars,
        totalBars
      )

      // ODF scoring
      const odfScore = this.odfScoring ? ODFScore(curTrack, t, cueOut, cueIn) : 0

      // Prefer mixes that transition into the next track earlier, so there are
      // more options in the future
      const invSecondaryProgress = 1 - secondaryBar / t.metadata.downbeats.length

      score.set(ScoreField.MixType, mixTypeScore)
      score.set(ScoreField.Onsets, odfScore)
      score.set(ScoreField.Vocals, vocalClash ? 0 : 1)
      score.set(ScoreField.NextSegments, invSecondaryProgress)
    }

    return scoredCues.map((entry) => {
      const [score, t, cueIn] = entry
      return {
        track: t,
        prevTrack: curTrack,
        primaryBar: cueOut.bar,
        secondaryBar: cueIn.bar,
        fadeInBars: cueIn.fadeInBars,
        fadeOutBars: cueOut.fadeOutBars,
        semitoneOffset: this.semitoneOffset,
        fadeType: cueOut.fadeType,
        vocalClash: score.get(ScoreField.Vocals) >= 0.5,
        score,
      }
    })
  }

  private createMixCandidates() {
    if (!this.mix) return []

    const curTransition = this.mix.transition
    const curTrack = curTransition.track
    const allCues = curTrack.metadata.segments
    const firstCue = allCues[0]
    let earliestBar = curTransition.secondaryBar + 48

    // Align earliestBar with an eight bar phrase boundary
    earliestBar = earliestBar + ((8 - ((earliestBar - firstCue.bar) % 8)) % 8)

    // Get candidate cue out points and mix types in the current track
    const cuesOut = [
      ...this.candidateTransitionsPrimary(curTrack, earliestBar, MixType.DoubleDrop),
      ...this.candidateTransitionsPrimary(curTrack, earliestBar, MixType.Rolling),
      ...this.candidateTransitionsPrimary(curTrack, earliestBar, MixType.Chill),
    ]

    // Build a map of arrays for candidate transitions for each track
    const candidates = new Map<Track, Transition[]>()
    for (const out of cuesOut) {
      const curTransitions = this.bestTransitions(curTrack, out)
      for (const t of curTransitions) {
        if (!candidates.has(t.track)) candidates.set(t.track, [])
        candidates.get(t.track)!.push(t)
      }
    }

    // Get the top MAX_CANDIDATES_PER_TRACK transitions for each track
    const transitions: Transition[] = []
    for (const curTransitions of candidates.values()) {
      curTransitions.sort((a, b) => b.score.compare(a.score))
      transitions.push(...curTransitions.slice(0, MAX_CANDIDATES_PER_TRACK))
    }

    // Sort the transitions by score
    transitions.sort((a, b) => b.score.compare(a.score))

    // Generate a crossfade for every candidate transition
    return transitions.map((t) => Crossfade.Create(t))
  }

  private _popRandom() {
    const track = this.rng.nextArrayItem(Array.from(this.unplayed.values()))
    this.unplayed.delete(track)
    return track
  }
}
