changeset 734:f2a2be93331c

Cookie signing working for forms
author Adam Chlipala <adamc@hcoop.net>
date Thu, 16 Apr 2009 19:12:12 -0400
parents 15ddd64a5113
children 5ccb67665d05
files demo/cookie.urp include/urweb.h src/c/driver.c src/c/urweb.c src/cjr_print.sml src/compiler.sml src/monoize.sml tests/cookieSec.ur tests/cookieSec.urp tests/cookieSec.urs
diffstat 10 files changed, 230 insertions(+), 10 deletions(-) [+]
line wrap: on
line diff
--- a/demo/cookie.urp	Thu Apr 16 15:38:01 2009 -0400
+++ b/demo/cookie.urp	Thu Apr 16 19:12:12 2009 -0400
@@ -1,2 +1,3 @@
+debug
 
 cookie
--- a/include/urweb.h	Thu Apr 16 15:38:01 2009 -0400
+++ b/include/urweb.h	Thu Apr 16 19:12:12 2009 -0400
@@ -152,3 +152,7 @@
 uw_Basis_client uw_Basis_self(uw_context, uw_unit);
 
 uw_Basis_string uw_Basis_bless(uw_context, uw_Basis_string);
+
+uw_Basis_string uw_unnull(uw_Basis_string);
+uw_Basis_string uw_Basis_makeSigString(uw_context, uw_Basis_string);
+uw_Basis_string uw_Basis_sigString(uw_context, uw_unit);
--- a/src/c/driver.c	Thu Apr 16 15:38:01 2009 -0400
+++ b/src/c/driver.c	Thu Apr 16 19:12:12 2009 -0400
@@ -10,6 +10,8 @@
 
 #include <pthread.h>
 
+#include <mhash.h>
+
 #include "urweb.h"
 
 int uw_backlog = 10;
@@ -102,6 +104,46 @@
   return ctx;
 }
 
+#define KEYSIZE 16
+#define PASSSIZE 4
+
+#define HASH_ALGORITHM MHASH_SHA256
+#define HASH_BLOCKSIZE 32
+#define KEYGEN_ALGORITHM KEYGEN_MCRYPT
+
+int uw_hash_blocksize = HASH_BLOCKSIZE;
+
+static int password[PASSSIZE];
+static unsigned char private_key[KEYSIZE];
+
+static void init_crypto() {
+  KEYGEN kg = {{HASH_ALGORITHM, HASH_ALGORITHM}};
+  int i;
+
+  assert(mhash_get_block_size(HASH_ALGORITHM) == HASH_BLOCKSIZE);
+
+  for (i = 0; i < PASSSIZE; ++i)
+    password[i] = rand();
+
+  if (mhash_keygen_ext(KEYGEN_ALGORITHM, kg,
+                       private_key, sizeof(private_key),
+                       (unsigned char*)password, sizeof(password)) < 0) {
+    printf("Key generation failed\n");
+    exit(1);
+  }
+}
+
+void uw_sign(const char *in, char *out) {
+  MHASH td;
+
+  td = mhash_hmac_init(HASH_ALGORITHM, private_key, sizeof(private_key),
+                       mhash_get_hash_pblock(HASH_ALGORITHM));
+  
+  mhash(td, in, strlen(in));
+  if (mhash_hmac_deinit(td, out) < 0)
+    printf("Signing failed");
+}
+
 static void *worker(void *data) {
   int me = *(int *)data, retries_left = MAX_RETRIES;
   uw_context ctx = new_context();
@@ -344,9 +386,13 @@
 }
 
 static void initialize() {
-  uw_context ctx = new_context();
+  uw_context ctx;
   failure_kind fk;
 
+  init_crypto();
+
+  ctx = new_context();
+
   if (!ctx)
     exit(1);
 
@@ -411,6 +457,7 @@
     }
   }
 
+  uw_global_init();
   initialize();
 
   names = calloc(nthreads, sizeof(int));
@@ -444,8 +491,6 @@
 
   sin_size = sizeof their_addr;
 
