//! This crate provides types related to JPEG XL frames.
//!
//! A JPEG XL image contains one or more frames. A frame represents single unit of image that can
//! be displayed or referenced by other frames.
//!
//! A frame consists of a few components:
//! - [Frame header][FrameHeader].
//! - [Table of contents (TOC)][data::Toc].
//! - Actual frame data, in the following order, potentially permuted as specified in the TOC:
//!   - one [`LfGlobal`],
//!   - [`num_lf_groups`] [`LfGroup`]'s, in raster order,
//!   - one [`HfGlobal`], potentially empty for Modular frames, and
//!   - [`num_passes`] times [`num_groups`] [pass groups][data::decode_pass_group], in raster
//!     order.
//!
//! [`num_lf_groups`]: FrameHeader::num_lf_groups
//! [`num_groups`]: FrameHeader::num_groups
//! [`num_passes`]: header::Passes::num_passes
use std::collections::BTreeMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

use jxl_bitstream::{read_bits, Bitstream, Bundle};
use jxl_image::ImageHeader;

pub mod data;
mod error;
pub mod filter;
pub mod header;

pub use error::{Error, Result};
pub use header::FrameHeader;
use jxl_modular::{image::TransformedModularSubimage, MaConfig};
use jxl_threadpool::JxlThreadPool;

use crate::data::*;

/// JPEG XL frame.
///
/// A frame represents a single unit of image that can be displayed or referenced by other frames.
#[derive(Debug)]
pub struct Frame {
    pool: JxlThreadPool,
    image_header: Arc<ImageHeader>,
    header: FrameHeader,
    toc: Toc,
    data: Vec<GroupData>,
    all_group_offsets: AllGroupOffsets,
    reading_data_index: usize,
    pass_shifts: BTreeMap<u32, (i32, i32)>,
}

#[derive(Debug, Default)]
struct AllGroupOffsets {
    lf_group: AtomicUsize,
    hf_global: AtomicUsize,
    pass_group: AtomicUsize,
}

#[derive(Debug)]
struct GroupData {
    toc_group: TocGroup,
    bytes: Vec<u8>,
}

impl From<TocGroup> for GroupData {
    fn from(value: TocGroup) -> Self {
        let cap = value.size as usize;
        Self {
            toc_group: value,
            bytes: Vec::with_capacity(cap),
        }
    }
}

#[derive(Debug, Clone)]
pub struct FrameContext {
    pub image_header: Arc<ImageHeader>,
    pub pool: JxlThreadPool,
}

impl Bundle<FrameContext> for Frame {
    type Error = crate::Error;

