diff src/ur/openid.ur @ 8:870d99055dd1

Diffie-Hellman started but not fully tested; successfully checked signature from AOL
author Adam Chlipala <adam@chlipala.net>
date Wed, 29 Dec 2010 12:16:32 -0500
parents 976121190b2d
children 426dd5c88df1
line wrap: on
line diff
--- a/src/ur/openid.ur	Tue Dec 28 19:57:25 2010 -0500
+++ b/src/ur/openid.ur	Wed Dec 29 12:16:32 2010 -0500
@@ -27,48 +27,164 @@
     OpenidFfi.addInput is "openid.ns" "http://specs.openid.net/auth/2.0";
     return is
 
-table associations : { Endpoint : string, Handle : string, Key : string, Expires : time }
+datatype association_type = HMAC_SHA1 | HMAC_SHA256
+datatype association_session_type = NoEncryption | DH_SHA1 | DH_SHA256
+
+table associations : { Endpoint : string, Handle : string, Typ : serialized association_type, Key : string, Expires : time }
   PRIMARY KEY Endpoint
 
-datatype association = Association of {Handle : string, Key : string} | AssError of string
+datatype association = Association of {Handle : string, Typ : association_type, Key : string}
+                     | AssError of string
+                     | AssAlternate of {Atype : association_type, Stype : association_session_type}
 
