(*
 * Copyright (c) 2018-2019
 *	The President and Fellows of Harvard College.
 *
 * Written by David A. Holland.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 * 3. Neither the name of the University nor the names of its contributors
 *    may be used to endorse or promote products derived from this software
 *    without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE UNIVERSITY OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 *)

(*
 * Typechecker for logic expressions
 *)

open Symbolic
open Mipsstate
open Mips
module P = Printlogic

type ctx = {
   env: typename VarMap.t;
   topwhat: string;
}

let string_of_typename = P.string_of_typename

let is_machine_int ty =
   match ty with
   | UINT5
   | SINT8
   | UINT8
   | SINT16
   | UINT16
   | SINT32
   | UINT32 -> true
   | _ -> false

let err' ctx foundty xstr what =
   let foundstr = string_of_typename foundty in

   Util.say ("In " ^ ctx.topwhat ^ ":");
   Util.say ("Type error in " ^ what);
   Util.say ("This expression has type " ^ foundstr);
   Util.say ("The context expects type " ^ xstr);
   Util.crash "Aborting."

let err ctx foundty xty what =
   let xstr = string_of_typename xty in
   err' ctx foundty xstr what

let err2 ctx foundty xty1 xty2 what =
   let xstr1 = string_of_typename xty1 in
   let xstr2 = string_of_typename xty2 in
   err' ctx foundty (xstr1 ^ " or " ^ xstr2) what

let errmismatch ctx lty rty what =
   Util.say ("In " ^ ctx.topwhat ^ ":");
   Util.say ("Type mismatch in arguments of " ^ what ^ ":");
   Util.say ("Left-hand-side has type " ^ string_of_typename lty);
   Util.say ("Right-hand-side has type " ^ string_of_typename rty);
   Util.crash ("Aborting.")

let errint ctx compty recty what =
   Util.say ("In " ^ ctx.topwhat ^ ":");
   Util.say ("Internal type error " ^ what);
   Util.say ("Computed type: " ^ string_of_typename compty);
   Util.say ("Recorded type: " ^ string_of_typename recty);
   Util.crash "Whoops."

(**************************************************************)

let tc'intval ctx k ty =
   let range lo hi =
      if k < lo || k > hi then begin
         Util.say ("In " ^ ctx.topwhat ^ ":");
         Util.say "Out of bounds integer constant";
         Util.say ("Value: " ^ string_of_int k);
         Util.say ("Type: " ^ string_of_typename ty);
         Util.crash "Aborting."
      end
   in
   match ty with
   | UINT5 -> range 0 31
   | SINT8 -> range (-128) 127
   | UINT8 -> range 0 255
   | SINT16 -> range (-32768) 32767
   | UINT16 -> range 0 65535
   | SINT32 -> range (-2147483648) 2147483647
   | UINT32 -> range 0 4294967296
   | BOOL
   | CACHESTATE
   | REGISTER
   | CREGISTER
   | FREGISTER
   | FCREGISTER
   | HIREGISTER
   | LOREGISTER
   | GOLEM
   | WHICHCACHE
   | WHICHCACHEENTRYID
   | WHICHCACHEENTRY
   | PROGRAMPOINT
   | INSN
   | MAP _ ->
        let tystr = string_of_typename ty in
        Util.say ("In " ^ ctx.topwhat ^ ":");
        Util.crash ("Integer constant has invalid type " ^ tystr)

let tc'var ctx x ty how =
   try 
      let declty = VarMap.find x (ctx.env) in
      if declty <> ty then
         errint ctx declty ty (how ^ " " ^ string_of_varname x)
   with Not_found ->
      Util.say ("In " ^ ctx.topwhat ^ ":");
      Util.crash ("Unbound variable/symbol " ^ string_of_varname x)

(*
let tc'cachegolem ctx g =
   match g with
   | CONCI _ -> BOOL
   | SYMI sym -> tc'var ctx (CONTROLVAR sym) GOLEM "symbolic golem"; BOOL

let tc'stateelement ctx elem =
   let s x ty how res =
      tc'var ctx (None, x) ty ("symbolic " ^ how); res
   in
   match elem with
   | REG (SYM sym) -> s sym REGISTER "register" UINT32
   | CREG (SYM sym) -> s sym CREGISTER "control register" UINT32
   | FREG (SYM sym) -> s sym FREGISTER "floating register" UINT32
   | FCREG (SYM sym) -> s sym FCREGISTER "fp control register" UINT32
   | REG (CONC _)
   | CREG (CONC _)
   | FREG (CONC _)
   | FCREG (CONC _) -> UINT32
   | HI
   | LO -> UINT32
   | CACHEENTRY (SYM sc, SYM sce) ->
        tc'var ctx (None, sc) WHICHCACHE "cache";
        tc'var ctx (None, sce) WHICHCACHEENTRY "cache entry";
        CACHESTATE
   | CACHEENTRY _ -> CACHESTATE
*)

