diff src/cjr_print.sml @ 734:f2a2be93331c

Cookie signing working for forms
author Adam Chlipala <adamc@hcoop.net>
date Thu, 16 Apr 2009 19:12:12 -0400
parents e0dd85ea58e1
children 5ccb67665d05
line wrap: on
line diff
--- a/src/cjr_print.sml	Thu Apr 16 15:38:01 2009 -0400
+++ b/src/cjr_print.sml	Thu Apr 16 19:12:12 2009 -0400
@@ -2198,6 +2198,26 @@
         (TOption _, _) => false
       | _ => true
 
+fun sigName fields =
+    let
+        fun inFields s = List.exists (fn (s', _) => s' = s) fields
+
+        fun getSigName n =
+            let
+                val s = "Sig" ^ Int.toString n
+            in
+                if inFields s then
+                    getSigName (n + 1)
+                else
+                    s
+            end
+    in
+        if inFields "Sig" then
+            getSigName 0
+        else
+            "Sig"
+    end
+
 fun p_file env (ds, ps) =
     let
         val (pds, env) = ListUtil.foldlMap (fn (d, env) =>
@@ -2214,6 +2234,7 @@
                                        (TRecord i, _) =>
                                        let
                                            val xts = E.lookupStruct env i
+                                           val xts = (sigName xts, (TRecord 0, ErrorMsg.dummySpan)) :: xts
                                            val xtsSet = SS.addList (SS.empty, map #1 xts)
                                        in
                                            foldl (fn ((x, _), fields) =>
@@ -2245,6 +2266,8 @@
                                   end)
                     SM.empty fields
 
+        val cookies = List.mapPartial (fn (DCookie s, _) => SOME s | _ => NONE) ds
+
         fun makeSwitch (fnums, i) =
             case SM.foldl (fn (n, NotFound) => Found n
                             | (n, Error) => Error
@@ -2328,10 +2351,10 @@
 
         fun p_page (ek, s, n, ts, ran, side) =
             let
-                val (ts, defInputs, inputsVar) =
+                val (ts, defInputs, inputsVar, fields) =
                     case ek of
-                        Core.Link => (List.take (ts, length ts - 1), string "", string "")
-                      | Core.Rpc _ => (List.take (ts, length ts - 1), string "", string "")
+                        Core.Link => (List.take (ts, length ts - 1), string "", string "", NONE)
+                      | Core.Rpc _ => (List.take (ts, length ts - 1), string "", string "", NONE)
                       | Core.Action _ =>
                         case List.nth (ts, length ts - 2) of
                             (TRecord i, _) =>
@@ -2392,12 +2415,43 @@
                                       newline],
                                  box [string ",",
                                       space,
-                                      string "uw_inputs"])
+                                      string "uw_inputs"],
+                                 SOME xts)
                             end
 
                           | _ => raise Fail "CjrPrint: Last argument to an action isn't a record"
+
+                fun couldWrite ek =
+                    case ek of
+                        Link => false
+                      | Action ef => ef = ReadWrite
+                      | Rpc ef => ef = ReadWrite
             in
-                box [string "if (!strncmp(request, \"",
+                box [if couldWrite ek then
+                         box [string "{",
+                              newline,
+                              string "uw_Basis_string sig = ",
+                              case fields of
+                                  NONE => string "uw_Basis_requestHeader(ctx, \"UrWeb-Sig\")"
+                                | SOME fields =>
+                                  case SM.find (fnums, sigName fields) of
+                                      NONE => raise Fail "CjrPrint: sig name wasn't assigned a number"
+                                    | SOME inum =>
+                                      string ("uw_get_input(ctx, " ^ Int.toString inum ^ ")"),
+                              string ";",
+                              newline,
+                              string "if (sig == NULL) uw_error(ctx, FATAL, \"Missing cookie signature\");",
+                              newline,
+                              string "if (strcmp(sig, uw_cookie_sig(ctx)))",
+                              newline,
+                              box [string "uw_error(ctx, FATAL, \"Wrong cookie signature\");",
+                                   newline],
+                              string "}",
+                              newline]
+                     else
+                         box [],
+                     
+                     string "if (!strncmp(request, \"",
                      string (String.toString s),
                      string "\", ",
                      string (Int.toString (size s)),
@@ -2745,6 +2799,18 @@
                  string "}"]
 
         val hasDb = List.exists (fn (DDatabase _, _) => true | _ => false) ds
+
+        val cookies = List.mapPartial (fn (DCookie s, _) => SOME s | _ => NONE) ds
+
+        val cookieCode = foldl (fn (cookie, acc) =>
+                                   SOME (case acc of
+                                             NONE => string ("uw_unnull(uw_Basis_get_cookie(ctx, \""
+                                                             ^ cookie ^ "\"))")
+                                           | SOME acc => box [string ("uw_Basis_strcat(ctx, uw_unnull(uw_Basis_get_cookie(ctx, \""
+                                                                      ^ cookie ^ "\")), uw_Basis_strcat(ctx, \"/\", "),
+                                                              acc,
+                                                              string "))"]))
+                         NONE cookies
     in
         box [string "#include <stdio.h>",
              newline,
@@ -2783,6 +2849,27 @@
              string "}",
              newline,
              newline,
+             
+             string "extern void uw_sign(const char *in, char *out);",
+             newline,
+             string "extern int uw_hash_blocksize;",
+             newline,
+             string "uw_Basis_string uw_cookie_sig(uw_context ctx) {",
+             newline,
+             box [string "uw_Basis_string r = uw_malloc(ctx, uw_hash_blocksize);",
+                  newline,
+                  string "uw_sign(",
+                  case cookieCode of
+                      NONE => string "\"\""
+                    | SOME code => code,
+                  string ", r);",
+                  newline,
+                  string "return uw_Basis_makeSigString(ctx, r);",
+                  newline],
+             string "}",
+             newline,
+             newline,
+
              string "void uw_handle(uw_context ctx, char *request) {",
              newline,
              string "if (!strcmp(request, \"/app.js\")) {",