use std::fs::File;
use std::io::BufReader;
use crate::io::byteio::*;
use crate::io::bitreader::*;
use super::super::*;

const DICT_SIZE: usize = 0x88CF;
const MAX_BITS:   u8 = 16;
const INVALID_POS: usize = 65536;

struct LZWState {
    dict_sym:   [u8; DICT_SIZE],
    dict_prev:  [u16; DICT_SIZE],
    dict_pos:   usize,
    dict_lim:   usize,
    nsyms:      usize,
    idx_bits:   u8,
}

impl LZWState {
    fn new() -> Self {
        Self {
            dict_sym:   [0; DICT_SIZE],
            dict_prev:  [0; DICT_SIZE],
            dict_pos:   0,
            dict_lim:   0,
            idx_bits:   0,
            nsyms:      0,
        }
    }
    fn reset(&mut self, bits: u8) {
        self.nsyms    = 1 << bits;
        self.dict_pos = self.nsyms + 3;
        self.dict_lim = 1 << (bits + 1);
        self.idx_bits = bits + 1;
    }
    fn add(&mut self, prev: usize, sym: u8) {
        if self.dict_pos < self.dict_lim {
            self.dict_sym [self.dict_pos] = sym;
            self.dict_prev[self.dict_pos] = prev as u16;
            self.dict_pos += 1;
        }
    }
    fn decode_idx(&self, dst: &mut Vec<u8>, pos: usize, idx: usize) -> DecoderResult<usize> {
        let mut tot_len = 1;
        let mut tidx = idx;
        while tidx >= self.nsyms {
            tidx = self.dict_prev[tidx] as usize;
            tot_len += 1;
        }
        for _ in 0..tot_len {
            dst.push(0);
        }

        let mut end = pos + tot_len - 1;
        let mut tidx = idx;
        while tidx >= self.nsyms {
            dst[end] = self.dict_sym[tidx];
            end -= 1;
            tidx = self.dict_prev[tidx] as usize;
        }
        dst[end] = tidx as u8;

        Ok(tot_len)
    }
    fn remap_idx(idx: u32) -> usize {
        match idx {
            0 => 0x101,
            1 => 0x102,
            2 => 0x100,
            3..=0x102 => (idx - 3) as usize,
            _ => idx as usize,
        }
    }
    fn unpack(&mut self, src: &[u8], dst: &mut Vec<u8>) -> DecoderResult<()> {
        validate!(src.len() > 1);
        let mut br = BitReader::new(src, BitReaderMode::BE);

        dst.clear();

        self.reset(8);
        'restart: loop {
            let mut lastidx = Self::remap_idx(br.read(self.idx_bits)
                .map_err(|_| DecoderError::InvalidData)?);
            if lastidx == 0x100 {
                return Ok(());
            }
            validate!(lastidx < 0x100);
            dst.push(lastidx as u8);
            loop {
                let ret         = br.read(self.idx_bits);
                if ret.is_err() {
                    return Ok(());
                }
                let idx = Self::remap_idx(ret.unwrap());
                match idx {
                    0x100 => return Ok(()),
                    0x101 => {
                        self.dict_lim <<= 1;
                        self.idx_bits += 1;
                        br.align();
                        continue;
                    },
                    0x102 => {
                        self.reset(8);
                        continue 'restart;
                    },
                    _ => {},
                }
                validate!(idx <= self.dict_pos);
                let pos = dst.len();
                if idx != self.dict_pos {
                    self.decode_idx(dst, pos, idx)?;
                    self.add(lastidx, dst[pos]);
                } else {
                    self.decode_idx(dst, pos, lastidx)?;
                    let lastsym = dst[pos];
                    dst.push(lastsym);
                    self.add(lastidx, lastsym);
                }
                lastidx = idx;
            }
        }
    }
}

struct FrameRecord {
    ftype:      u32,
    size:       usize,
    offset:     u32,
}

impl FrameRecord {
    fn read(br: &mut dyn ByteIO) -> DecoderResult<Self> {
        let ftype = br.read_u32le()?;
        validate!((1..=5).contains(&ftype));
        br.read_u32le()?;
        let size = br.read_u32le()? as usize;
        validate!(size > 0);
        let offset = br.read_u32le()?;
        Ok(Self { ftype, size, offset })
    }
}

