(*
 * Copyright (c) 2018-2019
 *	The President and Fellows of Harvard College.
 *
 * Written by David A. Holland.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 * 3. Neither the name of the University nor the names of its contributors
 *    may be used to endorse or promote products derived from this software
 *    without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE UNIVERSITY OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 *)

(* simple optimizer for smt *)

(*
 * The purpose of this is not to help the solver (the solver can
 * presumably do all these transforms itself) but to remove crap from
 * the input expression so it can be read by humans.
 *)

module I = Mintops
open Smt

let zerop x = x = Big_int.zero_big_int
let onep x = x = Big_int.unit_big_int

(**************************************************************)
(* baseopt pass *)

let trivial e =
   match e with
   | TRUE
   | FALSE
   | XINTVAL _
   | MINTVAL _
   | READVAR _
   | NOT (READVAR _) -> true
   | _ -> false

let rec subst x repl e =
   let sub e0 = subst x repl e0 in
   match e with
   | TRUE -> TRUE
   | FALSE -> FALSE
   | XINTVAL (k, w) -> XINTVAL (k, w)
   | MINTVAL k -> MINTVAL k
   | BOOLARRAYVAL (m, d, kw) -> BOOLARRAYVAL (m, d, kw)
   | XINTARRAYVAL (m, d, kw, vw) -> XINTARRAYVAL (m, d, kw, vw)
   | MINTARRAYVAL (m, d, kw) -> MINTARRAYVAL (m, d, kw)
   | READVAR x0 -> if x0 = x then repl else READVAR x0
   | READARRAY (a, e1) -> READARRAY (a, sub e1)
   | UPDATEARRAY (a, e1, e2) -> UPDATEARRAY (a, sub e1, sub e2)
   | LET (x, ty, e1, e2) -> LET (x, ty, sub e1, sub e2)
   | IF (c, t, f) -> IF (sub c, sub t, sub f)
   | NOT e1 -> NOT (sub e1)
   | AND (e1, e2) -> AND (sub e1, sub e2)
   | OR (e1, e2) -> OR (sub e1, sub e2)
   | XOR (e1, e2) -> XOR (sub e1, sub e2)
   | EQ (e1, e2) -> EQ (sub e1, sub e2)
   | IMPLIES (e1, e2) -> IMPLIES (sub e1, sub e2)
   | FORALL (x, ty, e1) -> FORALL (x, ty, sub e1)
   | EXISTS (x, ty, e1) -> EXISTS (x, ty, sub e1)
   | MINTUOP (fn, e1) -> MINTUOP (fn, sub e1)
   | MINTBOP (fn, e1, e2) -> MINTBOP (fn, sub e1, sub e2)

let baseopt'intuop fn e1 =
   match fn, e1 with