    fn parse(bitstream: &mut Bitstream, ctx: FrameContext) -> Result<Self> {
        let FrameContext { image_header, pool } = ctx;

        bitstream.zero_pad_to_byte()?;
        let base_offset = bitstream.num_read_bits() / 8;
        let header = read_bits!(bitstream, Bundle(FrameHeader), &image_header)?;

        let width = header.width as u64;
        let height = header.height as u64;
        if width > (1 << 30) {
            tracing::error!(width, "Frame width too large; limit is 2^30");
            return Err(jxl_bitstream::Error::ProfileConformance("frame width too large").into());
        }
        if height > (1 << 30) {
            tracing::error!(width, "Frame height too large; limit is 2^30");
            return Err(jxl_bitstream::Error::ProfileConformance("frame height too large").into());
        }
        if (width * height) > (1 << 40) {
            tracing::error!(
                area = width * height,
                "Frame area (width * height) too large; limit is 2^40"
            );
            return Err(jxl_bitstream::Error::ProfileConformance("frame area too large").into());
        }

        for blending_info in std::iter::once(&header.blending_info).chain(&header.ec_blending_info)
        {
            if blending_info.mode.use_alpha()
                && blending_info.alpha_channel as usize >= image_header.metadata.ec_info.len()
            {
                return Err(jxl_bitstream::Error::ValidationFailed(
                    "blending_info.alpha_channel out of range",
                )
                .into());
            }
        }
        if header.flags.use_lf_frame() && header.lf_level >= 4 {
            return Err(jxl_bitstream::Error::ValidationFailed("lf_level out of range").into());
        }

        for ec_info in &image_header.metadata.ec_info {
            if ec_info.dim_shift > 7 + header.group_size_shift {
                return Err(jxl_bitstream::Error::ValidationFailed("dim_shift too large").into());
            }
        }

        if header.upsampling > 1 {
            for (ec_upsampling, ec_info) in header
                .ec_upsampling
                .iter()
                .zip(image_header.metadata.ec_info.iter())
            {
                if (ec_upsampling << ec_info.dim_shift) < header.upsampling {
                    return Err(jxl_bitstream::Error::ValidationFailed(
                        "EC upsampling < color upsampling, which is invalid",
                    )
                    .into());
                }
            }
        }

        if header.width == 0 || header.height == 0 {
            return Err(jxl_bitstream::Error::ValidationFailed(
                "Invalid crop dimensions for frame: zero width or height",
            )
            .into());
        }

        let mut toc = read_bits!(bitstream, Bundle(Toc), &header)?;
        toc.adjust_offsets(base_offset);
        let data = toc.iter_bitstream_order().map(GroupData::from).collect();

        let passes = &header.passes;
        let mut pass_shifts = BTreeMap::new();
        let mut maxshift = 3i32;
        for (&downsample, &last_pass) in passes.downsample.iter().zip(&passes.last_pass) {
            let minshift = downsample.trailing_zeros() as i32;
            pass_shifts.insert(last_pass, (minshift, maxshift));
            maxshift = minshift;
        }
        pass_shifts.insert(header.passes.num_passes - 1, (0i32, maxshift));

        Ok(Self {
            pool,
            image_header,
            header,
            toc,
            data,
            all_group_offsets: AllGroupOffsets::default(),
            reading_data_index: 0,
            pass_shifts,
        })
    }
}

impl Frame {
    pub fn image_header(&self) -> &ImageHeader {
        &self.image_header
    }

    pub fn clone_image_header(&self) -> Arc<ImageHeader> {
        Arc::clone(&self.image_header)
    }

    /// Returns the frame header.
    pub fn header(&self) -> &FrameHeader {
        &self.header
    }

    /// Returns the TOC.
    ///
    /// See the documentation of [`Toc`] for details.
    pub fn toc(&self) -> &Toc {
        &self.toc
    }

    pub fn pass_shifts(&self) -> &BTreeMap<u32, (i32, i32)> {
        &self.pass_shifts
    }

    pub fn data(&self, group: TocGroupKind) -> Option<&[u8]> {
        let idx = self.toc.group_index_bitstream_order(group);
        self.data.get(idx).map(|b| &*b.bytes)
    }
}

impl Frame {
    pub fn feed_bytes<'buf>(&mut self, mut buf: &'buf [u8]) -> &'buf [u8] {
        while let Some(group_data) = self.data.get_mut(self.reading_data_index) {
            let bytes_left = group_data.toc_group.size as usize - group_data.bytes.len();
            if buf.len() < bytes_left {
                group_data.bytes.extend_from_slice(buf);
                return &[];
            }
            let (l, r) = buf.split_at(bytes_left);
            group_data.bytes.extend_from_slice(l);
            buf = r;
            self.reading_data_index += 1;
        }
        buf
    }

    #[inline]
    pub fn is_loading_done(&self) -> bool {
        self.reading_data_index >= self.data.len()
    }
}

impl Frame {
    pub fn try_parse_lf_global(&self) -> Option<Result<LfGlobal>> {
        Some(if self.toc.is_single_entry() {
            let group = self.data.get(0)?;
            let mut bitstream = Bitstream::new(&group.bytes);
            let lf_global = LfGlobal::parse(
                &mut bitstream,
                LfGlobalParams::new(&self.image_header, &self.header, false),
            );
            if lf_global.is_ok() {
                tracing::trace!(num_read_bits = bitstream.num_read_bits(), "LfGlobal");
                self.all_group_offsets
                    .lf_group
                    .store(bitstream.num_read_bits(), Ordering::Relaxed);
            }
            lf_global
        } else {
            let idx = self.toc.group_index_bitstream_order(TocGroupKind::LfGlobal);
            let group = self.data.get(idx)?;
            let allow_partial = group.bytes.len() < group.toc_group.size as usize;

            let mut bitstream = Bitstream::new(&group.bytes);
            LfGlobal::parse(
                &mut bitstream,
                LfGlobalParams::new(&self.image_header, &self.header, allow_partial),
            )
        })
    }

