changeset 735:5ccb67665d05

Only use cookie signatures when cookies might be read
author Adam Chlipala <adamc@hcoop.net>
date Thu, 23 Apr 2009 14:10:10 -0400
parents f2a2be93331c
children 796e42c93c48
files src/cjr_print.sml src/effectize.sml src/export.sig src/export.sml src/monoize.sml
diffstat 5 files changed, 137 insertions(+), 77 deletions(-) [+]
line wrap: on
line diff
--- a/src/cjr_print.sml	Thu Apr 16 19:12:12 2009 -0400
+++ b/src/cjr_print.sml	Thu Apr 23 14:10:10 2009 -0400
@@ -2227,14 +2227,17 @@
 
         val fields = foldl (fn ((ek, _, _, ts, _, _), fields) =>
                                case ek of
-                                   Core.Link => fields
-                                 | Core.Rpc _ => fields
-                                 | Core.Action _ =>
+                                   Link => fields
+                                 | Rpc _ => fields
+                                 | Action eff =>
                                    case List.nth (ts, length ts - 2) of
                                        (TRecord i, _) =>
                                        let
                                            val xts = E.lookupStruct env i
-                                           val xts = (sigName xts, (TRecord 0, ErrorMsg.dummySpan)) :: xts
+                                           val xts = case eff of
+                                                         ReadCookieWrite =>
+                                                         (sigName xts, (TRecord 0, ErrorMsg.dummySpan)) :: xts
+                                                       | _ => xts
                                            val xtsSet = SS.addList (SS.empty, map #1 xts)
                                        in
                                            foldl (fn ((x, _), fields) =>
@@ -2424,10 +2427,26 @@
                 fun couldWrite ek =
                     case ek of
                         Link => false
-                      | Action ef => ef = ReadWrite
-                      | Rpc ef => ef = ReadWrite
+                      | Action ef => ef = ReadCookieWrite
+                      | Rpc ef => ef = ReadCookieWrite
             in
-                box [if couldWrite ek then
+                box [string "if (!strncmp(request, \"",
+                     string (String.toString s),
+                     string "\", ",
+                     string (Int.toString (size s)),
+                     string ") && (request[",
+                     string (Int.toString (size s)),
+                     string "] == 0 || request[",
+                     string (Int.toString (size s)),
+                     string "] == '/')) {",
+                     newline,
+                     string "request += ",
+                     string (Int.toString (size s)),
+                     string ";",
+                     newline,
+                     string "if (*request == '/') ++request;",
+                     newline,
+                     if couldWrite ek then
                          box [string "{",
                               newline,
                               string "uw_Basis_string sig = ",
@@ -2450,23 +2469,6 @@
                               newline]
                      else
                          box [],
-                     
-                     string "if (!strncmp(request, \"",
-                     string (String.toString s),
-                     string "\", ",
-                     string (Int.toString (size s)),
-                     string ") && (request[",
-                     string (Int.toString (size s)),
-                     string "] == 0 || request[",
-                     string (Int.toString (size s)),
-                     string "] == '/')) {",
-                     newline,
-                     string "request += ",
-                     string (Int.toString (size s)),
-                     string ";",
-                     newline,
-                     string "if (*request == '/') ++request;",
-                     newline,
                      box (case ek of
                               Core.Rpc _ => [string "uw_write_header(ctx, \"Content-type: text/plain\\r\\n\");",
                                              newline]
--- a/src/effectize.sml	Thu Apr 16 19:12:12 2009 -0400
+++ b/src/effectize.sml	Thu Apr 23 14:10:10 2009 -0400
@@ -37,7 +37,7 @@
                            val compare = String.compare
                            end)
 
-val effectful = ["dml", "nextval", "send"]
+val effectful = ["dml", "nextval", "send", "setCookie"]
 val effectful = SS.addList (SS.empty, effectful)
 
 fun effectize file =
@@ -54,21 +54,47 @@
                                            con = fn _ => false,
                                            exp = exp evs}
 
-        fun doDecl (d, evs) =
+        fun exp evs e =
+            case e of
+                EFfi ("Basis", "getCookie") => true
+              | ENamed n => IM.inDomain (evs, n)
+              | EServerCall (n, _, _, _) => IM.inDomain (evs, n)
+              | _ => false
+
+        fun couldReadCookie evs = U.Exp.exists {kind = fn _ => false,
+                                                con = fn _ => false,
+                                                exp = exp evs}
+
+        fun doDecl (d, evs as (writers, readers)) =
             case #1 d of
                 DVal (x, n, t, e, s) =>
-                (d, if couldWrite evs e then
-                        IM.insert (evs, n, (#2 d, s))
-                    else
-                        evs)
+                (d, (if couldWrite writers e then
+                         IM.insert (writers, n, (#2 d, s))
+                     else
+                         writers,
+                     if couldReadCookie readers e then
+                         IM.insert (readers, n, (#2 d, s))
+                     else
+                         readers))
               | DValRec vis =>
                 let
                     fun oneRound evs =
-                        foldl (fn ((_, n, _, e, s), (changed, evs)) =>
-                                if couldWrite evs e andalso not (IM.inDomain (evs, n)) then
-                                    (true, IM.insert (evs, n, (#2 d, s)))
-                                else
-                                    (changed, evs)) (false, evs) vis
+                        foldl (fn ((_, n, _, e, s), (changed, (writers, readers))) =>
+                                  let
+                                      val (changed, writers) =
+                                          if couldWrite writers e andalso not (IM.inDomain (writers, n)) then
+                                              (true, IM.insert (writers, n, (#2 d, s)))
+                                          else
+                                              (changed, writers)
+
+                                      val (changed, readers) =
+                                          if couldReadCookie readers e andalso not (IM.inDomain (readers, n)) then
+                                              (true, IM.insert (readers, n, (#2 d, s)))
+                                          else
+                                              (changed, readers)
+                                  in
+                                      (changed, (writers, readers))
+                                  end) (false, evs) vis
 
                     fun loop evs =
                         let
@@ -80,28 +106,34 @@
                                 evs
                         end
                 in
-                    (d, loop evs)
+                    (d, loop (writers, readers))
                 end
               | DExport (Link, n) =>
-                (case IM.find (evs, n) of
+                (case IM.find (writers, n) of
                      NONE => ()
                    | SOME (loc, s) => ErrorMsg.errorAt loc ("A link (" ^ s ^ ") could cause side effects; try implementing it with a form instead");
                  (d, evs))
               | DExport (Action _, n) =>
-                ((DExport (Action (if IM.inDomain (evs, n) then
-                                       ReadWrite
+                ((DExport (Action (if IM.inDomain (writers, n) then
+                                       if IM.inDomain (readers, n) then
+                                           ReadCookieWrite
+                                       else
+                                           ReadWrite
                                    else
                                        ReadOnly), n), #2 d),
                  evs)
               | DExport (Rpc _, n) =>
-                ((DExport (Rpc (if IM.inDomain (evs, n) then
-                                    ReadWrite
+                ((DExport (Rpc (if IM.inDomain (writers, n) then
+                                    if IM.inDomain (readers, n) then
+                                        ReadCookieWrite
+                                    else
+                                        ReadWrite
                                 else
                                     ReadOnly), n), #2 d),
                  evs)
               | _ => (d, evs)
 
-        val (file, _) = ListUtil.foldlMap doDecl IM.empty file
+        val (file, _) = ListUtil.foldlMap doDecl (IM.empty, IM.empty) file
     in
         file
     end
--- a/src/export.sig	Thu Apr 16 19:12:12 2009 -0400
+++ b/src/export.sig	Thu Apr 23 14:10:10 2009 -0400
@@ -29,6 +29,7 @@
 
 datatype effect =
          ReadOnly
+       | ReadCookieWrite
        | ReadWrite
 
 datatype export_kind =
--- a/src/export.sml	Thu Apr 16 19:12:12 2009 -0400
+++ b/src/export.sml	Thu Apr 23 14:10:10 2009 -0400
@@ -25,13 +25,14 @@
  * POSSIBILITY OF SUCH DAMAGE.
  *)
 
-structure Export = struct
+structure Export :> EXPORT = struct
 
 open Print.PD
 open Print
 
 datatype effect =
          ReadOnly
+       | ReadCookieWrite
        | ReadWrite
 
 datatype export_kind =
@@ -42,6 +43,7 @@
 fun p_effect ef =
     case ef of
         ReadOnly => string "r"
+      | ReadCookieWrite => string "rcw"
       | ReadWrite => string "rw"
 
 fun p_export_kind ck =
--- a/src/monoize.sml	Thu Apr 16 19:12:12 2009 -0400
+++ b/src/monoize.sml	Thu Apr 23 14:10:10 2009 -0400
@@ -34,6 +34,7 @@
 structure L' = Mono
 
 structure IM = IntBinaryMap
+structure IS = IntBinarySet
 
 val urlPrefix = ref "/"
 
@@ -538,6 +539,8 @@
 
 fun strcatR loc e xs = strcatComma loc (map (fn (x, _) => (L'.EField (e, x), loc)) xs)
 
+val readCookie = ref IS.empty
+
 fun monoExp (env, st, fm) (all as (e, loc)) =
     let
         val strcat = strcat loc
@@ -2453,53 +2456,64 @@
                            | _ => findSubmit xml)
                       | _ => NotFound
 
-                val (action, fm) = case findSubmit xml of
-                    NotFound => ((L'.EPrim (Prim.String ""), loc), fm)
+                val (func, action, fm) = case findSubmit xml of
+                    NotFound => (0, (L'.EPrim (Prim.String ""), loc), fm)
                   | Error => raise Fail "Not ready for multi-submit lforms yet"
                   | Found (action, actionT) =>
                     let
+                        val func = case #1 action of
+                                       L.EClosure (n, _) => n
+                                     | _ => raise Fail "Monoize: Action is not a closure"
                         val actionT = monoType env actionT
                         val (action, fm) = monoExp (env, st, fm) action
                         val (action, fm) = urlifyExp env fm (action, actionT)
                     in
-                        ((L'.EStrcat ((L'.EPrim (Prim.String " action=\""), loc),
+                        (func,
+                         (L'.EStrcat ((L'.EPrim (Prim.String " action=\""), loc),
                                       (L'.EStrcat (action,
                                                    (L'.EPrim (Prim.String "\""), loc)), loc)), loc),
                          fm)
                     end
-                
-                fun inFields s = List.exists (fn ((L.CName s', _), _) => s' = s
-                                               | _ => true) fields
-
-                fun getSigName () =
-                    let
-                        fun getSigName' n =
-                            let
-                                val s = "Sig" ^ Int.toString n
-                            in
-                                if inFields s then
-                                    getSigName' (n + 1)
-                                else
-                                    s
-                            end
-                    in
-                        if inFields "Sig" then
-                            getSigName' 0
-                        else
-                            "Sig"
-                    end
-
-                val sigName = getSigName ()
-                val sigSet = (L'.EFfiApp ("Basis", "sigString", [(L'.ERecord [], loc)]), loc)
-                val sigSet = (L'.EStrcat ((L'.EPrim (Prim.String ("<input type=\"hidden\" name=\""
-                                                                  ^ sigName
-                                                                  ^ "\" value=\"")), loc),
-                                          sigSet), loc)
-                val sigSet = (L'.EStrcat (sigSet,
-                                          (L'.EPrim (Prim.String "\">"), loc)), loc)
 
                 val (xml, fm) = monoExp (env, st, fm) xml
-                val xml = (L'.EStrcat (sigSet, xml), loc)
+
+                val xml =
+                    if IS.member (!readCookie, func) then
+                        let
+                            fun inFields s = List.exists (fn ((L.CName s', _), _) => s' = s
+                                                           | _ => true) fields
+
+                            fun getSigName () =
+                                let
+                                    fun getSigName' n =
+                                        let
+                                            val s = "Sig" ^ Int.toString n
+                                        in
+                                            if inFields s then
+                                                getSigName' (n + 1)
+                                            else
+                                                s
+                                        end
+                                in
+                                    if inFields "Sig" then
+                                        getSigName' 0
+                                    else
+                                        "Sig"
+                                end
+
+                            val sigName = getSigName ()
+                            val sigSet = (L'.EFfiApp ("Basis", "sigString", [(L'.ERecord [], loc)]), loc)
+                            val sigSet = (L'.EStrcat ((L'.EPrim (Prim.String ("<input type=\"hidden\" name=\""
+                                                                              ^ sigName
+                                                                              ^ "\" value=\"")), loc),
+                                                      sigSet), loc)
+                            val sigSet = (L'.EStrcat (sigSet,
+                                                      (L'.EPrim (Prim.String "\">"), loc)), loc)
+                        in
+                            (L'.EStrcat (sigSet, xml), loc)
+                        end
+                    else
+                        xml
             in
                 ((L'.EStrcat ((L'.EStrcat ((L'.EPrim (Prim.String "<form method=\"post\""), loc),
                                            (L'.EStrcat (action,
@@ -2793,6 +2807,15 @@
             else
                 ()
 
+        (* Calculate which exported functions need cookie signature protection *)
+        val rcook = foldl (fn ((d, _), rcook) =>
+                              case d of
+                                  L.DExport (L.Action L.ReadCookieWrite, n) => IS.add (rcook, n)
+                                | L.DExport (L.Rpc L.ReadCookieWrite, n) => IS.add (rcook, n)
+                                | _ => rcook)
+                          IS.empty file
+        val () = readCookie := rcook
+
         val loc = E.dummySpan
         val client = (L'.TFfi ("Basis", "client"), loc)
         val unit = (L'.TRecord [], loc)