(*
   | I.NOP, _ -> Some e1
   | I.NEG, INTUOP ("neg", e1) -> Some e1
   | I.NEG, INTBOP ("sub", e1, e2) -> Some (INTBOP ("sub", e2, e1))
   | I.NEG, _ -> None
*)
   | I.BITNOT, _ -> None
   | I.SBWIDEN, _ -> None
   | I.SHWIDEN, _ -> None
   | I.UBWIDEN, _ -> None
   | I.UHWIDEN, _ -> None
   | I.U5WIDEN, _ -> None
   | I.TRUNCU5, _ -> None
   | I.TOSIGNED, MINTUOP (I.TOUNSIGNED, e1') -> Some e1'
   | I.TOSIGNED, _ -> None
   | I.TOUNSIGNED, MINTUOP (I.TOSIGNED, e1') -> Some e1'
   | I.TOUNSIGNED, _ -> None

let baseopt'intbop fn e1 e2 =
   match fn, e1, e2 with
   | I.EQ, MINTVAL a, MINTVAL b ->
        if Big_int.eq_big_int a b then Some TRUE else Some FALSE
   | I.EQ, READVAR x, READVAR y -> if x = y then Some TRUE else None
   | I.EQ, _, _ -> None
   | I.ULT, MINTVAL a, MINTVAL b ->
        if Big_int.lt_big_int a b then Some TRUE else Some FALSE
   | I.ULT, READVAR x, READVAR y -> if x = y then Some FALSE else None
   | I.ULT, _, _ -> None
   | I.SLT, MINTVAL a, MINTVAL b ->
        if Big_int.lt_big_int a b then Some TRUE else Some FALSE
   | I.SLT, READVAR x, READVAR y -> if x = y then Some FALSE else None
   | I.SLT, _, _ -> None
   | I.ADD, e1, MINTVAL z when zerop z -> Some e1
   | I.ADD, MINTVAL z, e2 when zerop z -> Some e2
   | I.ADD, _, _ -> None
   | I.SUB, e1, MINTVAL z when zerop z -> Some e1
   | I.SUB, _, _ -> None
   | I.UMUL_HI, _, MINTVAL z when zerop z -> Some (MINTVAL z)
   | I.UMUL_HI, _, MINTVAL one when onep one -> Some (MINTVAL Big_int.zero_big_int)
   | I.UMUL_HI, _, _ -> None
   | I.SMUL_HI, _, MINTVAL z when zerop z -> Some (MINTVAL z)
   | I.SMUL_HI, _, _ -> None
   | I.MUL_LO, _, MINTVAL z when zerop z -> Some (MINTVAL z)
   | I.MUL_LO, MINTVAL z, _ when zerop z -> Some (MINTVAL z)
   | I.MUL_LO, e1, MINTVAL one when onep one -> Some e1
   | I.MUL_LO, MINTVAL one, e2 when onep one -> Some e2
   | I.MUL_LO, _, _ -> None
   | I.UDIV, e1, MINTVAL one when onep one -> Some e1
   | I.UDIV, _, _ -> None
   | I.UMOD, _, MINTVAL one when onep one -> Some (MINTVAL Big_int.zero_big_int)
   | I.UMOD, _, _ -> None
   | I.SDIV, e1, MINTVAL one when onep one -> Some e1
   | I.SDIV, _, _ -> None
   | I.SMOD, _, MINTVAL one when onep one -> Some (MINTVAL Big_int.zero_big_int)
   | I.SMOD, _, _ -> None
   | I.BITAND, _, MINTVAL z when zerop z -> Some (MINTVAL z)
   | I.BITAND, _, _ -> None
   | I.BITOR, e1, MINTVAL z when zerop z -> Some e1
   | I.BITOR, _, _ -> None
   | I.BITXOR, e1, MINTVAL z when zerop z -> Some e1
   | I.BITXOR, _, _ -> None
   | I.SHL, e1, MINTVAL z when zerop z -> Some e1
   | I.SHL, _, _ -> None
   | I.USHR, e1, MINTVAL z when zerop z -> Some e1
   | I.USHR, _, _ -> None
   | I.SSHR, e1, MINTVAL z when zerop z -> Some e1
   | I.SSHR, _, _ -> None

let baseopt'here e =
   match e with
   | LET (x, ty, e1, e2) ->
        (* we already removed trivial e1 *)
        LET (x, ty, e1, e2)
   | IF (TRUE, t, _f) -> t
   | IF (FALSE, _t, f) -> f
   | NOT TRUE -> FALSE
   | NOT FALSE -> TRUE
   | NOT (NOT e1) -> e1
   | NOT (AND (NOT e1, NOT e2)) -> OR (e1, e2)
   | NOT (OR (NOT e1, NOT e2)) -> AND (e1, e2)
   | AND (TRUE, e2) -> e2
   | AND (FALSE, _) -> FALSE
   | AND (e1, TRUE) -> e1
   | AND (_, FALSE) -> FALSE
   | XOR (TRUE, TRUE) -> FALSE
   | XOR (TRUE, FALSE) -> TRUE
   | XOR (FALSE, TRUE) -> TRUE
   | XOR (FALSE, FALSE) -> FALSE
   | OR (TRUE, _) -> TRUE
   | OR (FALSE, e2) -> e2
   | OR (_, TRUE) -> TRUE
   | OR (e1, FALSE) -> e1
   | EQ (TRUE, TRUE) -> TRUE
   | EQ (TRUE, FALSE) -> FALSE
   | EQ (FALSE, TRUE) -> FALSE
   | EQ (FALSE, FALSE) -> TRUE
   | IMPLIES (TRUE, e2) -> e2
   | IMPLIES (_, TRUE) -> TRUE
   | IMPLIES (FALSE, _) -> TRUE
   | IMPLIES (e1, FALSE) -> NOT e1
   | FORALL (_, _, TRUE) -> TRUE
   | FORALL (_, _, FALSE) -> FALSE
   | EXISTS (_, _, TRUE) -> TRUE
   | EXISTS (_, _, FALSE) -> FALSE
   | MINTUOP (fn, e1) -> begin
        match baseopt'intuop fn e1 with
        | Some e' -> e'
        | None -> MINTUOP (fn, e1)
     end
   | MINTBOP (fn, e1, e2) -> begin
        match baseopt'intbop fn e1 e2 with
        | Some e' -> e'
        | None -> MINTBOP (fn, e1, e2)
     end
   | _ -> e

let rec baseopt'expr e =
   let descend e0 =
      match e0 with
      | TRUE -> TRUE
      | FALSE -> FALSE
      | XINTVAL (k, w) -> XINTVAL (k, w)
      | MINTVAL k -> MINTVAL k
      | BOOLARRAYVAL (m, d, kw) -> BOOLARRAYVAL (m, d, kw)
      | XINTARRAYVAL (m, d, kw, vw) -> XINTARRAYVAL (m, d, kw, vw)
      | MINTARRAYVAL (m, d, kw) -> MINTARRAYVAL (m, d, kw)
      | READVAR x -> READVAR x
      | READARRAY (a, e1) ->
	   let e1' = baseopt'expr e1 in
           READARRAY (a, e1')
      | UPDATEARRAY (a, e1, e2) ->
	   let e1' = baseopt'expr e1 in
	   let e2' = baseopt'expr e2 in
           UPDATEARRAY (a, e1', e2')
      | LET (x, ty, e1, e2) ->
	   let e1' = baseopt'expr e1 in
           if trivial e1' then
              baseopt'expr (subst x e1' e2)
           else
              let e2' = baseopt'expr e2 in
              LET (x, ty, e1', e2')
      | IF (c, t, f) ->
           let c' = baseopt'expr c in
           let t' = baseopt'expr t in
           let f' = baseopt'expr f in
           IF (c', t', f')
      | NOT e1 ->
	   let e1' = baseopt'expr e1 in
	   NOT e1'
      | AND (e1, e2) ->
	   let e1' = baseopt'expr e1 in
	   let e2' = baseopt'expr e2 in
           AND (e1', e2')
      | OR (e1, e2) ->
	   let e1' = baseopt'expr e1 in
	   let e2' = baseopt'expr e2 in
           OR (e1', e2')
      | XOR (e1, e2) ->
	   let e1' = baseopt'expr e1 in
	   let e2' = baseopt'expr e2 in
           XOR (e1', e2')
      | EQ (e1, e2) ->
	   let e1' = baseopt'expr e1 in
	   let e2' = baseopt'expr e2 in
           EQ (e1', e2')
      | IMPLIES (e1, e2) ->
	   let e1' = baseopt'expr e1 in
	   let e2' = baseopt'expr e2 in
           IMPLIES (e1', e2')
      | FORALL (x, ty, e1) ->
	   let e1' = baseopt'expr e1 in
           FORALL (x, ty, e1')
      | EXISTS (x, ty, e1) ->
	   let e1' = baseopt'expr e1 in
           EXISTS (x, ty, e1')
      | MINTUOP (fn, e1) ->
	   let e1' = baseopt'expr e1 in
	   MINTUOP (fn, e1')
      | MINTBOP (fn, e1, e2) ->
	   let e1' = baseopt'expr e1 in
	   let e2' = baseopt'expr e2 in
	   MINTBOP (fn, e1', e2')
   in
   let e' = descend e in
   baseopt'here e'

let baseopt'decl d =
   match d with
   | BIND (x, ty) -> BIND (x, ty)
   | ASSERT e -> ASSERT (baseopt'expr e)
   | COMMENT txt -> COMMENT txt

(**************************************************************)
(* toplevel *)

let go decls =
   List.map baseopt'decl decls