    pub fn try_parse_lf_group(
        &self,
        lf_global_vardct: Option<&LfGlobalVarDct>,
        global_ma_config: Option<&MaConfig>,
        mlf_group: Option<TransformedModularSubimage>,
        lf_group_idx: u32,
    ) -> Option<Result<LfGroup>> {
        if self.toc.is_single_entry() {
            if lf_group_idx != 0 {
                return None;
            }

            let group = self.data.get(0)?;
            let mut bitstream = Bitstream::new(&group.bytes);
            let offset = self.all_group_offsets.lf_group.load(Ordering::Relaxed);
            if offset == 0 {
                let lf_global = self.try_parse_lf_global().unwrap();
                if let Err(e) = lf_global {
                    return Some(Err(e));
                }
            }
            let offset = self.all_group_offsets.lf_group.load(Ordering::Relaxed);
            bitstream.skip_bits(offset).unwrap();

            let allow_partial = group.bytes.len() < group.toc_group.size as usize;
            let result = LfGroup::parse(
                &mut bitstream,
                LfGroupParams {
                    frame_header: &self.header,
                    quantizer: lf_global_vardct.map(|x| &x.quantizer),
                    global_ma_config,
                    mlf_group,
                    lf_group_idx,
                    allow_partial,
                    pool: &self.pool,
                },
            );
            if allow_partial && result.is_err() {
                return None;
            }
            tracing::trace!(num_read_bits = bitstream.num_read_bits(), "LfGroup");
            self.all_group_offsets
                .hf_global
                .store(bitstream.num_read_bits(), Ordering::Relaxed);
            Some(result)
        } else {
            let idx = self
                .toc
                .group_index_bitstream_order(TocGroupKind::LfGroup(lf_group_idx));
            let group = self.data.get(idx)?;
            let allow_partial = group.bytes.len() < group.toc_group.size as usize;

            let mut bitstream = Bitstream::new(&group.bytes);
            let result = LfGroup::parse(
                &mut bitstream,
                LfGroupParams {
                    frame_header: &self.header,
                    quantizer: lf_global_vardct.map(|x| &x.quantizer),
                    global_ma_config,
                    mlf_group,
                    lf_group_idx,
                    allow_partial,
                    pool: &self.pool,
                },
            );
            if allow_partial && result.is_err() {
                return None;
            }
            Some(result)
        }
    }

    pub fn try_parse_hf_global(
        &self,
        cached_lf_global: Option<&LfGlobal>,
    ) -> Option<Result<HfGlobal>> {
        let is_modular = self.header.encoding == header::Encoding::Modular;

        if self.toc.is_single_entry() {
            let group = self.data.get(0)?;
            let mut bitstream = Bitstream::new(&group.bytes);
            let offset = self.all_group_offsets.hf_global.load(Ordering::Relaxed);
            let lf_global = if cached_lf_global.is_none() && (offset == 0 || !is_modular) {
                match self.try_parse_lf_global()? {
                    Ok(lf_global) => Some(lf_global),
                    Err(e) => return Some(Err(e)),
                }
            } else {
                None
            };
            let lf_global = cached_lf_global.or(lf_global.as_ref());

            if offset == 0 {
                let lf_global = lf_global.unwrap();
                let mut gmodular = lf_global.gmodular.clone();
                let groups = gmodular
                    .modular
                    .image_mut()
                    .map(|x| x.prepare_groups(&self.pass_shifts))
                    .transpose();
                let groups = match groups {
                    Ok(groups) => groups,
                    Err(e) => return Some(Err(e.into())),
                };
                let mlf_group = groups.and_then(|mut x| x.lf_groups.pop());
                let lf_group = self
                    .try_parse_lf_group(
                        lf_global.vardct.as_ref(),
                        lf_global.gmodular.ma_config(),
                        mlf_group,
                        0,
                    )
                    .unwrap();
                if let Err(e) = lf_group {
                    return Some(Err(e));
                }
            }
            let offset = self.all_group_offsets.hf_global.load(Ordering::Relaxed);

            if self.header.encoding == header::Encoding::Modular {
                self.all_group_offsets
                    .pass_group
                    .store(offset, Ordering::Relaxed);
                return None;
            }

            bitstream.skip_bits(offset).unwrap();
            let lf_global = lf_global.unwrap();
            let result = HfGlobal::parse(
                &mut bitstream,
                HfGlobalParams::new(
                    &self.image_header.metadata,
                    &self.header,
                    lf_global,
                    &self.pool,
                ),
            );
            self.all_group_offsets
                .pass_group
                .store(bitstream.num_read_bits(), Ordering::Relaxed);
            Some(result)
        } else {
            if self.header.encoding == header::Encoding::Modular {
                return None;
            }

            let idx = self.toc.group_index_bitstream_order(TocGroupKind::HfGlobal);
            let group = self.data.get(idx)?;
            if group.bytes.len() < group.toc_group.size as usize {
                return None;
            }

            let mut bitstream = Bitstream::new(&group.bytes);
            let lf_global = if cached_lf_global.is_none() {
                match self.try_parse_lf_global()? {
                    Ok(lf_global) => Some(lf_global),
                    Err(e) => return Some(Err(e)),
                }
            } else {
                None
            };
            let lf_global = cached_lf_global.or(lf_global.as_ref()).unwrap();
            let params = HfGlobalParams::new(
                &self.image_header.metadata,
                &self.header,
                lf_global,
                &self.pool,
            );
            Some(HfGlobal::parse(&mut bitstream, params))
        }
    }