-fun association url =
-    secret <- oneOrNoRows1 (SELECT associations.Handle, associations.Key
+fun atype_show v =
+    case v of
+        HMAC_SHA1 => "HMAC-SHA1"
+      | HMAC_SHA256 => "HMAC-SHA256"
+
+val show_atype = mkShow atype_show
+
+fun stype_show v =
+    case v of
+        NoEncryption => "no-encryption"
+      | DH_SHA1 => "DH-SHA1"
+      | DH_SHA256 => "DH-SHA256"
+
+val show_stype = mkShow stype_show
+
+fun atype_read s =
+    case s of
+        "HMAC-SHA1" => Some HMAC_SHA1
+      | "HMAC-SHA256" => Some HMAC_SHA256
+      | _ => None
+
+val read_atype = mkRead' atype_read "association type"
+
+fun stype_read s =
+    case s of
+        "no-encryption" => Some NoEncryption
+      | "DH-SHA1" => Some DH_SHA1
+      | "DH-SHA256" => Some DH_SHA256
+      | _ => None
+
+val read_stype = mkRead' stype_read "association session type"
+
+fun atype_eq v1 v2 =
+    case (v1, v2) of
+        (HMAC_SHA1, HMAC_SHA1) => True
+      | (HMAC_SHA256, HMAC_SHA256) => True
+      | _ => False
+
+val eq_atype = mkEq atype_eq
+
+fun stype_eq v1 v2 =
+    case (v1, v2) of
+        (NoEncryption, NoEncryption) => True
+      | (DH_SHA1, DH_SHA1) => True
+      | (DH_SHA256, DH_SHA256) => True
+      | _ => False
+
+val eq_stype = mkEq stype_eq
+
+fun errorResult atype stype os =
+    case OpenidFfi.getOutput os "error" of
+        Some v =>
+        (case (OpenidFfi.getOutput os "error_code", OpenidFfi.getOutput os "assoc_type", OpenidFfi.getOutput os "session_type") of
+             (Some "unsupported-type", at, st) => Some (AssAlternate {Atype = Option.get atype (Option.bind read at),
+                                                                      Stype = Option.get stype (Option.bind read st)})
+           | _ => Some (AssError ("OP error during association: " ^ v)))
+      | None => None
+
+fun associateNoEncryption url atype =
+    is <- createInputs;
+    OpenidFfi.addInput is "openid.mode" "associate";
+    OpenidFfi.addInput is "openid.assoc_type" (show atype);
+    OpenidFfi.addInput is "openid.session_type" (show NoEncryption);
+
+    os <- OpenidFfi.direct url is;
+    case errorResult atype NoEncryption os of
+        Some v => return v
+      | None =>
+        case (OpenidFfi.getOutput os "assoc_handle", OpenidFfi.getOutput os "mac_key", OpenidFfi.getOutput os "expires_in") of
+            (Some handle, Some key, Some expires) =>
+            (case read expires of
+                 None => return (AssError "Invalid 'expires_in' field")
+               | Some expires =>
+                 tm <- now;
+                 dml (INSERT INTO associations (Endpoint, Handle, Typ, Key, Expires)
+                      VALUES ({[url]}, {[handle]}, {[serialize atype]}, {[key]}, {[addSeconds tm expires]}));
+                 return (Association {Handle = handle, Typ = atype, Key = key}))
+          | (None, _, _) => return (AssError "Missing assoc_handle")
+          | (_, None, _) => return (AssError "Missing mac_key")
+          | _ => return (AssError "Missing expires_in")
+
+fun associateDh url atype stype =
+    dh <- OpenidFfi.generate;
+
+    is <- createInputs;
+    OpenidFfi.addInput is "openid.mode" "associate";
+    OpenidFfi.addInput is "openid.assoc_type" (show atype);
+    OpenidFfi.addInput is "openid.session_type" (show stype);
+    OpenidFfi.addInput is "openid.dh_modulus" (OpenidFfi.modulus dh);
+    OpenidFfi.addInput is "openid.dh_gen" (OpenidFfi.generator dh);
+    OpenidFfi.addInput is "openid.dh_consumer_public" (OpenidFfi.public dh);
+
+    os <- OpenidFfi.direct url is;
+    case errorResult atype stype os of
+        Some v => return v
+      | None =>
+        case (OpenidFfi.getOutput os "assoc_handle", OpenidFfi.getOutput os "dh_server_public",
+              OpenidFfi.getOutput os "enc_mac_key", OpenidFfi.getOutput os "expires_in") of
+                (Some handle, Some pub, Some mac, Some expires) =>
+                (case read expires of
+                     None => return (AssError "Invalid 'expires_in' field")
+                   | Some expires =>
+                     key <- OpenidFfi.compute dh pub;
+                     tm <- now;
+                     dml (INSERT INTO associations (Endpoint, Handle, Typ, Key, Expires)
+                          VALUES ({[url]}, {[handle]}, {[serialize atype]}, {[key]}, {[addSeconds tm expires]}));
+                     return (Association {Handle = handle, Typ = atype, Key = key}))
+              | (None, _, _, _) => return (AssError "Missing assoc_handle")
+              | (_, None, _, _) => return (AssError "Missing dh_server_public")
+              | (_, _, None, _) => return (AssError "Missing enc_mac_key")
+              | _ => return (AssError "Missing expires_in")
+
+fun oldAssociation url =
+    secret <- oneOrNoRows1 (SELECT associations.Handle, associations.Typ, associations.Key
                             FROM associations
                             WHERE associations.Endpoint = {[url]});
     case secret of
+        Some r => return (Some (r -- #Typ ++ {Typ = deserialize r.Typ}))
+      | None => return None
+
+fun newAssociation url atype stype =
+    case stype of
+        NoEncryption => associateNoEncryption url atype
+      | _ => associateDh url atype stype
+
+fun association atype stype url =
+    secret <- oldAssociation url;
+    case secret of
         Some r => return (Association r)
       | None =>
-        is <- createInputs;
-        OpenidFfi.addInput is "openid.mode" "associate";
-        OpenidFfi.addInput is "openid.assoc_type" "HMAC-SHA256";
-        OpenidFfi.addInput is "openid.session_type" "no-encryption";
-
-        debug ("Contacting " ^ url);
-
-        os <- OpenidFfi.direct url is;
-        case OpenidFfi.getOutput os "error" of
-            Some v => return (AssError v)
-          | None =>
-            case (OpenidFfi.getOutput os "assoc_handle", OpenidFfi.getOutput os "mac_key", OpenidFfi.getOutput os "expires_in") of
-                (Some handle, Some key, Some expires) =>
-                (case read expires of
-                     None => return (AssError "Invalid 'expires_in' field")
-                   | Some expires =>
-                     tm <- now;
-                     dml (INSERT INTO associations (Endpoint, Handle, Key, Expires)
-                          VALUES ({[url]}, {[handle]}, {[key]}, {[addSeconds tm expires]}));
-                     return (Association {Handle = handle, Key = key}))
-              | (None, _, _) => return (AssError "Missing assoc_handle")
-              | (_, None, _) => return (AssError "Missing mac_key")
-              | _ => return (AssError "Missing fields in response from OP")
+        stype <- return (case (stype, String.isPrefix {Full = url, Prefix = "https://"}) of
+                             (NoEncryption, False) => DH_SHA256
+                           | _ => stype);
+        r <- newAssociation url atype stype;
+        case r of
+            AssAlternate alt =>
+            if alt.Atype = atype && alt.Stype = stype then
+                return (AssError "Suggested new modes match old ones!")
+            else
+                newAssociation url alt.Atype alt.Stype
+          | v => return v
 
 fun eatFragment s =
     case String.split s #"#" of
         Some (_, s') => s'
       | _ => s
 
-datatype handle_result = HandleOk of {Endpoint : string, Key : string} | HandleError of string
+datatype handle_result = HandleOk of {Endpoint : string, Typ : association_type, Key : string} | HandleError of string
 
 fun verifyHandle os id =
     ep <- discover (eatFragment id);
@@ -78,14 +194,14 @@
         case OpenidFfi.getOutput os "openid.assoc_handle" of
             None => return (HandleError "Missing association handle in response")
           | Some handle =>
-            assoc <- association ep;
+            assoc <- oldAssociation ep;
             case assoc of
-                AssError s => return (HandleError s)
-              | Association assoc =>
+                None => return (HandleError "Couldn't find association handle")
+              | Some assoc =>
                 if assoc.Handle <> handle then
                     return (HandleError "Association handles don't match")
                 else
-                    return (HandleOk {Endpoint = ep, Key = assoc.Key})
+                    return (HandleOk {Endpoint = ep, Typ = assoc.Typ, Key = assoc.Key})
 
 table nonces : { Endpoint : string, Nonce : string, Expires : time }
   PRIMARY KEY (Endpoint, Nonce)
@@ -123,7 +239,7 @@
                          VALUES ({[ep]}, {[nonce]}, {[exp]}));
                     return None
 
-fun verifySig os key =
+fun verifySig os atype key =
     case OpenidFfi.getOutput os "openid.signed" of
         None => return (Some "Missing openid.signed in OP response")
       | Some signed =>
@@ -153,7 +269,9 @@
                     None => return (Some "openid.signed mentions missing field")
                   | Some nvps =>
                     let
-                        val sign' = OpenidFfi.sha256 key nvps
+                        val sign' = case atype of
+                                        HMAC_SHA256 => OpenidFfi.sha256 key nvps
+                                      | HMAC_SHA1 => OpenidFfi.sha1 key nvps
                     in
                         debug ("Fields: " ^ signed);
                         debug ("Nvps: " ^ nvps);
@@ -187,7 +305,7 @@
                          errO <- verifyHandle os id;
                          case errO of
                              HandleError s => error <xml>{[s]}</xml>
-                           | HandleOk {Endpoint = ep, Key = key} =>
+                           | HandleOk {Endpoint = ep, Typ = atype, Key = key} =>
                              errO <- verifyReturnTo os;
                              case errO of
                                  Some s => error <xml>{[s]}</xml>
@@ -196,7 +314,7 @@
                                  case errO of
                                      Some s => error <xml>{[s]}</xml>
                                    | None =>
-                                     errO <- verifySig os key;
+                                     errO <- verifySig os atype key;
                                      case errO of
                                          Some s => error <xml>{[s]}</xml>
                                        | None => return <xml>Identity: {[id]}</xml>)
@@ -211,14 +329,15 @@
         else
             return None
 
-fun authenticate id =
+fun authenticate atype stype id =
     dy <- discover id;
     case dy of
         None => return "Discovery failed"
       | Some dy =>
-        assoc <- association dy;
+        assoc <- association atype stype dy;
         case assoc of
-            AssError msg => return msg
+            AssError msg => return ("Association failure: " ^ msg)
+          | AssAlternate _ => return "Association failure: server didn't accept its own alternate association modes"
           | Association assoc =>
             redirect (bless (dy ^ "?openid.ns=http://specs.openid.net/auth/2.0&openid.mode=checkid_setup&openid.claimed_id="
                              ^ id ^ "&openid.identity=http://specs.openid.net/auth/2.0/identifier_select&openid.assoc_handle="