-  uw_global_init();
-
   printf("Listening on port %d....\n", uw_port);
 
   {
--- a/src/c/urweb.c	Thu Apr 16 15:38:01 2009 -0400
+++ b/src/c/urweb.c	Thu Apr 16 19:12:12 2009 -0400
@@ -1981,3 +1981,25 @@
 uw_Basis_string uw_Basis_bless(uw_context ctx, uw_Basis_string s) {
   return s;
 }
+
+uw_Basis_string uw_unnull(uw_Basis_string s) {
+  return s ? s : "";
+}
+
+extern int uw_hash_blocksize;
+
+uw_Basis_string uw_Basis_makeSigString(uw_context ctx, uw_Basis_string sig) {
+  uw_Basis_string r = uw_malloc(ctx, 2 * uw_hash_blocksize + 1);
+  int i;
+  
+  for (i = 0; i < uw_hash_blocksize; ++i)
+    sprintf(&r[2*i], "%.02X", ((unsigned char *)sig)[i]);
+
+  return r;
+}
+
+extern uw_Basis_string uw_cookie_sig(uw_context);
+
+uw_Basis_string uw_Basis_sigString(uw_context ctx, uw_unit u) {
+  return uw_cookie_sig(ctx);
+}
--- a/src/cjr_print.sml	Thu Apr 16 15:38:01 2009 -0400
+++ b/src/cjr_print.sml	Thu Apr 16 19:12:12 2009 -0400
@@ -2198,6 +2198,26 @@
         (TOption _, _) => false
       | _ => true
 
+fun sigName fields =
+    let
+        fun inFields s = List.exists (fn (s', _) => s' = s) fields
+
+        fun getSigName n =
+            let
+                val s = "Sig" ^ Int.toString n
+            in
+                if inFields s then
+                    getSigName (n + 1)
+                else
+                    s
+            end
+    in
+        if inFields "Sig" then
+            getSigName 0
+        else
+            "Sig"
+    end
+
 fun p_file env (ds, ps) =
     let
         val (pds, env) = ListUtil.foldlMap (fn (d, env) =>
@@ -2214,6 +2234,7 @@
                                        (TRecord i, _) =>
                                        let
                                            val xts = E.lookupStruct env i
+                                           val xts = (sigName xts, (TRecord 0, ErrorMsg.dummySpan)) :: xts
                                            val xtsSet = SS.addList (SS.empty, map #1 xts)
                                        in
                                            foldl (fn ((x, _), fields) =>
@@ -2245,6 +2266,8 @@
                                   end)
                     SM.empty fields
 
+        val cookies = List.mapPartial (fn (DCookie s, _) => SOME s | _ => NONE) ds
+
         fun makeSwitch (fnums, i) =
             case SM.foldl (fn (n, NotFound) => Found n
                             | (n, Error) => Error
@@ -2328,10 +2351,10 @@
 
         fun p_page (ek, s, n, ts, ran, side) =
             let
-                val (ts, defInputs, inputsVar) =
+                val (ts, defInputs, inputsVar, fields) =
                     case ek of
-                        Core.Link => (List.take (ts, length ts - 1), string "", string "")
-                      | Core.Rpc _ => (List.take (ts, length ts - 1), string "", string "")
+                        Core.Link => (List.take (ts, length ts - 1), string "", string "", NONE)
+                      | Core.Rpc _ => (List.take (ts, length ts - 1), string "", string "", NONE)
                       | Core.Action _ =>
                         case List.nth (ts, length ts - 2) of
                             (TRecord i, _) =>
@@ -2392,12 +2415,43 @@
                                       newline],
                                  box [string ",",
                                       space,
-                                      string "uw_inputs"])
+                                      string "uw_inputs"],
+                                 SOME xts)
                             end
 
                           | _ => raise Fail "CjrPrint: Last argument to an action isn't a record"
+
+                fun couldWrite ek =
+                    case ek of
+                        Link => false
+                      | Action ef => ef = ReadWrite
+                      | Rpc ef => ef = ReadWrite
             in
-                box [string "if (!strncmp(request, \"",
+                box [if couldWrite ek then
+                         box [string "{",
+                              newline,
+                              string "uw_Basis_string sig = ",
+                              case fields of
+                                  NONE => string "uw_Basis_requestHeader(ctx, \"UrWeb-Sig\")"
+                                | SOME fields =>
+                                  case SM.find (fnums, sigName fields) of
+                                      NONE => raise Fail "CjrPrint: sig name wasn't assigned a number"
+                                    | SOME inum =>
+                                      string ("uw_get_input(ctx, " ^ Int.toString inum ^ ")"),
+                              string ";",
+                              newline,
+                              string "if (sig == NULL) uw_error(ctx, FATAL, \"Missing cookie signature\");",
+                              newline,
+                              string "if (strcmp(sig, uw_cookie_sig(ctx)))",
+                              newline,
+                              box [string "uw_error(ctx, FATAL, \"Wrong cookie signature\");",
+                                   newline],
+                              string "}",
+                              newline]
+                     else
+                         box [],
+                     
+                     string "if (!strncmp(request, \"",
                      string (String.toString s),
                      string "\", ",
                      string (Int.toString (size s)),
@@ -2745,6 +2799,18 @@
                  string "}"]
 
         val hasDb = List.exists (fn (DDatabase _, _) => true | _ => false) ds
+
+        val cookies = List.mapPartial (fn (DCookie s, _) => SOME s | _ => NONE) ds
+
+        val cookieCode = foldl (fn (cookie, acc) =>
+                                   SOME (case acc of
+                                             NONE => string ("uw_unnull(uw_Basis_get_cookie(ctx, \""
+                                                             ^ cookie ^ "\"))")
+                                           | SOME acc => box [string ("uw_Basis_strcat(ctx, uw_unnull(uw_Basis_get_cookie(ctx, \""
+                                                                      ^ cookie ^ "\")), uw_Basis_strcat(ctx, \"/\", "),
+                                                              acc,
+                                                              string "))"]))
+                         NONE cookies
     in
         box [string "#include <stdio.h>",
              newline,
@@ -2783,6 +2849,27 @@
              string "}",
              newline,
              newline,
+             
+             string "extern void uw_sign(const char *in, char *out);",
+             newline,
+             string "extern int uw_hash_blocksize;",
+             newline,
+             string "uw_Basis_string uw_cookie_sig(uw_context ctx) {",
+             newline,
+             box [string "uw_Basis_string r = uw_malloc(ctx, uw_hash_blocksize);",
+                  newline,
+                  string "uw_sign(",
+                  case cookieCode of
+                      NONE => string "\"\""
+                    | SOME code => code,
+                  string ", r);",
+                  newline,
+                  string "return uw_Basis_makeSigString(ctx, r);",
+                  newline],
+             string "}",
+             newline,
+             newline,
+
              string "void uw_handle(uw_context ctx, char *request) {",
              newline,
              string "if (!strcmp(request, \"/app.js\")) {",
--- a/src/compiler.sml	Thu Apr 16 15:38:01 2009 -0400
+++ b/src/compiler.sml	Thu Apr 16 19:12:12 2009 -0400
@@ -611,7 +611,7 @@
         val driver_o = clibFile "driver.o"
 
         val compile = "gcc " ^ Config.gccArgs ^ " -Wstrict-prototypes -Werror -O3 -I include -c " ^ cname ^ " -o " ^ oname
-        val link = "gcc -Werror -O3 -lm -pthread " ^ libs ^ " " ^ urweb_o ^ " " ^ oname ^ " " ^ driver_o ^ " -o " ^ ename
+        val link = "gcc -Werror -O3 -lm -lmhash -pthread " ^ libs ^ " " ^ urweb_o ^ " " ^ oname ^ " " ^ driver_o ^ " -o " ^ ename
 
         val (compile, link) =
             if profile then
--- a/src/monoize.sml	Thu Apr 16 15:38:01 2009 -0400
+++ b/src/monoize.sml	Thu Apr 16 19:12:12 2009 -0400
@@ -2399,7 +2399,7 @@
 
           | L.EApp ((L.ECApp (
                      (L.ECApp ((L.EFfi ("Basis", "form"), _), _), _),
-                     _), _),
+                     (L.CRecord (_, fields), _)), _),
                     xml) =>
             let
                 fun findSubmit (e, _) =
@@ -2468,7 +2468,38 @@
                          fm)
                     end
                 