    pub fn pass_group_bitstream(
        &self,
        pass_idx: u32,
        group_idx: u32,
    ) -> Option<Result<PassGroupBitstream>> {
        Some(if self.toc.is_single_entry() {
            if pass_idx != 0 || group_idx != 0 {
                return None;
            }

            let group = self.data.get(0)?;
            let mut bitstream = Bitstream::new(&group.bytes);
            let mut offset = self.all_group_offsets.pass_group.load(Ordering::Relaxed);
            if offset == 0 {
                let hf_global = self.try_parse_hf_global(None)?;
                if let Err(e) = hf_global {
                    return Some(Err(e));
                }
                offset = self.all_group_offsets.pass_group.load(Ordering::Relaxed);
            }
            bitstream.skip_bits(offset).unwrap();

            Ok(PassGroupBitstream {
                bitstream,
                partial: group.bytes.len() < group.toc_group.size as usize,
            })
        } else {
            let idx = self
                .toc
                .group_index_bitstream_order(TocGroupKind::GroupPass {
                    pass_idx,
                    group_idx,
                });
            let group = self.data.get(idx)?;
            let partial = group.bytes.len() < group.toc_group.size as usize;

            Ok(PassGroupBitstream {
                bitstream: Bitstream::new(&group.bytes),
                partial,
            })
        })
    }
}

#[derive(Debug)]
pub struct PassGroupBitstream<'buf> {
    pub bitstream: Bitstream<'buf>,
    pub partial: bool,
}

impl Frame {
    /// Adjusts the cropping region of the image to the actual decoding region of the frame.
    ///
    /// The cropping region of the *image* needs to be adjusted to be used in a *frame*, for a few
    /// reasons:
    /// - A frame may be blended to the canvas with offset, which makes the image and the frame
    ///   have different coordinates.
    /// - Some filters reference other samples, which requires padding to the region.
    ///
    /// This method takes care of those and adjusts the given region appropriately.
    pub fn adjust_region(&self, (left, top, width, height): &mut (u32, u32, u32, u32)) {
        if self.header.have_crop {
            *left = left.saturating_add_signed(-self.header.x0);
            *top = top.saturating_add_signed(-self.header.y0);
        };

        let mut padding = 0u32;
        if self.header.restoration_filter.gab.enabled() {
            tracing::debug!("Gabor-like filter requires padding of 1 pixel");
            padding = 1;
        }
        if self.header.restoration_filter.epf.enabled() {
            tracing::debug!("Edge-preserving filter requires padding of 3 pixels");
            padding = 3;
        }
        if padding > 0 {
            let delta_w = (*left).min(padding);
            let delta_h = (*top).min(padding);
            *left -= delta_w;
            *top -= delta_h;
            *width += delta_w + padding;
            *height += delta_h + padding;
        }
    }
}