struct MVIDecoder {
    fr:         FileReader<BufReader<File>>,
    base:       u64,
    width:      usize,
    height:     usize,
    fps:        u32,
    frame:      Vec<u8>,
    pal:        [u8; 768],
    arate:      u32,
    channels:   u8,
    abits:      u8,
    frm_rec:    Vec<FrameRecord>,
    cur_frm:    usize,
    data:       Vec<u8>,
    pixels:     Vec<u8>,
    mask:       Vec<u8>,
    lzw:        LZWState,
    left:       Option<Vec<u8>>,
}

fn update_masked(dst: &mut [u8], mask_bytes: &[u8], pixels: &[u8]) -> DecoderResult<()> {
    let mut pix = MemoryReader::new_read(pixels);
    for (chunk, mask) in dst.chunks_mut(8).zip(mask_bytes.iter()) {
        let mut mask = *mask;
        for el in chunk.iter_mut() {
            if (mask & 0x80) != 0 {
                *el = pix.read_byte()?;
            }
            mask <<= 1;
        }
    }
    Ok(())
}

impl InputSource for MVIDecoder {
    fn get_num_streams(&self) -> usize { if self.arate > 0 { 2 } else { 1 } }
    fn get_stream_info(&self, stream_no: usize) -> StreamInfo {
        match stream_no {
            0 => StreamInfo::Video(VideoInfo{
                    width:  self.width,
                    height: self.height,
                    bpp:    8,
                    tb_num: 1,
                    tb_den: self.fps,
                }),
            1 if self.arate > 0 => StreamInfo::Audio(AudioInfo{
                    sample_rate: self.arate,
                    sample_type: if self.abits > 8 { AudioSample::S16 } else { AudioSample::U8 },
                    channels:    self.channels,
                }),
            _ => StreamInfo::None
        }
    }
    fn decode_frame(&mut self) -> DecoderResult<(usize, Frame)> {
        let br = &mut self.fr;

        loop {
            if self.cur_frm >= self.frm_rec.len() {
                return Err(DecoderError::EOF);
            }
            let frm_info = &self.frm_rec[self.cur_frm];
            self.cur_frm += 1;

            if frm_info.offset > 0 {
                validate!(br.tell() == self.base + u64::from(frm_info.offset));
            }

            match frm_info.ftype {
                1 => { // intra
                    validate!(frm_info.size > 8);
                    self.data.resize(frm_info.size, 0);
                    br.read_buf(&mut self.data)?;
                    let packed_size = read_u32le(&self.data).unwrap_or_default() as usize;
                    let unpacked_size = read_u32le(&self.data[4..]).unwrap_or_default() as usize;
                    validate!(self.data.len() > packed_size + 8);
                    self.lzw.unpack(&self.data[8..], &mut self.pixels)
                        .map_err(|_| DecoderError::InvalidData)?;
                    validate!(self.pixels.len() == unpacked_size);
                    validate!(self.pixels.len() <= self.frame.len());
                    self.frame[..self.pixels.len()].copy_from_slice(&self.pixels);
                    return Ok((0, Frame::VideoPal(self.frame.clone(), self.pal)));
                },
                2 => { // inter
                    validate!(frm_info.size >= 16);
                    self.data.resize(frm_info.size, 0);
                    br.read_buf(&mut self.data)?;

                    let part1_csize = read_u32le(&self.data).unwrap_or_default() as usize;
                    let part1_usize = read_u32le(&self.data[4..]).unwrap_or_default() as usize;
                    let part2_csize = read_u32le(&self.data[8..]).unwrap_or_default() as usize;
                    let part2_usize = read_u32le(&self.data[12..]).unwrap_or_default() as usize;

                    if part1_usize > 0 {
                        if part1_csize < part1_usize {
                            validate!(self.data.len() > 24);
                            self.lzw.unpack(&self.data[24..], &mut self.pixels)
                                .map_err(|_| DecoderError::InvalidData)?;
                            validate!(self.pixels.len() == part1_usize);
                        } else {
                            self.pixels.resize(part1_csize, 0);
                            self.pixels.copy_from_slice(&self.data[16..][..part1_csize]);
                        }
                        if part2_csize < part2_usize {
                            validate!(part2_csize + part1_csize + 16 <= self.data.len());
                            validate!(part2_csize > 8);
                            self.lzw.unpack(&self.data[part1_csize + 24..], &mut self.mask)
                                .map_err(|_| DecoderError::InvalidData)?;
                            validate!(self.mask.len() == part2_usize);
                        } else {
                            self.mask.resize(part2_csize, 0);
                            self.mask.copy_from_slice(&self.data[part1_csize + 16..][..part2_csize]);
                        }
                        update_masked(&mut self.frame, &self.mask, &self.pixels)
                            .map_err(|_| DecoderError::InvalidData)?;
                    }
                    return Ok((0, Frame::VideoPal(self.frame.clone(), self.pal)));
                },
                3 | 4 => { // audio
                    validate!(self.arate > 0);
                    validate!(frm_info.size == (self.arate as usize));
                    let mut abuf = vec![0; frm_info.size];
                    br.read_buf(&mut abuf)?;
                    for el in abuf.iter_mut() {
                        *el ^= 0x80;
                    }
                    match (frm_info.ftype, self.channels) {
                        (3, 1) => return Ok((1, Frame::AudioU8(abuf))),
                        (3, _) => {
                            validate!(self.left.is_none());
                            self.left = Some(abuf);
                        },
                        (4, _) => {
                            validate!(self.channels == 2);
                            let mut left = None;
                            std::mem::swap(&mut left, &mut self.left);
                            validate!(left.is_some());
                            let left = left.unwrap();
                            let mut stereo = vec![0; left.len() * 2];
                            for (pair, (&l, &r)) in stereo.chunks_exact_mut(2)
                                    .zip(left.iter().zip(abuf.iter())) {
                                pair[0] = l;
                                pair[1] = r;
                            }
                            return Ok((1, Frame::AudioU8(stereo)));
                        },
                        _ => unreachable!(),
                    }
                },
                5 => { // palette (plus some other optional data)
                    validate!(frm_info.size > 0);
                    br.read_buf(&mut self.pal[..frm_info.size.min(768)])?;
                    if frm_info.size > 768 {
                        br.read_skip(frm_info.size - 768)?;
                    }
                },
                _ => {
                    println!("Unsupported chunk {}", frm_info.ftype);
                    br.read_skip(frm_info.size)?;
                },
            }
        }
    }
}

