open Core
open ImageLib_unix

module Norm_kind = struct
  module T = struct
    type t = Euclidean | One | Max [@@deriving sexp, enumerate]

  include T
  include Sexpable.To_stringable (T)

  let arg = Command.Arg_type.create of_string

module Traverse_mode = struct
  module T = struct
    type t = Rows | Columns [@@deriving sexp, enumerate]

  include T
  include Sexpable.To_stringable (T)

  let arg = Command.Arg_type.create of_string

module Corner = struct
  module T = struct
    type t = North_east | North_west | South_west | South_east
    [@@deriving sexp]

  include T
  include Sexpable.To_stringable (T)

  let arg = Command.Arg_type.create of_string

module Image = struct
  include Image

  module Rgba = struct
    type t = {r: int; g: int; b: int; a: int; max_val: int}

    let zero = {r= 0; g= 0; b= 0; a= 0; max_val= 0}

    let distance t t' = function
      | Norm_kind.Euclidean ->
          @@ Float.sqrt
               ( 0.
               +. (let f = Float.of_int (Int.abs (t.r - t'.r)) in
                   f *. f)
               +. (let f = Float.of_int (Int.abs (t.g - t'.g)) in
                   f *. f)
               +. (let f = Float.of_int (Int.abs (t.b - t'.b)) in
                   f *. f)
               let f = Float.of_int (Int.abs (t.a - t'.a)) in
               f *. f )
      | One ->
          Int.abs (t.r - t'.r)
          + Int.abs (t.g - t'.g)
          + Int.abs (t.b - t'.b)
          + Int.abs (t.a - t'.a)
      | Max ->
            (List.max_elt ~compare:Int.compare
               [ Int.abs (t.r - t'.r)
               ; Int.abs (t.g - t'.g)
               ; Int.abs (t.b - t'.b)
               ; Int.abs (t.a - t'.a) ])

    let norm t norm = distance t zero norm

    let combine t ~new_:t' =
      let x =
        { r= ((100 * t.r) + (1 * t'.r)) / 101
        ; g= ((100 * t.g) + (1 * t'.g)) / 101
        ; b= ((100 * t.b) + (1 * t'.b)) / 101
        ; a= ((100 * t.a) + (1 * t'.a)) / 101
        ; max_val= t.max_val }

  let get_rgba im i j =
    match im.pixels with
    | Grey p ->
        let v = Pixmap.get p i j in
        {Rgba.r= v; g= v; b= v; a= 0; max_val= im.max_val}
    | GreyA (p, a) ->
        let v, a = (Pixmap.get p i j, Pixmap.get a i j) in
        {Rgba.r= v; g= v; b= v; a; max_val= im.max_val}
    | RGB (r, g, b) ->
        { Rgba.r= Pixmap.get r i j
        ; g= Pixmap.get g i j
        ; b= Pixmap.get b i j
        ; a= 0
        ; max_val= im.max_val }
    | RGBA (r, g, b, a) ->
        { Rgba.r= Pixmap.get r i j
        ; g= Pixmap.get g i j
        ; b= Pixmap.get b i j
        ; a= Pixmap.get a i j
        ; max_val= im.max_val }

  let set_rgba im i j {Rgba.r; g; b; a; max_val= _} =
    match im.pixels with
    | Grey mp ->
        let v = (r + g + b) / 3 in
        Pixmap.set mp i j v
    | GreyA (mp, ma) ->
        let v = (r + g + b) / 3 in
        Pixmap.set mp i j v ; Pixmap.set ma i j a
    | RGB (mr, mg, mb) ->
        Pixmap.set mr i j r ; Pixmap.set mg i j g ; Pixmap.set mb i j b
    | RGBA (mr, mg, mb, ma) ->
        Pixmap.set mr i j r ;
        Pixmap.set mg i j g ;
        Pixmap.set mb i j b ;
        Pixmap.set ma i j a

  let iteri traverse_mode corner im ~f =
    match (traverse_mode, corner) with
    | Traverse_mode.Columns, Corner.South_west ->
        for i = im.width - 1 downto 0 do
          for j = im.height - 1 downto 0 do
            f i j (get_rgba im i j)
    | Rows, South_west ->
        for j = im.height - 1 downto 0 do
          for i = im.width - 1 downto 0 do
            f i j (get_rgba im i j)
    | Columns, South_east ->
        for i = 0 to im.width - 1 do
          for j = im.height - 1 downto 0 do
            f i j (get_rgba im i j)
    | Rows, South_east ->
        for j = im.height - 1 downto 0 do
          for i = 0 to im.width - 1 do
            f i j (get_rgba im i j)
    | Columns, North_west ->
        for i = im.width - 1 downto 0 do
          for j = 0 to im.height - 1 do
            f i j (get_rgba im i j)
    | Rows, North_west ->
        for j = 0 to im.height - 1 do
          for i = im.width - 1 downto 0 do
            f i j (get_rgba im i j)
    | Columns, North_east ->
        for i = 0 to im.width - 1 do
          for j = 0 to im.height - 1 do
            f i j (get_rgba im i j)
    | Rows, North_east ->
        for j = 0 to im.height - 1 do
          for i = 0 to im.width - 1 do
            f i j (get_rgba im i j)

  let mapi_inplace traverse_mode corner im ~f =
    iteri traverse_mode corner im ~f:(fun i j rgba ->
        set_rgba im i j (f i j rgba) )

  let map_slices_in_place dir im ~f =
    match dir with
    | Traverse_mode.Rows ->
      for r = 0 to im.Image.height -1 do
        List.iteri (f (List.init im.Image.width ~f:(fun i -> get_rgba im i r)))
          ~f:(fun i rgba -> set_rgba im i r rgba)
    | Traverse_mode.Columns ->
      for c = 0 to im.Image.width -1 do
        List.iteri (f (List.init im.Image.height ~f:(fun j -> get_rgba im c j)))
          ~f:(fun j rgba -> set_rgba im c j rgba)

  let neighbors im i j =
      [ (if i > 0 then Some (i - 1, j) else None)
      ; (if j > 0 then Some (i, j - 1) else None)
      ; (if i < im.width - 1 then Some (i + 1, j) else None)
      ; (if j < im.height - 1 then Some (i, j + 1) else None) ]

let main file distance out_path neighbor_precedence_norm
    neighbor_precedence_rev pivot_corner join_norm traverse_mode () =
  let im = openfile file in
  let region = Array.create ~-1 ~len:(im.Image.width * im.Image.height) in
  let max_region = ref 0 in
  let _ = pivot_corner in
  let start_i, start_j =
    match pivot_corner with
    | Corner.North_east -> (0, 0)
    | North_west -> (im.Image.width - 1, 0)
    | South_east -> (0, im.Image.height - 1)
    | South_west -> (im.Image.width - 1, im.Image.height - 1)
  let start_idx = (start_i * im.Image.height) + start_j in
  region.(start_idx) <- 0 ;
  let region_color = Int.Table.create () in
  Hashtbl.add_exn region_color ~key:0 ~data:(Image.get_rgba im start_i start_j) ;
  Image.iteri traverse_mode pivot_corner im ~f:(fun i j col ->
      let neighbors =
        List.filter_map (Image.neighbors im i j) ~f:(fun (ni, nj) ->
            let nregion = region.((ni * im.Image.height) + nj) in
            if nregion = ~-1 then None
                ( ni
                , nj
                , Image.Rgba.norm
                    (Hashtbl.find_exn region_color nregion)
                    neighbor_precedence_norm ) )
      let neighbors =
        List.sort neighbors ~compare:(fun (_, _, d) (_, _, d') ->
            Int.compare d d' )
        ((if neighbor_precedence_rev then List.rev else Fn.id)
           (List.map ~f:(fun (i, j, _) -> (i, j)) neighbors))
        ~f:(fun (ni, nj) ->
          let col =
            let cregion = region.((i * im.Image.height) + j) in
            if cregion = ~-1 then col
            else Hashtbl.find_exn region_color cregion
          let nregion = region.((ni * im.Image.height) + nj) in
          let ncol = Image.get_rgba im ni nj in
          let ncol =
            Option.value ~default:ncol (Hashtbl.find region_color nregion)
          if nregion = ~-1 then ()
          else if Image.Rgba.distance col ncol join_norm < distance then (
            region.((i * im.Image.height) + j) <- nregion ;
              let rcol = Hashtbl.find_exn region_color nregion in
              Hashtbl.set region_color ~key:nregion
                ~data:(Image.Rgba.combine rcol ~new_:col)
            with _ -> printf "%d %d %d\n%!" i j nregion )
          else (
            incr max_region ;
            region.((i * im.Image.height) + j) <- !max_region ;
            Hashtbl.add_exn region_color ~key:!max_region ~data:col ) ) ) ;
  printf "%dx%d size, %d regions\n%!" im.Image.width im.Image.height
    (Hashtbl.length region_color) ;
  Image.mapi_inplace traverse_mode pivot_corner im ~f:(fun i j c ->
      Option.value ~default:c
      @@ Hashtbl.find region_color region.((i * im.Image.height) + j) ) ;
  ImageLib_unix.writefile out_path im

let double n file out_path () =
  let im = openfile file in
  Image.mapi_inplace Rows North_east im ~f:(fun i j _c ->
      Image.get_rgba im (i - (i mod n)) (j - (j mod n)) ) ;
  ImageLib_unix.writefile out_path im

module Pixel_compare = struct
  module T = struct
  type t =
    | Norm of Norm_kind.t
    | Lexi
    | Simple_mean
    [@@deriving sexp]

  include T
  include Sexpable.To_stringable (T)

  let arg = Command.Arg_type.create of_string

  let cmp t r r' =
    match t with
    | Simple_mean ->
      Int.compare ( (r.Image.Rgba.r + r.Image.Rgba.g + r.Image.Rgba.b ) / 3)
      ( (r'.Image.Rgba.r + r'.Image.Rgba.g + r'.Image.Rgba.b ) / 3)
    | Norm n ->
      Int.compare (Image.Rgba.norm r n) (Image.Rgba.norm r' n)
    | Lexi ->
      let cmp_r = Int.compare r.Image.Rgba.r r'.Image.Rgba.r in
      let cmp_g = Int.compare r.Image.Rgba.g r'.Image.Rgba.g in
      let cmp_b = Int.compare r.Image.Rgba.b r'.Image.Rgba.b in
      if not (cmp_r = 0)
      then cmp_r
      else if not (cmp_g = 0)
      then cmp_g
      else cmp_b

let sort file out_path dir compare () =
  let im = openfile file in
  Image.map_slices_in_place dir im ~f:(fun s ->
     List.sort s ~compare:(Pixel_compare.cmp compare));
  ImageLib_unix.writefile out_path im

let () =
  @@ Command.group ~summary:"Pixel sorting and glitch toolkit."
       [ ( "weigert"
         , Command.basic
               "Variants of dripping edge detect algo written by \
             (let open Command.Let_syntax in
             let%map_open file = anon ("FILE" %: string)
             and distance =
               flag "distance" ~doc:"int triggering norm threshold" (optional_with_default 200 int)
             and out_path = flag "out" (optional string) ~doc:"path output path"
             and neighbor_precedence_norm =
               flag "precedence"
                 (optional_with_default Norm_kind.One Norm_kind.arg)
             and join_norm =
               flag "join-norm"
                 (optional_with_default Norm_kind.Max Norm_kind.arg)
             and neighbor_precedence_rev = flag "neighbor-rev" no_arg ~doc:" flag"
             and pivot_corner =
               flag "pivot-corner"
                 (optional_with_default Corner.North_east Corner.arg)
             and traverse_mode =
               flag "traverse"
                 (optional_with_default Traverse_mode.Columns Traverse_mode.arg)
             let notation =
               String.concat ~sep:"-"
                 [ "neigh_" ^ Norm_kind.to_string neighbor_precedence_norm
                 ; "join_" ^ Norm_kind.to_string join_norm
                 ; "trig_" ^ sprintf "%04d" distance
                 ; "inv_" ^ Bool.to_string neighbor_precedence_rev
                 ; "by_" ^ Traverse_mode.to_string traverse_mode
                 ; "from_" ^ Corner.to_string pivot_corner ]
             let out_path =
               Option.value ~default:(file ^ "." ^ notation ^ ".png") out_path
             main file distance out_path neighbor_precedence_norm
               (not neighbor_precedence_rev)
               pivot_corner join_norm traverse_mode) )
       ; ( "double"
         , Command.basic ~summary:"Averages NxN pixels."
             (let open Command.Let_syntax in
             let%map_open file = anon ("FILE" %: string)
             and n = flag "n" ~doc:"x" (optional_with_default 2 int)
             and out_path = flag "out" (optional string) ~doc:"X" in
             let notation = String.concat ~sep:"-" ["n_" ^ Int.to_string n] in
             let out_path =
               Option.value ~default:(file ^ "." ^ notation ^ ".png") out_path
             double n file out_path) )
             ; "sort"
         , Command.basic ~summary:"Actual sorting."
             (let open Command.Let_syntax in
             let%map_open file = anon ("FILE" %: string)
             and dir =
               flag "dir"
                 (optional_with_default Traverse_mode.Columns Traverse_mode.arg)
             and compare =
               flag "compare"
                 (optional_with_default Pixel_compare.Simple_mean Pixel_compare.arg)
             and out_path = flag "out" (optional string) ~doc:"X" in
             let notation = String.concat ~sep:"-" ["dir_" ^ Traverse_mode.to_string dir] in
             let out_path =
               Option.value ~default:(file ^ "." ^ notation ^ ".png") out_path
             sort file out_path dir compare)