let tc'stateelement _ctx ty'elem =
   let fail msg =
      Util.crash msg
   in
   match ty'elem with
   | REGISTER -> UINT32
   | CREGISTER -> UINT32
   | FREGISTER -> UINT32
   | FCREGISTER -> UINT32
(* XXX there's currently no encoding for these
   | HI -> UINT32
   | LO -> UINT32
*)
   | WHICHCACHEENTRY -> CACHESTATE
   | WHICHCACHE -> fail "Cannot READSTATE a whole cache"
   | WHICHCACHEENTRYID -> fail "Cannot READSTATE a cacheentryid on its own"
   | _ -> fail ("Inappropriate type " ^ string_of_typename ty'elem ^
	  " for READSTATE")

let rec tc'expr ctx e =
   let checkrebind x what =
      try
         let _ = VarMap.find x (ctx.env) in ()
      with Not_found ->
         Util.say ("In " ^ ctx.topwhat ^ ":");
         Util.say ("Warning: variable " ^ string_of_varname x ^
		    " rebound in " ^ what)
   in
   let tc'two e1 e2 = (tc'expr ctx e1, tc'expr ctx e2) in
   let demand ty demandty what =
     if ty <> demandty then
        err ctx ty demandty what
   in
   let checkbool ty what =
      demand ty BOOL what
   in
   let booluop e1 what =
      let ty'e1 = tc'expr ctx e1 in
      checkbool ty'e1 ("argument of " ^ what);
      BOOL
   in
   let boolbop e1 e2 what =
      let (ty'e1, ty'e2) = tc'two e1 e2 in
      checkbool ty'e1 ("left-hand-side of " ^ what);
      checkbool ty'e2 ("right-hand-side of " ^ what);
      BOOL
   in

   let intuop e1 what =
      let ty'e1 = tc'expr ctx e1 in
      if ty'e1 <> SINT32 && ty'e1 <> UINT32 then
	 err2 ctx ty'e1 SINT32 UINT32 ("argument of " ^ what)
      ;
      ty'e1
   in

   let intbop extra e1 e2 what =
      let (ty'e1, ty'e2) = tc'two e1 e2 in
      if ty'e1 <> ty'e2 then
         errmismatch ctx ty'e1 ty'e2 what
      else ();
      let what' = "arguments of " ^ what in
      let _ =
	 match extra with
	 | Types.OIL xty ->
	      if ty'e1 <> xty then
		 err ctx ty'e1 xty what'
	 | Types.WATER (xty1, xty2) ->
	      if ty'e1 <> xty1 && ty'e1 <> xty2 then
		 err2 ctx ty'e1 xty1 xty2 what'
      in
      ty'e1
   in
   let sintbop = intbop (Types.OIL SINT32) in
   let uintbop = intbop (Types.OIL UINT32) in
   let pintbop = intbop (Types.WATER (SINT32, UINT32)) in

   let intshift extra e1 e2 what =
      let (ty'e1, ty'e2) = tc'two e1 e2 in
      if ty'e2 <> UINT5 then
         err ctx ty'e2 UINT5 ("right-hand-side of " ^ what)
      else ();
      let what' = "left-hand-side of " ^ what in
      let _ =
	 match extra with
	 | Types.OIL xty ->
	      if ty'e1 <> xty then
		 err ctx ty'e1 xty what'
	 | Types.WATER (xty1, xty2) ->
	      if ty'e1 <> xty1 && ty'e1 <> xty2 then
		 err2 ctx ty'e1 xty1 xty2 what'
      in
      ty'e1
   in
   let sintshift = intshift (Types.OIL SINT32) in
   let uintshift = intshift (Types.OIL UINT32) in
   let pintshift = intshift (Types.WATER (SINT32, UINT32)) in

   let widenop e1 ty rty =
      let ty'e1 = tc'expr ctx e1 in
      if ty'e1 <> ty then
         err ctx ty'e1 ty ("Widen operator for " ^ string_of_typename ty)
      ;
      rty
   in
   let truncop e1 ty rty =
      let ty'e1 = tc'expr ctx e1 in
      if ty'e1 <> ty then
         err ctx ty'e1 ty ("Truncate operator for " ^ string_of_typename rty)
      ;
      rty
   in

   match e with

     (* constants *)
   | TRUE -> BOOL
   | FALSE -> BOOL
   | CACHESTATEVALUE _ -> CACHESTATE

     (* constructors *)
   | INTVALUE (k, ty) ->
        tc'intval ctx k ty; ty
   | REGVALUE _ -> REGISTER
   | CREGVALUE _ -> CREGISTER
   | FREGVALUE _ -> FREGISTER
   | FCREGVALUE _ -> FCREGISTER
   | HIREGVALUE -> HIREGISTER
   | LOREGVALUE -> LOREGISTER
   | GOLEMVALUE _ -> GOLEM
   | WHICHCACHEVALUE _ -> WHICHCACHE
   | WHICHCACHEENTRYIDVALUE _ -> WHICHCACHEENTRYID
   | WHICHCACHEENTRYVALUE _ -> WHICHCACHEENTRY
   | PPVALUE _ -> PROGRAMPOINT
   | INSNVALUE _ -> INSN

     (* access to processor state, as of a particular program point *)
   | READSTATE (elem, _pp, _zs) ->
        let ty'elem = tc'expr ctx elem in
        tc'stateelement ctx ty'elem
   | READGOLEM g ->
        let ty'g = tc'expr ctx g in
        demand ty'g GOLEM ("reading golem");
        BOOL
