(*
 * Copyright (c) 2018-2019
 *	The President and Fellows of Harvard College.
 *
 * Written by David A. Holland.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 * 3. Neither the name of the University nor the names of its contributors
 *    may be used to endorse or promote products derived from this software
 *    without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE UNIVERSITY OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 *)

(* typechecker for smt representation *)

module I = Mintops
open Smt

type ctx = {
   errors: int ref;
   vars: typename VarMap.t ref;
}

let snivel _ctx e =
   prerr_string "smtcheck: in this expression:\n";
   prerr_string (Printsmt.print_expr "   " e ^ "\n")

let whine _ctx msg =
   prerr_string "smtcheck: ";
   prerr_string msg;
   prerr_newline ()

let choke ctx msg =
   whine ctx msg;
   ctx.errors := 1 + !(ctx.errors)

let string_of_typename ty = Printsmt.print_typename ty

(**************************************************************)
(* recursive pass *)

let mint_rettype op =
   match op with
   | I.EQ
   | I.ULT
   | I.SLT -> BOOL
   | _ -> MINT

let rec tc'expr ctx e =
   let req'bool e1 what =
      let ty'e1 = tc'expr ctx e1 in
      if ty'e1 <> BOOL then begin
         snivel ctx e;
         let got = string_of_typename ty'e1 in
         choke ctx (what ^ " is not boolean (got " ^ got ^ ")")
      end
   in
   let req'mint e1 what =
      let ty'e1 = tc'expr ctx e1 in
      if ty'e1 <> MINT then begin
         snivel ctx e;
         let got = string_of_typename ty'e1 in
         choke ctx (what ^ " is not a machine integer (got " ^ got ^ ")")
      end
   in
   let subwith x ty e1 =
      let oldctx = !(ctx.vars) in
      ctx.vars := VarMap.add x ty oldctx;
      let ty'e1 = tc'expr ctx e1 in
      ctx.vars := oldctx;
      ty'e1
   in

   match e with
   | TRUE
   | FALSE -> BOOL
   | XINTVAL (_, w) -> XINT w
   | MINTVAL _ -> MINT
   | BOOLARRAYVAL (_, _, kw) -> ARRAY (XINT kw, BOOL)
   | XINTARRAYVAL (_, _, kw, vw) -> ARRAY (XINT kw, XINT vw)
   | MINTARRAYVAL (_, _, kw) -> ARRAY (XINT kw, MINT)
   | READVAR name -> begin
        try
           VarMap.find name !(ctx.vars)
        with Not_found ->
           snivel ctx e;
           choke ctx ("Unbound variable " ^ string_of_varname name);
           MINT
     end
   | READARRAY (name, e1) ->
        let ty'a =
           try
              VarMap.find name !(ctx.vars)
           with Not_found ->
              snivel ctx e;
              choke ctx ("Unbound array " ^ string_of_varname name);
              ARRAY (XINT 32, MINT)
        in
        let ty'k, ty'v =
           match ty'a with
           | ARRAY (k, v) -> k, v
           | _ ->
                snivel ctx e;
                choke ctx ("Indexing non-array " ^ string_of_varname name);
                XINT 32, MINT
        in
        let ty'e1 = tc'expr ctx e1 in
        if ty'e1 <> ty'k then begin
           let name' = string_of_varname name in
           snivel ctx e;
           whine ctx ("Indexing array " ^ name' ^ " with wrong type");
           whine ctx ("Array type: " ^ string_of_typename ty'a);
           choke ctx ("Index type: " ^ string_of_typename ty'e1);
        end;
        ty'v
   | UPDATEARRAY (name, e1, e2) ->
        let ty'a =
           try
              VarMap.find name !(ctx.vars)
           with Not_found ->
              snivel ctx e;
              choke ctx ("Unbound array " ^ string_of_varname name);
              ARRAY (XINT 32, MINT)
        in
        let ty'k, ty'v =
           match ty'a with
           | ARRAY (k, v) -> k, v
           | _ ->
                snivel ctx e;
                choke ctx ("Indexing non-array " ^ string_of_varname name);
                XINT 32, MINT
        in
        let ty'e1 = tc'expr ctx e1 in
        if ty'e1 <> ty'k then begin
           let name' = string_of_varname name in
           snivel ctx e;
           whine ctx ("Updating array " ^ name' ^ " with wrong index type");
           whine ctx ("Array type: " ^ string_of_typename ty'a);
           choke ctx ("Index type: " ^ string_of_typename ty'e1);
        end;
        let ty'e2 = tc'expr ctx e2 in
        if ty'e2 <> ty'v then begin
           let name' = string_of_varname name in
           snivel ctx e;
           whine ctx ("Updating array " ^ name' ^ " with wrong value type");
           whine ctx ("Array type: " ^ string_of_typename ty'a);
           choke ctx ("Value type: " ^ string_of_typename ty'e2);
        end;
        ty'a
   | LET (x, ty, e1, e2) ->
        let ty'e1 = tc'expr ctx e1 in
        if ty'e1 <> ty then begin
           let x' = string_of_varname x in
           snivel ctx e;
           whine ctx ("Stored type in let expression (" ^ x' ^ ") is wrong");
           whine ctx ("Stored type: " ^ string_of_typename ty);
           choke ctx ("Computed type: " ^ string_of_typename ty'e1);
        end;
        subwith x ty e2
   | IF (c, t, f) ->
        req'bool c "Control expression of if";
        let ty't = tc'expr ctx t in
        let ty'f = tc'expr ctx f in
        if ty't <> ty'f then begin
           snivel ctx e;
           whine ctx ("Results of if do not match");
           whine ctx ("True case: " ^ string_of_typename ty't);
           choke ctx ("False case: " ^ string_of_typename ty'f)
        end;
        ty't
   | NOT e1 ->
        req'bool e1 "Argument of logical not";
        BOOL
   | AND (e1, e2) ->
        req'bool e1 "First argument of logical and";
        req'bool e2 "Second argument of logical and";
        BOOL
   | OR (e1, e2) ->
        req'bool e1 "First argument of logical or";
        req'bool e2 "Second argument of logical or";
        BOOL
   | XOR (e1, e2) ->
        req'bool e1 "First argument of logical xor";
        req'bool e2 "Second argument of logical xor";
        BOOL
   | EQ (e1, e2) ->
        let ty'e1 = tc'expr ctx e1 in
        let ty'e2 = tc'expr ctx e2 in
        if ty'e1 <> ty'e2 then begin
           snivel ctx e;
           whine ctx ("Arguments of equal do not match");
           whine ctx ("Left: " ^ string_of_typename ty'e1);
           choke ctx ("Right: " ^ string_of_typename ty'e2)
        end;
        if ty'e1 = MINT then begin
           snivel ctx e;
           choke ctx ("Logical == used on machine int value")
        end;
        BOOL
   | IMPLIES (e1, e2) ->
        req'bool e1 "First argument of logical implies";
        req'bool e2 "Second argument of logical implies";
        BOOL
   | FORALL (x, ty, e1) ->
        let x' = string_of_varname x in
        let ty'e1 = subwith x ty e1 in
        if ty'e1 <> BOOL then begin
           snivel ctx e;
           choke ctx ("Expression under forall " ^ x' ^ " is not boolean")
        end;
        ty'e1
   | EXISTS (x, ty, e1) ->
        let x' = string_of_varname x in
        let ty'e1 = subwith x ty e1 in
        if ty'e1 <> BOOL then begin
           snivel ctx e;
           choke ctx ("Expression under exists " ^ x' ^ " is not boolean")
        end;
        ty'e1
   | MINTUOP (op, e1) ->
        req'mint e1 ("Argument of " ^ I.name_of_uop op);
        MINT
   | MINTBOP (op, e1, e2) ->
        req'mint e1 ("First argument of " ^ I.name_of_bop op);
        req'mint e2 ("Second argument of " ^ I.name_of_bop op);
        mint_rettype op

let tc'decl ctx d =
   match d with
   | COMMENT _txt -> ()
   | BIND (name, ty) -> begin
        try
           let _ = VarMap.find name !(ctx.vars) in
           choke ctx ("Redeclaration of variable " ^ string_of_varname name)
        with Not_found ->
           ctx.vars := VarMap.add name ty !(ctx.vars)
     end
   | ASSERT e ->
        match tc'expr ctx e with
        | BOOL -> ()
        | ty ->
             snivel ctx e;
             let got = string_of_typename ty in
             choke ctx ("Assertion expression not boolean (got " ^ got ^ ")")

(**************************************************************)
(* external interface *)

let go decls =
   let ctx = {
      errors = ref 0;
      vars = ref VarMap.empty;
   } in
   List.iter (tc'decl ctx) decls;
   let errors = !(ctx.errors) in
   if errors > 0 then
      Util.crash (string_of_int errors ^ " fatal internal type errors")
