(*
 * Copyright (c) 2018-2019
 *	The President and Fellows of Harvard College.
 * Written by David A. Holland.
 *
 * Copyright (c) 2020 David A. Holland.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 * 3. Neither the name of the University nor the names of its contributors
 *    may be used to endorse or promote products derived from this software
 *    without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE UNIVERSITY OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 *)

open Symbolic
module MS = Mipsstate
module M = Mips
module MT = Mipstools

type fact = string * Mips.expr

(*
 * symbolically executed synthesis problem in penguin logic
 *)
type xproblem = {
   allenv: Mips.typename Mips.VarMap.t;
   prefact: fact;
   postfact: fact;
   allpps: Ppoint.t list;
   numinsns: int;
   controlfacts: fact list;
   symfacts: fact list;
   symcode: Mips.code;
}


(*
 * Really synthesizing with NUMINSNS instructions, after having
 * constructed a symbolic program.
 *)
let synthcegis startstates0 num_counterexamples dryrun (xproblem: xproblem) =
   let allenv = xproblem.allenv in
   let prefact = xproblem.prefact in
   let postfact = xproblem.postfact in
   let allpps = xproblem.allpps in
   let numinsns = xproblem.numinsns in
   let controlfacts = xproblem.controlfacts in
   let symfacts = xproblem.symfacts in
   let symcode = xproblem.symcode in

   let prepoint = List.hd allpps in

   let negpostfact =
      match postfact with
      | (_name, post) -> ("Negated postcondition", M.LOGNOT post)
   in

   let timings = ref [] in

   (*
    * Build the SMT for the control variables.
    *)
   let control_smt =
      let elabel = "Control variable environment:" in
      let clabel = "Control variable constraints:" in
      let env = Build.build'environment elabel allenv in
      let constraints = Build.build'envconstraints clabel allenv allpps in
      let facts = Build.build'facts None allpps controlfacts in
      env @ constraints @ facts
   in

   (*
    * CEGIS loop
    *)

   let rec cegis startstates =
      let guess_smt =
         let mkstate i = Smt.STATE i in
         let startstates' = Util.number startstates in

         (*
          * Build the SMT for the state-dependent variables, which depends
          * on the list of states.
          *)
         let state_smt =
            let states = List.map (fun (i, _) -> Smt.STATE i) startstates' in
	    let elabel = "State-dependent variable environment:" in
            let clabel = "State-dependent variable constraints:" in
            let env = Build.build'more'environment elabel allenv states in
	    let constraints =
               Build.build'more'envconstraints clabel allenv states allpps
            in
            env @ constraints
         in

         let main_smt =
            let onestate (i, startstate) =
               let state = mkstate i in
               let env = Build.build'execenv "State variables:" state allpps in
               let pre = Build.build'fact (Some state) allpps prefact in
               let exec = Build.build'facts (Some state) allpps symfacts in
               let post = Build.build'fact (Some state) allpps postfact in
               let comment = Smt.COMMENT ("startstate " ^ string_of_int i) in
               env @ pre @ exec @ post @ [comment; Smt.ASSERT startstate]
            in
            List.concat (List.map onestate startstates')
         in

         state_smt @ main_smt
      in

      (* Guess a concrete code sequence. *)
      Log.vsay ("*** Guessing ***");
      Log.csay ("Guessing  " ^ string_of_int numinsns ^ " instructions");
      let (gduration, gsat) =
         let smt = control_smt @ guess_smt in
         if dryrun then Solver.dryrun true(*isguess*) smt
         else Solver.go smt
      in
      timings := (gduration, "guess") :: !timings;
      match gsat with
      | Solver.UNSAT _ ->
           (* no viable candidate code blocks of this length *)
           None
      | Solver.SAT candidate ->
           (* got a candidate *)

           (* convert assignments to assertions *)
           let candidate = Build.filter_candidate candidate in
           let candidate_smt = Build.assert_assignment candidate in

           (* don't need this unless we get UNSAT; but want to print it *)
           let code =
              if dryrun then ([], prepoint)
              else Build.unguess candidate symcode
           in
           let codetxt = Printasm.print "   " code in
           Log.vsay ("Candidate:");
           Log.vsay (codetxt);
           Log.csay (codetxt);

           (*
            * Verify it for all start states: negate the postcondition
            * so as to search for an assignment to the start state
            * variables that violates the postcondition.
            *)
           let verify_smt =
              let state = Smt.VERIFY in (* "V" *)

              let state_smt =
                 let states = [state] in
	         let elabel = "State-dependent variable environment:" in
                 let clabel = "State-dependent variable constraints:" in
                 let env = Build.build'more'environment elabel allenv states in
	         let constraints =
                    Build.build'more'envconstraints clabel allenv states allpps
                 in
                 env @ constraints
              in

              let env = Build.build'execenv "State variables:" state allpps in
              let pre = Build.build'fact (Some state) allpps prefact in
              let exec = Build.build'facts (Some state) allpps symfacts in
              let negpost = Build.build'fact (Some state) allpps negpostfact in
              state_smt @ env @ pre @ exec @ negpost
           in

	   let ncstr = string_of_int num_counterexamples in
           Log.vsay ("*** Verifying (" ^ ncstr ^ " counterexamples) ***");
           Log.csay ("Verifying " ^ string_of_int numinsns ^
		     " instructions (getting " ^ ncstr ^ " counterexamples)");

	   let doverify i more_smt =
              let smt = control_smt @ verify_smt @ candidate_smt @ more_smt in
              let (vduration, vsat) =
                 if dryrun then Solver.dryrun false(*isguess*) smt
                 else Solver.go smt
              in
              timings := (vduration, "verify") :: !timings;
              match vsat with
              | Solver.UNSAT _ -> None
              | Solver.SAT counterexample ->
                   (* Doesn't work; prepare a new start state *)
                   let nextstate = i + List.length startstates in
                   let nstate = Smt.STATE nextstate in (* "S" ^ string_of_int nextstate *)
                   let vstate = Smt.VERIFY in (* "V" *)
                   let newstartstate =
                      Build.unverify vstate nstate prepoint
				counterexample
                   in
                   (* XXX should fix unverify to not require multiple calls *)
                   let newstartstate' =
                      Build.unverify vstate vstate prepoint counterexample
                   in
                   Some (newstartstate, newstartstate')
           in
	   let rec doverifies i more_smt results =
              if i = num_counterexamples then
                 List.rev results
              else
                 match doverify i more_smt with
                 | None -> List.rev results
                 | Some (newstartstate, newstartstate') ->
                      (* get another counterexample *)
		      let different_smt =
                         Smt.ASSERT (Smt.NOT newstartstate')
                      in
		      let more_smt' = different_smt :: more_smt in
		      let results' = newstartstate :: results in
                      doverifies (i+1) more_smt' results'
           in
           match doverifies 0 [] [] with
           | [] ->
                (* No such assignment; done; concretize the code *)
                Some code
           | newstartstates ->
                (* try guessing again *)

	        (* this order breaks the startstate numbering *)
                (*cegis (newstartstate :: startstates)*)

                cegis (startstates @ newstartstates)
   in
   let result = cegis startstates0 in
   let resultstr =
      match result with
      | Some _ -> "Done"
      | None -> "Failed"
   in
   let runstr = string_of_int (List.length !timings) ^ " solver runs" in
   let timingstr =
      let once (dur, str) = "   " ^ string_of_float dur ^ " (" ^ str ^ ")" in
      String.concat "\n" (List.map once (List.rev !timings))
   in
   Log.vcsay ("*** " ^ resultstr ^ " (" ^ runstr ^ ") ***");
   Log.csay timingstr;
   Log.timingsay ("# " ^ resultstr ^ " in " ^ runstr ^ "\n" ^ timingstr);
   result

(*
 * Try synthesizing with NUMINSNS instructions and optional sketch SKETCH.
 *)
let synthesize' startstates0 pre numinsns sketch post num_counterexamples dryrun =
   let verb = if dryrun then "Dry run with " else "Synthesizing with " in
   Log.dcsay ("*** " ^ verb ^ string_of_int numinsns ^ " insns ***");

   (*
    * Make the symbolic instruction sequence.
    *
    * This produces:
    *    control variables that govern what instructions the sequence is
    *    assertions about the control variables to keep things well formed
    *    a list of all program points in the instruction sequence
    *    the symbolic instruction sequence itself
    *
    * The assertions are Mips.expr expressions (paired with a name for
    * use when printing, like the various "facts" below).
    *
    * The four return values are:
    *    controlenv - control variable environment (map from names to types)
    *    controlfacts - assertions about the control variables
    *    allpps - the program points
    *    symcode - the symbolic instruction list
    *)

   match Syminsns.make_symbolic_insns numinsns sketch with
   | None ->
        Log.dcsay "*** Unsat ***";
        exit 1
   | Some (controlenv, controlfacts, allpps, symcode) ->

        (*
         * Update the pre/postconditions to refer to the program points in
         * the sequence.
         *)

        (* this should always be a nop, but whatever *)
        let prepoint = Mipstools.prepoint symcode in
        let pre = Mipstools.changepp pre (Ppoint.pre) prepoint in

        let postpoint = Mipstools.postpoint symcode in
        let post = Mipstools.changepp post (Ppoint.post) postpoint in

        (*
         * Symbolically execute the sequence to produce constraint expressions.
         *)

        let (symenv, symfacts) =
            Symexec.go symcode
        in
        let prefact = ("precondition", pre) in
        let postfact = ("postcondition", post) in
        let allfacts =
           prefact :: postfact :: controlfacts @ symfacts
        in

        let allenv =
           let doadd z (v, t) = M.VarMap.add v t z in
           List.fold_left doadd controlenv symenv
        in
(*
        let print'fact (what, e) =
           prerr_string (what ^ ":\n");
           prerr_string (Printlogic.print "   " e ^ "\n");
        in
        List.iter print'fact allfacts;
*)
        let tc'fact when_ (what, e) =
           Typecheck.go allenv e BOOL (what ^ when_)
        in
        List.iter (tc'fact "") allfacts;

        (*
         * Simplify what we got. XXX notyet; doesn't do anything anyway
         *)
(*
        let (allenv, facts, symcode) =
           Simplify.go allenv facts symcode
        in
*)

(*  turn this off until Simplify does something as there's no point
        let allfacts =
           prefact :: postfact :: controlfacts @ symfacts
        in
        List.iter (tc'fact " (after Simplify)") allfacts;
*)
        let xproblem = {
           allenv;
	   prefact;
	   postfact;
	   allpps;
	   numinsns;
	   controlfacts;
	   symfacts;
	   symcode;
	} in
        synthcegis startstates0 num_counterexamples dryrun xproblem

let make_init_states pre n require_all =
   let sstr =
      string_of_int n ^ " " ^
         "initial start states differing in " ^
         (if require_all then "all variables" else "one variable")
   in
   Log.dcsay ("*** Getting " ^ sstr ^ " ***");
   let once (already, already', quit) i =
      if quit then (already, already', quit)
      else begin
	 let mk_i_name j = Smt.INIT j in (* "I" ^ string_of_int j *)
         let i_name = mk_i_name i in
         let j_names =
            List.map mk_i_name (Util.count i)
         in
	 let s_name = Smt.STATE i in (* "S" ^ string_of_int i *)
	 let pp = Ppoint.pre in
         let pps = [pp] in
	 let mk_prestate_smt name =
            Build.build'execenv "new prestate" name pps
         in
         let prestate_smt = mk_prestate_smt i_name in
	 let pre_smt =
	    Build.build'fact (Some i_name) pps ("precondition", pre)
	 in
         let j_decls =
            let once jn =
               let nstr =
                  match jn with
                  | Smt.INIT j -> string_of_int j
                  | _ -> Util.crash "synthcegis: make_init_states: bad state"
               in
               Build.build'execenv ("prestate " ^ nstr) jn pps
            in
            List.concat (List.map once j_names)
         in
         let j_different =
            let label =
               "state " ^ string_of_int i ^ " different"
            in
            Build.assert_different_execenv label i_name j_names pp require_all
         in
	 let smt =
            let already'' =
               let assertone e = [
                  Smt.COMMENT "prior prestate assignment";
                  Smt.ASSERT e;
               ] in
               List.concat (List.map assertone already')
            in
            j_decls @ prestate_smt @ already'' @ j_different @ pre_smt
         in
         let (_iduration, isat) = Solver.go smt in
         match isat with
         | Solver.UNSAT _ ->
              if i = 0 then
                  Util.crash "Requested precondition is unsatisfiable"
              else
                  (already, already', true)
         | Solver.SAT example ->
              let startstate =
                 Build.unverify i_name s_name Ppoint.pre example
              in
              let startstate' =
                 Build.unverify i_name i_name Ppoint.pre example
              in
              (startstate :: already, startstate' :: already', quit)
      end
   in
   let (startstates, _, _) =
      List.fold_left once ([], [], false) (Util.count n)
   in
   let startstates = List.rev startstates in
   Log.csay ("Got " ^ sstr);
   startstates   

let synthesize (pre, post) numinsns sketch dryrun =
   let num_initstates = 1 in
   let initstates_require_all = false in
   let num_counterexamples = 1 in

   let startstates0 =
      if dryrun then List.init num_initstates (fun _ -> Smt.TRUE)
      else make_init_states pre num_initstates initstates_require_all
   in
   let rec dosynth numinsns =
      match synthesize' startstates0 pre numinsns sketch post num_counterexamples dryrun with
      | Some code -> code
      | None -> dosynth (1 + numinsns)
   in
   dosynth numinsns