(*
   | READALLSTATE (ty, _pp) -> begin
        match ty with
        | REGISTER
        | CREGISTER
        | FREGISTER
        | FCREGISTER -> MAP (ty, UINT32)
        | GOLEM -> MAP (ty, BOOL)
        | _ -> err' ctx ty "*register or golem" "readallstate"
     end
   | UPDATEMAP (m, k, v) ->
        let ty'm = tc'expr ctx m in
        let ty'k = tc'expr ctx k in
        let ty'v = tc'expr ctx v in
        begin match ty'm with
        | MAP (kt, vt) ->
             demand ty'k kt ("key in map update");
             demand ty'v vt ("value in map update");
        | _ -> err' ctx ty'm "map" "map update"
        end;
        ty'm
*)

     (* access to logic variables *)
   | READVAR (x, ty) -> tc'var ctx x ty ("using variable"); ty

     (* access to logic functions *)
   | GOLEMIZE (g, e1) ->
        let ty'g = tc'expr ctx g in
        demand ty'g BOOL ("golem function golem");
        let ty'e1 = tc'expr ctx e1 in
        demand ty'e1 UINT32 ("golem function argument");
        BOOL

     (* logical operators *)
   | FORALL (x, ty, e1) ->
        checkrebind x "forall";
        let ctx' = { ctx with env = VarMap.add x ty (ctx.env); } in
        let ty'e1 = tc'expr ctx' e1 in
        checkbool ty'e1 ("result of forall " ^ string_of_varname x);
        BOOL
   | EXISTS (x, ty, e1) ->
        checkrebind x "exists";
        let ctx' = { ctx with env = VarMap.add x ty (ctx.env); } in
        let ty'e1 = tc'expr ctx' e1 in
        checkbool ty'e1 ("result of forall " ^ string_of_varname x);
        BOOL
   | LET (x, ty, e1, e2) ->
        checkrebind x "let";
        let ty'e1 = tc'expr ctx e1 in
        if ty'e1 <> ty then
           errint ctx ty'e1 ty ("let-binding " ^ string_of_varname x)
        else ();
        let ctx' = { ctx with env = VarMap.add x ty (ctx.env); } in
        tc'expr ctx' e2
   | IF (c, t, f) ->
        let ty'c = tc'expr ctx c in
        let ty't = tc'expr ctx t in
        let ty'f = tc'expr ctx f in
        checkbool ty'c "condition of if";
        demand ty'f ty't "second branch of if";
        ty't
   | IMPLIES (e1, e2) -> boolbop e1 e2 "->"
   | LOGNOT e1 -> booluop e1 "!"
   | LOGAND (e1, e2) -> boolbop e1 e2 "&&"
   | LOGOR (e1, e2) -> boolbop e1 e2 "||"
   | LOGXOR (e1, e2) -> boolbop e1 e2 "^^"
   | LOGEQ (e1, e2) ->
        let ty'e1 = tc'expr ctx e1 in
        let ty'e2 = tc'expr ctx e2 in
        if ty'e1 <> ty'e2 then
           errmismatch ctx ty'e1 ty'e2 "==x"
        else ();
        if is_machine_int ty'e1 then begin
           Util.say ("In " ^ ctx.topwhat ^ ":");
           Util.say ("Logic equality used on machine integer type");
           Util.say ("Type: " ^ string_of_typename ty'e1);
           Util.crash ("Internal error; aborting.")
        end;
        BOOL

     (* other operators *)
   | MINTBOP (EQ, e1, e2) ->
        let ty'e1 = tc'expr ctx e1 in
        let ty'e2 = tc'expr ctx e2 in
        if ty'e1 <> ty'e2 then
           errmismatch ctx ty'e1 ty'e2 "==m"
        else ();
        if not (is_machine_int ty'e1) then begin
           Util.say ("In " ^ ctx.topwhat ^ ":");
           Util.say ("Machine integer equality used on logic type");
           Util.say ("Type: " ^ string_of_typename ty'e1);
           Util.crash ("Internal error; aborting.")
        end;
        BOOL
   | MINTBOP (ULT, e1, e2) -> let _ = uintbop e1 e2 "<u" in BOOL
   | MINTBOP (SLT, e1, e2) -> let _ = sintbop e1 e2 "<s" in BOOL
   | MINTBOP (ADD, e1, e2) -> pintbop e1 e2 "+"
   | MINTBOP (SUB, e1, e2) -> pintbop e1 e2 "-"
   | MINTBOP (SMUL_HI, e1, e2) -> sintbop e1 e2 "*sh"
   | MINTBOP (UMUL_HI, e1, e2) -> uintbop e1 e2 "*uh"
   | MINTBOP (MUL_LO, e1, e2) -> let _ = pintbop e1 e2 "*l" in UINT32
   | MINTBOP (SDIV, e1, e2) -> sintbop e1 e2 "/s"
   | MINTBOP (SMOD, e1, e2) -> sintbop e1 e2 "%s"
   | MINTBOP (UDIV, e1, e2) -> uintbop e1 e2 "/u"
   | MINTBOP (UMOD, e1, e2) -> uintbop e1 e2 "%u"
   | MINTUOP (BITNOT, e1) -> intuop e1 "~"
   | MINTBOP (BITAND, e1, e2) -> pintbop e1 e2 "&"
   | MINTBOP (BITOR, e1, e2) -> pintbop e1 e2 "|"
   | MINTBOP (BITXOR, e1, e2) -> pintbop e1 e2 "^"
   | MINTBOP (SHL, e1, e2) -> pintshift e1 e2 "<<"
   | MINTBOP (USHR, e1, e2) -> uintshift e1 e2 ">>u"
   | MINTBOP (SSHR, e1, e2) -> sintshift e1 e2 ">>s"
     (* these all widen to 32 bit *)
   | MINTUOP (SBWIDEN, e1) -> widenop e1 SINT8 SINT32
   | MINTUOP (SHWIDEN, e1) -> widenop e1 SINT16 SINT32
   | MINTUOP (UBWIDEN, e1) -> widenop e1 UINT8 UINT32
   | MINTUOP (UHWIDEN, e1) -> widenop e1 UINT16 UINT32
   | MINTUOP (U5WIDEN, e1) -> widenop e1 UINT5 UINT32
     (* u32 -> u5 *)
   | MINTUOP (TRUNCU5, e1) -> truncop e1 UINT32 UINT5
     (* these only apply to 32 bit *)
   | MINTUOP (TOSIGNED, e1) -> 
        let ty'e1 = tc'expr ctx e1 in
        demand ty'e1 UINT32 "tosigned";
        SINT32
   | MINTUOP (TOUNSIGNED, e1) -> 
        let ty'e1 = tc'expr ctx e1 in
        demand ty'e1 SINT32 "tounsigned";
        UINT32

let go env e xty topwhat =
   let ctx = { env; topwhat; } in

   let ty = tc'expr ctx e in
   if ty <> xty then
      err ctx ty xty ("overall result")
   else ()

