diff src/scriptcheck.sml @ 2064:3dd041b00087

Extend ScriptCheck to take RPCs into account
author Adam Chlipala <adam@chlipala.net>
date Sun, 24 Aug 2014 11:43:49 -0400
parents a9159911c3ba
children 25874084bf1f
line wrap: on
line diff
--- a/src/scriptcheck.sml	Sat Aug 23 11:59:34 2014 +0000
+++ b/src/scriptcheck.sml	Sun Aug 24 11:43:49 2014 -0400
@@ -1,4 +1,4 @@
-(* Copyright (c) 2009, Adam Chlipala
+(* Copyright (c) 2009, 2014, Adam Chlipala
  * All rights reserved.
  *
  * Redistribution and use in source and binary forms, with or without
@@ -29,6 +29,10 @@
 
 open Mono
 
+structure SM = BinaryMapFn(struct
+                           type ord_key = string
+                           val compare = String.compare
+                           end)
 structure SS = BinarySetFn(struct
                            type ord_key = string
                            val compare = String.compare
@@ -39,37 +43,108 @@
                             ["new_channel",
                              "self"])
 
+datatype rpcmap =
+         Rpc of int (* ID of function definition *)
+       | Module of rpcmap SM.map
+
+fun lookup (r : rpcmap, k : string) =
+    let
+        fun lookup' (r, ks) =
+            case r of
+                Rpc x => SOME x
+              | Module m =>
+                case ks of
+                    [] => NONE
+                  | k :: ks' =>
+                    case SM.find (m, k) of
+                        NONE => NONE
+                      | SOME r' => lookup' (r', ks')
+    in
+        lookup' (r, String.tokens (fn ch => ch = #"/") k)
+    end
+
+fun insert (r : rpcmap, k : string, v) =
+    let
+        fun insert' (r, ks) =
+            case r of
+                Rpc _ => Rpc v
+              | Module m =>
+                case ks of
+                    [] => Rpc v
+                  | k :: ks' =>
+                    let
+                        val r' = case SM.find (m, k) of
+                                     NONE => Module SM.empty
+                                   | SOME r' => r'
+                    in
+                        Module (SM.insert (m, k, insert' (r', ks')))
+                    end
+    in
+        insert' (r, String.tokens (fn ch => ch = #"/") k)
+    end
+
+fun dump (r : rpcmap) =
+    case r of
+        Rpc _ => print "ROOT\n"
+      | Module m => (print "<Module>\n";
+                     SM.appi (fn (k, r') => (print (k ^ ":\n");
+                                             dump r')) m;
+                     print "</Module>\n")
+
 fun classify (ds, ps) =
     let
         val proto = Settings.currentProtocol ()
 
         fun inString {needle, haystack} = String.isSubstring needle haystack
 
-        fun hasClient {basis, funcs, push} =
+        fun hasClient {basis, rpcs, funcs, push} =
             MonoUtil.Exp.exists {typ = fn _ => false,
                                  exp = fn ERecv _ => push
                                         | EFfiApp ("Basis", x, _) => SS.member (basis, x) 
                                         | EJavaScript _ => not push
                                         | ENamed n => IS.member (funcs, n)
+                                        | EServerCall (e, _, _, _) =>
+                                          let
+                                              fun head (e : exp) =
+                                                  case #1 e of
+                                                      EStrcat (e1, _) => head e1
+                                                    | EPrim (Prim.String (_, s)) => SOME s
+                                                    | _ => NONE
+                                          in
+                                              case head e of
+                                                  NONE => true
+                                                | SOME fcall =>
+                                                  case lookup (rpcs, fcall) of
+                                                      NONE => true
+                                                    | SOME n => IS.member (funcs, n)
+                                          end
                                         | _ => false}
 
+        fun decl ((d, _), rpcs) =
+            case d of
+                DExport (Mono.Rpc _, fcall, n, _, _, _) =>
+                insert (rpcs, fcall, n)
+              | _ => rpcs
+
+        val rpcs = foldl decl (Module SM.empty) ds
+
         fun decl ((d, _), (pull_ids, push_ids)) =
             let
-                val hasClientPull = hasClient {basis = SS.empty, funcs = pull_ids, push = false}
-                val hasClientPush = hasClient {basis = pushBasis, funcs = push_ids, push = true}
+                val hasClientPull = hasClient {basis = SS.empty, rpcs = rpcs, funcs = pull_ids, push = false}
+                val hasClientPush = hasClient {basis = pushBasis, rpcs = rpcs, funcs = push_ids, push = true}
             in
                 case d of
                     DVal (_, n, _, e, _) => (if hasClientPull e then
-                                             IS.add (pull_ids, n)
-                                          else
-                                              pull_ids,
-                                          if hasClientPush e then
-                                              IS.add (push_ids, n)
-                                          else
-                                              push_ids)
+                                                 IS.add (pull_ids, n)
+                                             else
+                                                 pull_ids,
+                                             if hasClientPush e then
+                                                 IS.add (push_ids, n)
+                                             else
+                                                 push_ids)
                   | DValRec xes => (if List.exists (fn (_, _, _, e, _) => hasClientPull e) xes then
-                                       foldl (fn ((_, n, _, _, _), pull_ids) => IS.add (pull_ids, n))
-                                             pull_ids xes
+                                        foldl (fn ((_, n, _, _, _), pull_ids) => IS.add (pull_ids, n))
+                                              pull_ids xes
                                     else
                                         pull_ids,
                                     if List.exists (fn (_, _, _, e, _) => hasClientPush e) xes then