pub fn open(name: &str) -> DecoderResult<Box<dyn InputSource>> {
    let file = File::open(name).map_err(|_| DecoderError::InputNotFound(name.to_owned()))?;
    let mut fr = FileReader::new_read(BufReader::new(file));

    let one = fr.read_u32le()?;
    validate!(one == 1);
    fr.read_u32le()?;
    let mut width = fr.read_u32le()? as usize;
    let height = fr.read_u32le()? as usize;
    validate!((1..=1024).contains(&width) && (1..=1024).contains(&height));
    // no idea why
    if width == 640 { width = 320; }

    let nframes = fr.read_u32le()? as usize;
    validate!((1..=0x2000).contains(&nframes));
    fr.read_u32le()?;
    fr.read_u32le()?; // always 0x3D?
    let fps = fr.read_u32le()?;
    validate!((1..=30).contains(&fps));

    let has_audio = fr.read_u32le()?;
    let channels = fr.read_u32le()?;
    let arate = fr.read_u32le()?;
    validate!(channels <= 2 && (arate == 0 || (8000..=32000).contains(&arate)));
    let abits = fr.read_u32le()?;
    validate!(abits == 0 || abits == 8);
    if has_audio > 0 {
        validate!(arate > 0 && channels > 0 && abits > 0);
    } else {
        validate!(arate == 0 && channels == 0 && abits == 0);
    }

    let mut frm_rec = Vec::with_capacity(nframes);
    for _ in 0..nframes {
        let rec = FrameRecord::read(&mut fr)?;
        frm_rec.push(rec);
    }

    let pos = fr.tell();
    fr.seek(SeekFrom::Start((pos + 0x7FF) & !0x7FF))?;

    let base = fr.tell();

    Ok(Box::new(MVIDecoder {
        fr,
        width, height, fps,
        pal: [0; 768],
        frame: vec![0; width * height],
        frm_rec,
        cur_frm: 0,
        data: Vec::new(),
        pixels: Vec::new(),
        mask: Vec::new(),
        lzw: LZWState::new(),
        base,
        arate,
        channels: channels as u8,
        abits: abits as u8,
        left: None,
    }))
}
