(* Congruence closure.
   Copyright (C) 2008 Jean Goubault-Larrecq and LSV, CNRS UMR 8643 & ENS Cachan.

   This file is part of h1.

   h1 is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2, or (at your option)
   any later version.

   h1 is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with h1; see the file COPYING.  If not, write to
   the Free Software Foundation, 675 Mass Ave, Cambridge, MA 02139, USA.
*)

open "cc_h";

(*
open "term_h"; (* DEBUG *)
*)

fun cc () : (''_var, '_a) cc =
    let (* First, we assume a so-called universe: this is a finite
	 set of terms, closed under subterms. *)

(*
val varnames = ref (fn _ => "") (* DEBUG *)
*)

	(* The parents of a color c are all terms f (..., t, ...)
	 in the universe with some argument t of color c. *)
	val parents = ref ({} : ''_var color -m> ''_var term set ref)
	(* the set of parents of the color *)

	(* color maps each term in the universe to its color,
	 i.e., a unique term in its equivalence class. *)
	val color = ref ({} : ''_var term -m> ''_var color)

	(* invcolor is the inverse map to color: it yields the set of
	 all terms of a given color. *)
	val invcolor = ref ({} : ''_var color -m> ''_var term set ref)

	(* sigtable maps f (c1, ..., cn) where each ci is a color,
	 to some term in the universe of the color of f (c1, ..., cn).
	 This is provided there is t1 of color c1, ..., tn of color cn,
	 such that f (t1, ..., tn) is in the universe. *)
	val sigtable = ref ({} : string * ''_var color list -m> ''_var term)

(*
	exception ForgottenSigtable
	exception ExtraSigtable
	exception NoColorSigtable
	exception NoColorArgSigtable
	fun print_sigtable (name, sgt) =
	    (#put stderr name;
	     #put stderr " is:\n";
	     iterate
	       (#put stderr "\t";
		print_term (stderr, !varnames) (f $ cargs);
		#put stderr " -> ";
		print_term (stderr, !varnames) t;
		#put stderr "\n")
	     | (f, cargs) => t in map sgt
	     end;
	     #flush stderr ());

	fun check_sigtable () =
	    let val sigtable' = {(f, [?(!color) t | t in list tl]) => c
				| f $ tl => c in map !color}
	    in
		iterate
		  iterate
		    let val c = ?(!color) t
		    in
			if t<>c
			    then (#put stderr "!!! sigtable entry ";
				  print_term (stderr, !varnames) (f $ args);
				  #put stderr " has non-color argument:\n  t = ";
				  print_term (stderr, !varnames) t;
				  #put stderr "\n  color = ";
				  print_term (stderr, !varnames) c;
				  #put stderr "\n";
				  #flush stderr ();
				  raise NoColorArgSigtable)
			else ()
		    end
		  | t in list args
		  end
		| (f, args) in set !sigtable
		end;
		case !sigtable <-| sigtable' of
		    excess as {sg => t, ...} =>
		    (print_sigtable ("sigtable' - sigtable", excess);
		     raise ForgottenSigtable)
		  | _ =>
		    (case sigtable' <-| !sigtable of
			 missing as {sg => t, ...} =>
			 (print_sigtable ("sigtable - sigtable'", missing);
			  raise ExtraSigtable)
		       | _ =>
			 iterate
			   if c<>c'
			       then raise NoColorSigtable
			   else ()
			 | sg => c' in map sigtable'
			   val t = ?(!sigtable) sg
			   val c = ?(!color) t
			 end)
	    end

	exception ForgottenParents
	exception ExtraParents
	exception ForgottenParent
	exception ExtraParent
	fun check_parents () =
	    let val parents' = ref {c => ref {}
				   | c in set rng (!color)}
	    in
		iterate
		  iterate
		    let val acc = if ci inset !parents'
				      then ?(!parents') ci
				  else let val a = ref {}
				       in
					   parents' := !parents' ++ {ci => a};
					   a
				       end
		    in
			acc := !acc U {t}
		    end
		  | ti in list args
		  val ci = ?(!color) ti
		  end
		| t as f $ args => _ in map !color
		end;
		case !parents <-| !parents' of
		    excess as {c => ref ts, ...} =>
		    raise ForgottenParents
		  | _ =>
		    (case !parents' <-| !parents of
			 missing as {c => ref ts, ...} =>
			 (#put stderr "!!! extra parents: c = ";
			  print_term (stderr, !varnames) c;
			  #put stderr "\n  ts =";
			  iterate
			    (#put stderr " ";
			     print_term (stderr, !varnames) t)
			  | t in set ts
			  end;
			  #flush stderr ();
			  raise ExtraParents)
		       | _ =>
			 iterate
			   (case ts \ ts' of
				{t, ...} =>
				(#put stderr "!!! forgotten parent : c = ";
				 print_term (stderr, !varnames) c;
				 #put stderr "\n  ts - ts' =";
				 iterate
				   (#put stderr " ";
				    print_term (stderr, !varnames) t)
				 | t in set ts \ ts'
				 end;
				 #flush stderr ();
				 raise ForgottenParent)
			      | _ => (case ts' \ ts of
					  {t, ...} =>
					  (#put stderr "!!! extra parent : c = ";
					   print_term (stderr, !varnames) c;
					   #put stderr "\n  ts' - ts =";
					   iterate
					     (#put stderr " ";
					      print_term (stderr, !varnames) t)
					   | t in set ts' \ ts
					   end;
					   #flush stderr ();
					   raise ExtraParent)
					| _ => ()))
			 | c => ref ts' in map !parents'
			   val ts = !(?(!parents) c)
			 end)
	    end

	fun check () = (check_parents (); check_sigtable ())
*)

	(* pending holds a list of equations u=v
	 that need to be taken care of. *)
	val pending = ref (nil : (''_var term * ''_var term) list)

	fun update_parents (cl, t) =
	    iterate
	      let val prts = ?(!parents) cu
	      in
		  prts := !prts U {t}
	      end
	    | cu in list cl
	    end

	fun find t =
	    let val clr = !color
	    in
		if t inset clr
		    then ?clr t
		else case t of
			 V _ => (color := clr ++ {t => t};
				 invcolor := !invcolor ++ {t => ref {t}};
				 parents := !parents ++ {t => ref {}};
				 t)
		       | f $ l =>
			 let val cl = [find u | u in list l]
			     val clr = !color
			     val sg = (f, cl)
			     val sigt = ((* check (); *)
					 !sigtable)
			 in
			     if sg inset sigt (* the new term is in fact
					       already has a color: *)
				 then let val tt = ?sigt sg
					  val c = ?clr tt
					  val invc = ?(!invcolor) c
				      in
					  color := clr ++ {t => c};
					  invc := !invc U {t};
					  update_parents (cl, t);
(*
#put stderr "Find [old] : ";
print_term (stderr, !varnames) t;
#put stderr "\n";
#flush stderr ();
check ();
*)
					  c
				      end
			     else (* otherwise create a new color:
				   this will be t itself. *)
				 (parents := !parents ++ {t => ref {}};
				  update_parents (cl, t);
				  color := clr ++ {t => t};
				  invcolor := !invcolor ++ {t => ref {t}};
				  sigtable := sigt ++ {sg => t};
(*
#put stderr "Find [new] : ";
print_term (stderr, !varnames) t;
#put stderr "\n";
#flush stderr ();
check ();
*)
				  t)
			 end
	    end

	fun killsigs paru =
	    (* kills every signature entry that is congruent to some
	     term in paru.  Returns a list of equalities (parent, t)
	     that should still hold.
	     congruent is the relation defined so that
	     f (s1,...,sn) congruent f (t1,...,tn) iff
	     find si=find ti for each i.
	     paru should be a subset of the universe. *)
	    let val sigt = !sigtable
		val clr = !color
		val vanishing = [let val cargs = [?clr t
						 | t in list args]
				     val sg = (f, cargs)
				 in
				     sigtable := {sg} <-| !sigtable;
				     (parent, ?sigt sg)
				 end
				| parent as f $ args in set paru]
	    in
(*
#put stderr "*** Killsigs: paru =";
iterate
(#put stderr "\n        "; print_term (stderr, !varnames) t)
| t in set paru
end;
#put stderr "\n  vanishing =";
iterate
(#put stderr "\n        "; print_term (stderr, !varnames) p;
#put stderr "="; print_term (stderr, !varnames) c;
#put stderr " [color = ";
print_term (stderr, !varnames) (?(!color) c);
#put stderr "]"
)
| (p,c) in list vanishing
end;
#put stderr "\n";
#flush stderr ();
*)
		vanishing
	    end

	fun reinstallsigs eqns = (* recomputes signature entries corresponding
				  to each application term in paru. *)
(
(*
#put stderr "*** Reinstall sigs: eqns =";
iterate
(#put stderr "\n        "; print_term (stderr, !varnames) p;
#put stderr "="; print_term (stderr, !varnames) c)
| (p,c) in list eqns
end;
#put stderr "\n";
#flush stderr ();
*)
	    iterate
	      let val cargs = [?(!color) t | t in list args]
		  val sg = (f, cargs)
	      in
(*
#put stderr "  * parent = ";
print_term (stderr, !varnames) (f $ cargs);
#put stderr "\n    tt = ";
print_term (stderr, !varnames) tt;
#put stderr " [color = ";
print_term (stderr, !varnames) (?(!color) tt);
#put stderr "]\n";
#flush stderr ();
*)
		  if sg inset !sigtable
		      then pending := (?(!sigtable) sg, tt) :: !pending
		  else sigtable := !sigtable ++ {sg => tt}
	      end
	    | (parent as f $ args, tt) in list eqns
	    end
)

	fun merge_colors (cu, cv) =
	    (* rewrite color cu to cv in color, invcolor, and parents tables. *)
	    let val invclr = !invcolor
		val us = !(?invclr cu)
	    in
		color := !color ++ {u => cv
				   | u in set us};
		invcolor := ({cu} <-| invclr);
		let val vsr = ?invclr cv
		in
		    vsr := !vsr U us
		end;
		let val prts = !parents
		    val vpr = ?prts cv
		in
		    vpr := !vpr U !(?prts cu);
		    parents := {cu} <-| prts
		end
	    end

	fun merge (cu, paru, cv) = (* rewrite color cu to cv everywhere;
				    paru is the set of application terms in
				    the universe that have some argument of
				    color cu. *)
	    let val eqns = killsigs paru
	    in
(*
#put stderr "*** Merge: cu=";
print_term (stderr, !varnames) cu;
#put stderr "  cv=";
print_term (stderr, !varnames) cv;
#put stderr "\n";
#flush stderr ();
*)
		merge_colors (cu, cv);
		reinstallsigs eqns
	    end

	fun do_union (u, v) =
	    let val cu = find u
		val cv = find v
	    in
		if cu=cv
		    then ()
		else let val paru = !(?(!parents) cu)
			 val parv = !(?(!parents) cv)
		     in
			 if card paru<card parv
			     then merge (cu, paru, cv)
			 else merge (cv, parv, cu)
		     end
	    end

	fun process () =
	    case !pending of
		eqn::rest => (pending := rest;
			      do_union eqn;
			      process ())
	      | _ => ()

	fun equate eqn = (
(*
check (); #put stderr "[[equate "; print_term (stderr, !varnames) (#1 eqn);
#put stderr "="; print_term (stderr, !varnames) (#2 eqn); #put stderr "\n"; #flush stderr ();
*)
			  do_union eqn;
			  process ()
(*
before (#put stderr "]]\n"; #flush stderr (); check ())
*)
			  )

	fun class c =
	    let val invc = !invcolor
	    in
		if c inset invc
		    then !(?invc c)
		else raise NotColor
	    end

	fun select (fs, c) =
	    {t
	    | t as f $ _ in set class c
	      such that f inset fs}

	fun all_sigmas ({}, sigma, do_subst) = do_subst sigma
	  | all_sigmas ({x => _} U rest, sigma, do_subst) =
	    iterate
	      all_sigmas (rest, sigma', do_subst)
	    | c in set !invcolor
	    val sigma' = sigma ++ {x => c}
	    end

	fun equate_all (u, v, vars) =
	    (* adds all equations u.sigma = v.sigma,
	     where sigma only binds variables in vars,
	     and maps variables to elements in the current universe only.
	     Assumes FV(v) subseteq FV(u), i.e.,
	     that u -> v is a proper rewrite rule. *)
	    all_sigmas (vars, {},
			fn sigma =>
			   let val ts = tsubst sigma
			   in
			       equate (ts u, ts v)
			   end)
	fun colors () = {c => ts
			| c => ref ts in map !invcolor}
    in
	|[ find = find,
	   equate = equate,
	   class = class,
	   select = select,
	   equate_all = equate_all,
	   universe = (fn () => !color),
	   colors = colors,
(*
set_varnames = (fn vars => varnames := vars), (* DEBUG *)
*)
	   reset = (fn () =>
		       (parents := {};
			color := {};
			invcolor := {};
			sigtable := {};
			pending := nil))
	   ]|
    end;