+                fun inFields s = List.exists (fn ((L.CName s', _), _) => s' = s
+                                               | _ => true) fields
+
+                fun getSigName () =
+                    let
+                        fun getSigName' n =
+                            let
+                                val s = "Sig" ^ Int.toString n
+                            in
+                                if inFields s then
+                                    getSigName' (n + 1)
+                                else
+                                    s
+                            end
+                    in
+                        if inFields "Sig" then
+                            getSigName' 0
+                        else
+                            "Sig"
+                    end
+
+                val sigName = getSigName ()
+                val sigSet = (L'.EFfiApp ("Basis", "sigString", [(L'.ERecord [], loc)]), loc)
+                val sigSet = (L'.EStrcat ((L'.EPrim (Prim.String ("<input type=\"hidden\" name=\""
+                                                                  ^ sigName
+                                                                  ^ "\" value=\"")), loc),
+                                          sigSet), loc)
+                val sigSet = (L'.EStrcat (sigSet,
+                                          (L'.EPrim (Prim.String "\">"), loc)), loc)
+
                 val (xml, fm) = monoExp (env, st, fm) xml
+                val xml = (L'.EStrcat (sigSet, xml), loc)
             in
                 ((L'.EStrcat ((L'.EStrcat ((L'.EPrim (Prim.String "<form method=\"post\""), loc),
                                            (L'.EStrcat (action,
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/tests/cookieSec.ur	Thu Apr 16 19:12:12 2009 -0400
@@ -0,0 +1,24 @@
+table t : {Id : int}
+
+cookie c : int
+
+fun setter r =
+    setCookie c (readError r.Id);
+    return <xml>Done</xml>
+
+fun writer () =
+    ido <- getCookie c;
+    case ido of
+        None => error <xml>No cookie</xml>
+      | Some id => dml (INSERT INTO t (Id) VALUES ({[id]}));
+                   return <xml>Done</xml>
+
+fun main () = return <xml><body>
+  <form>
+    <textbox{#Id}/> <submit value="Get cookie" action={setter}/>
+  </form>
+
+  <form>
+    <submit value="Write to database" action={writer}/>
+  </form>
+</body></xml>
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/tests/cookieSec.urp	Thu Apr 16 19:12:12 2009 -0400
@@ -0,0 +1,5 @@
+debug
+database dbname=cookiesec
+sql cookieSec.sql
+
+cookieSec
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/tests/cookieSec.urs	Thu Apr 16 19:12:12 2009 -0400
@@ -0,0 +1,1 @@
+val main : unit -> transaction page