changeset 867:e7f80d78075b

Moved query code into Settings
author Adam Chlipala <adamc@hcoop.net>
date Sun, 28 Jun 2009 16:03:00 -0400
parents 03e7f111fe99
children 06497beb265b
files src/cjr_print.sml src/mysql.sml src/postgres.sml src/settings.sig src/settings.sml
diffstat 5 files changed, 384 insertions(+), 186 deletions(-) [+]
line wrap: on
line diff
--- a/src/cjr_print.sml	Sun Jun 28 13:49:32 2009 -0400
+++ b/src/cjr_print.sml	Sun Jun 28 16:03:00 2009 -0400
@@ -470,20 +470,8 @@
                            string ")"]),
              string ")"]
 
-datatype sql_type =
-         Int
-       | Float
-       | String
-       | Bool
-       | Time
-       | Blob
-       | Channel
-       | Client
-       | Nullable of sql_type
-
-fun isBlob Blob = true
-  | isBlob (Nullable t) = isBlob t
-  | isBlob _ = false
+datatype sql_type = datatype Settings.sql_type
+val isBlob = Settings.isBlob
 
 fun isFile (t : typ) =
     case #1 t of
@@ -1250,6 +1238,21 @@
         urlify' IS.empty 0 t
     end
 
+fun sql_type_in env (tAll as (t, loc)) =
+    case t of
+        TFfi ("Basis", "int") => Int
+      | TFfi ("Basis", "float") => Float
+      | TFfi ("Basis", "string") => String
+      | TFfi ("Basis", "bool") => Bool
+      | TFfi ("Basis", "time") => Time
+      | TFfi ("Basis", "blob") => Blob
+      | TFfi ("Basis", "channel") => Channel
+      | TFfi ("Basis", "client") => Client
+      | TOption t' => Nullable (sql_type_in env t')
+      | _ => (ErrorMsg.errorAt loc "Don't know SQL equivalent of type";
+              Print.eprefaces' [("Type", p_typ env tAll)];
+              Int)
+
 fun p_exp' par env (e, loc) =
     case e of
         EPrim p => Prim.p_t_GCC p
@@ -1570,6 +1573,56 @@
 
             val wontLeakStrings = notLeaky env true state
             val wontLeakAnything = notLeaky env false state
+
+            val inputs =
+                case prepared of
+                    NONE => []
+                  | SOME _ => getPargs query
+
+            fun doCols p_getcol =
+                box [string "struct __uws_",
+                     string (Int.toString rnum),
+                     string " __uwr_r_",
+                     string (Int.toString (E.countERels env)),
+                     string ";",
+                     newline,
+                     p_typ env state,
+                     space,
+                     string "__uwr_acc_",
+                     string (Int.toString (E.countERels env + 1)),
+                     space,
+                     string "=",
+                     space,
+                     string "acc;",
+                     newline,
+                     newline,
+                     p_list_sepi (box []) (fn i =>
+                                           fn (proj, t) =>
+                                              box [string "__uwr_r_",
+                                                   string (Int.toString (E.countERels env)),
+                                                   string ".",
+                                                   string proj,
+                                                   space,
+                                                   string "=",
+                                                   space,
+                                                   p_getcol {wontLeakStrings = wontLeakStrings,
+                                                             col = i,
+                                                             typ = sql_type_in env t},
+                                                   string ";",
+                                                   newline]) outputs,
+                     newline,
+                     newline,
+
+                     string "acc",
+                     space,
+                     string "=",
+                     space,
+                     p_exp (E.pushERel
+                                (E.pushERel env "r" (TRecord rnum, loc))
+                                "acc" state) 
+                           body,
+                     string ";",
+                     newline]
         in
             box [if wontLeakAnything then
                      string "(uw_begin_region(ctx), "
@@ -1577,8 +1630,6 @@
                      box [],
                  string "({",
                  newline,
-                 string "PGconn *conn = uw_get_db(ctx);",
-                 newline,
                  p_typ env state,
                  space,
                  string "acc",
@@ -1588,176 +1639,46 @@
                  p_exp env initial,
                  string ";",
                  newline,
-                 string "int n, i, dummy = (uw_begin_region(ctx), 0);",
+                 string "int dummy = (uw_begin_region(ctx), 0);",
                  newline,
                  
                  case prepared of
-                     NONE => box [string "char *query = ",
-                                  p_exp env query,
-                                  string ";",
-                                  newline]
-                   | SOME _ =>
-                     let
-                         val ets = getPargs query
-                     in
-                         box [p_list_sepi newline
-                                          (fn i => fn (e, t) =>
-                                                      box [p_sql_type t,
-                                                           space,
-                                                           string "arg",
-                                                           string (Int.toString (i + 1)),
-                                                           space,
-                                                           string "=",
-                                                           space,
-                                                           p_exp env e,
-                                                           string ";"])
-                                          ets,
-                              newline,
-                              newline,
+                     NONE =>
+                     box [string "char *query = ",
+                          p_exp env query,
+                          string ";",
+                          newline,
+                          newline,
 
-                              string "const int paramFormats[] = { ",
-                              p_list_sep (box [string ",", space])
-                              (fn (_, t) => if isBlob t then string "1" else string "0") ets,
-                              string " };",
-                              newline,
-                              string "const int paramLengths[] = { ",
-                              p_list_sepi (box [string ",", space])
-                              (fn i => fn (_, Blob) => string ("arg" ^ Int.toString (i + 1) ^ ".size")
-                                        | (_, Nullable Blob) => string ("arg" ^ Int.toString (i + 1)
-                                                                        ^ "?arg" ^ Int.toString (i + 1) ^ "->size:0")
-                                        | _ => string "0") ets,
-                              string " };",
-                              newline,
-                              string "const char *paramValues[] = { ",
-                              p_list_sepi (box [string ",", space])
-                              (fn i => fn (_, t) => p_ensql t (box [string "arg",
-                                                                    string (Int.toString (i + 1))]))
-                              ets,
-                              string " };",
-                              newline,
-                              newline]
-                     end,
+                          #query (Settings.currentDbms ())
+                                 {loc = loc,
+                                  numCols = length outputs,
+                                  doCols = doCols}]
+                   | SOME (id, query) =>
+                     box [p_list_sepi newline
+                                      (fn i => fn (e, t) =>
+                                                  box [p_sql_type t,
+                                                       space,
+                                                       string "arg",
+                                                       string (Int.toString (i + 1)),
+                                                       space,
+                                                       string "=",
+                                                       space,
+                                                       p_exp env e,
+                                                       string ";"])
+                                      inputs,
+                          newline,
+                          newline,
 
-                 string "PGresult *res = ",
-                 case prepared of
-                     NONE => string "PQexecParams(conn, query, 0, NULL, NULL, NULL, NULL, 0);"
-                   | SOME (n, s) =>
-                     if #persistent (Settings.currentProtocol ()) then
-                         box [string "PQexecPrepared(conn, \"uw",
-                              string (Int.toString n),
-                              string "\", ",
-                              string (Int.toString (length (getPargs query))),
-                              string ", paramValues, paramLengths, paramFormats, 0);"]
-                     else
-                         box [string "PQexecParams(conn, \"",
-                              string (String.toString s),
-                              string "\", ",
-                              string (Int.toString (length (getPargs query))),
-                              string ", NULL, paramValues, paramLengths, paramFormats, 0);"],
-                 newline,
+                          #queryPrepared (Settings.currentDbms ())
+                                         {loc = loc,
+                                          id = id,
+                                          query = query,
+                                          inputs = map #2 inputs,
+                                          numCols = length outputs,
+                                          doCols = doCols}],
                  newline,
 
-                 string "if (res == NULL) uw_error(ctx, FATAL, \"Out of memory allocating query result.\");",
-                 newline,
-                 newline,
-
-                 string "if (PQresultStatus(res) != PGRES_TUPLES_OK) {",
-                 newline,
-                 box [string "PQclear(res);",
-                      newline,
-                      string "uw_error(ctx, FATAL, \"",
-                      string (ErrorMsg.spanToString loc),
-                      string ": Query failed:\\n%s\\n%s\", ",
-                      case prepared of
-                          NONE => string "query"
-                        | SOME _ => p_exp env query,
-                      string ", PQerrorMessage(conn));",
-                      newline],
-                 string "}",
-                 newline,
-                 newline,
-
-                 string "if (PQnfields(res) != ",
-                 string (Int.toString (length outputs)),
-                 string ") {",
-                 newline,
-                 box [string "int nf = PQnfields(res);",
-                      newline,
-                      string "PQclear(res);",
-                      newline,
-                      string "uw_error(ctx, FATAL, \"",
-                      string (ErrorMsg.spanToString loc),
-                      string ": Query returned %d columns instead of ",
-                      string (Int.toString (length outputs)),
-                      string ":\\n%s\\n%s\", ",
-                      case prepared of
-                          NONE => string "query"
-                        | SOME _ => p_exp env query,
-                      string ", nf, PQerrorMessage(conn));",
-                      newline],
-                 string "}",
-                 newline,
-                 newline,
-
-                 string "uw_end_region(ctx);",
-                 newline,
-                 string "uw_push_cleanup(ctx, (void (*)(void *))PQclear, res);",
-                 newline,
-                 string "n = PQntuples(res);",
-                 newline,
-                 string "for (i = 0; i < n; ++i) {",
-                 newline,
-                 box [string "struct",
-                      space,
-                      string "__uws_",
-                      string (Int.toString rnum),
-                      space,
-                      string "__uwr_r_",
-                      string (Int.toString (E.countERels env)),
-                      string ";",
-                      newline,
-                      p_typ env state,
-                      space,
-                      string "__uwr_acc_",
-                      string (Int.toString (E.countERels env + 1)),
-                      space,
-                      string "=",
-                      space,
-                      string "acc;",
-                      newline,
-                      newline,
-
-                      p_list_sepi (box []) (fn i =>
-                                            fn (proj, t) =>
-                                               box [string "__uwr_r_",
-                                                    string (Int.toString (E.countERels env)),
-                                                    string ".",
-                                                    string proj,
-                                                    space,
-                                                    string "=",
-                                                    space,
-                                                    p_getcol wontLeakStrings env t i,
-                                                    string ";",
-                                                    newline]) outputs,
-             
-                      newline,
-                      newline,
-
-                      string "acc",
-                      space,
-                      string "=",
-                      space,
-                      p_exp (E.pushERel
-                                 (E.pushERel env "r" (TRecord rnum, loc))
-                                 "acc" state) 
-                            body,
-                      string ";",
-                      newline],
-                 string "}",
-                 newline,
-                 newline,
-                 string "uw_pop_cleanup(ctx);",
-                 newline,
                  if wontLeakAnything then
                      box [string "uw_end_region(ctx);",
                           newline]
--- a/src/mysql.sml	Sun Jun 28 13:49:32 2009 -0400
+++ b/src/mysql.sml	Sun Jun 28 16:03:00 2009 -0400
@@ -186,7 +186,7 @@
              newline,
              string "}",
              newline,
-             string "conn = malloc(sizeof(conn));",
+             string "conn = calloc(1, sizeof(conn));",
              newline,
              string "conn->conn = mysql;",
              newline,
@@ -253,6 +253,9 @@
              newline]
     end
 
+fun query _ = raise Fail "MySQL query"
+fun queryPrepared _ = raise Fail "MySQL queryPrepared"
+
 val () = addDbms {name = "mysql",
                   header = "mysql/mysql.h",
                   link = "-lmysqlclient",
@@ -268,6 +271,8 @@
                                           newline],
                               string "}",
                                      newline],
-                  init = init}
+                  init = init,
+                  query = query,
+                  queryPrepared = queryPrepared}
 
 end
--- a/src/postgres.sml	Sun Jun 28 13:49:32 2009 -0400
+++ b/src/postgres.sml	Sun Jun 28 16:03:00 2009 -0400
@@ -189,12 +189,216 @@
          newline,
          string "}"]
 
+fun p_getcol {wontLeakStrings, col = i, typ = t} =
+    let
+        fun p_unsql t e eLen =
+            case t of
+                Int => box [string "uw_Basis_stringToInt_error(ctx, ", e, string ")"]
+              | Float => box [string "uw_Basis_stringToFloat_error(ctx, ", e, string ")"]
+              | String =>
+                if wontLeakStrings then
+                    e
+                else
+                    box [string "uw_strdup(ctx, ", e, string ")"]
+              | Bool => box [string "uw_Basis_stringToBool_error(ctx, ", e, string ")"]
+              | Time => box [string "uw_Basis_stringToTime_error(ctx, ", e, string ")"]
+              | Blob => box [string "uw_Basis_stringToBlob_error(ctx, ",
+                             e,
+                             string ", ",
+                             eLen,
+                             string ")"]
+              | Channel => box [string "uw_Basis_stringToChannel_error(ctx, ", e, string ")"]
+              | Client => box [string "uw_Basis_stringToClient_error(ctx, ", e, string ")"]
+
+              | Nullable _ => raise Fail "Postgres: Recursive Nullable"
+
+        fun getter t =
+            case t of
+                Nullable t =>
+                box [string "(PQgetisnull(res, i, ",
+                     string (Int.toString i),
+                     string ") ? NULL : ",
+                     case t of
+                         String => getter t
+                       | _ => box [string "({",
+                                   newline,
+                                   p_sql_type t,
+                                   space,
+                                   string "*tmp = uw_malloc(ctx, sizeof(",
+                                   p_sql_type t,
+                                   string "));",
+                                   newline,
+                                   string "*tmp = ",
+                                   getter t,
+                                   string ";",
+                                   newline,
+                                   string "tmp;",
+                                   newline,
+                                   string "})"],
+                     string ")"]
+              | _ =>
+                box [string "(PQgetisnull(res, i, ",
+                     string (Int.toString i),
+                     string ") ? ",
+                     box [string "({",
+                          p_sql_type t,
+                          space,
+                          string "tmp;",
+                          newline,
+                          string "uw_error(ctx, FATAL, \"Unexpectedly NULL field #",
+                          string (Int.toString i),
+                          string "\");",
+                          newline,
+                          string "tmp;",
+                          newline,
+                          string "})"],
+                     string " : ",
+                     p_unsql t
+                             (box [string "PQgetvalue(res, i, ",
+                                   string (Int.toString i),
+                                   string ")"])
+                             (box [string "PQgetlength(res, i, ",
+                                   string (Int.toString i),
+                                   string ")"]),
+                     string ")"]
+    in
+        getter t
+    end
+
+fun queryCommon {loc, query, numCols, doCols} =
+    box [string "int n, i;",
+         newline,
+         newline,
+
+         string "if (res == NULL) uw_error(ctx, FATAL, \"Out of memory allocating query result.\");",
+         newline,
+         newline,
+
+         string "if (PQresultStatus(res) != PGRES_TUPLES_OK) {",
+         newline,
+         box [string "PQclear(res);",
+              newline,
+              string "uw_error(ctx, FATAL, \"",
+              string (ErrorMsg.spanToString loc),
+              string ": Query failed:\\n%s\\n%s\", ",
+              query,
+              string ", PQerrorMessage(conn));",
+              newline],
+         string "}",
+         newline,
+         newline,
+
+         string "if (PQnfields(res) != ",
+         string (Int.toString numCols),
+         string ") {",
+         newline,
+         box [string "int nf = PQnfields(res);",
+              newline,
+              string "PQclear(res);",
+              newline,
+              string "uw_error(ctx, FATAL, \"",
+              string (ErrorMsg.spanToString loc),
+              string ": Query returned %d columns instead of ",
+              string (Int.toString numCols),
+              string ":\\n%s\\n%s\", nf, ",
+              query,
+              string ", PQerrorMessage(conn));",
+              newline],
+         string "}",
+         newline,
+         newline,
+
+         string "uw_end_region(ctx);",
+         newline,
+         string "uw_push_cleanup(ctx, (void (*)(void *))PQclear, res);",
+         newline,
+         string "n = PQntuples(res);",
+         newline,
+         string "for (i = 0; i < n; ++i) {",
+         newline,
+         doCols p_getcol,
+         string "}",
+         newline,
+         newline,
+         string "uw_pop_cleanup(ctx);",
+         newline]    
+
+fun query {loc, numCols, doCols} =
+    box [string "PGconn *conn = uw_get_db(ctx);",
+         newline,
+         string "PGresult *res = PQexecParams(conn, query, 0, NULL, NULL, NULL, NULL, 0);",
+         newline,
+         newline,
+         queryCommon {loc = loc, numCols = numCols, doCols = doCols, query = string "query"}]
+
+fun p_ensql t e =
+    case t of
+        Int => box [string "uw_Basis_attrifyInt(ctx, ", e, string ")"]
+      | Float => box [string "uw_Basis_attrifyFloat(ctx, ", e, string ")"]
+      | String => e
+      | Bool => box [string "(", e, string " ? \"TRUE\" : \"FALSE\")"]
+      | Time => box [string "uw_Basis_attrifyTime(ctx, ", e, string ")"]
+      | Blob => box [e, string ".data"]
+      | Channel => box [string "uw_Basis_attrifyChannel(ctx, ", e, string ")"]
+      | Client => box [string "uw_Basis_attrifyClient(ctx, ", e, string ")"]
+      | Nullable String => e
+      | Nullable t => box [string "(",
+                           e,
+                           string " == NULL ? NULL : ",
+                           p_ensql t (box [string "(*", e, string ")"]),
+                           string ")"]
+
+fun queryPrepared {loc, id, query, inputs, numCols, doCols} =
+    box [string "PGconn *conn = uw_get_db(ctx);",
+         newline,
+         string "const int paramFormats[] = { ",
+         p_list_sep (box [string ",", space])
+                    (fn t => if isBlob t then string "1" else string "0") inputs,
+         string " };",
+         newline,
+         string "const int paramLengths[] = { ",
+         p_list_sepi (box [string ",", space])
+                     (fn i => fn Blob => string ("arg" ^ Int.toString (i + 1) ^ ".size")
+                               | Nullable Blob => string ("arg" ^ Int.toString (i + 1)
+                                                          ^ "?arg" ^ Int.toString (i + 1) ^ "->size:0")
+                               | _ => string "0") inputs,
+         string " };",
+         newline,
+         string "const char *paramValues[] = { ",
+         p_list_sepi (box [string ",", space])
+                     (fn i => fn t => p_ensql t (box [string "arg",
+                                                      string (Int.toString (i + 1))]))
+                     inputs,
+         string " };",
+         newline,
+         newline,
+         string "PGresult *res = ",
+         if #persistent (Settings.currentProtocol ()) then
+             box [string "PQexecPrepared(conn, \"uw",
+                  string (Int.toString id),
+                  string "\", ",
+                  string (Int.toString (length inputs)),
+                  string ", paramValues, paramLengths, paramFormats, 0);"]
+         else
+             box [string "PQexecParams(conn, \"",
+                  string (String.toString query),
+                  string "\", ",
+                  string (Int.toString (length inputs)),
+                  string ", NULL, paramValues, paramLengths, paramFormats, 0);"],
+         newline,
+         newline,
+         queryCommon {loc = loc, numCols = numCols, doCols = doCols, query = box [string "\"",
+                                                                                  string (String.toString query),
+                                                                                  string "\""]}]
+
 val () = addDbms {name = "postgres",
                   header = "postgresql/libpq-fe.h",
                   link = "-lpq",
                   global_init = box [string "void uw_client_init() { }",
                                      newline],
-                  init = init}
+                  init = init,
+                  query = query,
+                  queryPrepared = queryPrepared}
 val () = setDbms "postgres"
 
 end
--- a/src/settings.sig	Sun Jun 28 13:49:32 2009 -0400
+++ b/src/settings.sig	Sun Jun 28 16:03:00 2009 -0400
@@ -101,6 +101,20 @@
     val currentProtocol : unit -> protocol
 
     (* Different DBMSes *)
+    datatype sql_type =
+             Int
+           | Float
+           | String
+           | Bool
+           | Time
+           | Blob
+           | Channel
+           | Client
+           | Nullable of sql_type
+
+    val p_sql_type : sql_type -> Print.PD.pp_desc
+    val isBlob : sql_type -> bool
+
     type dbms = {
          name : string,
          (* Call it this on the command line *)
@@ -110,8 +124,18 @@
          (* Pass these linker arguments *)
          global_init : Print.PD.pp_desc,
          (* Define uw_client_init() *)
-         init : string * (string * int) list -> Print.PD.pp_desc
-         (* Define uw_db_init() from dbstring and prepared statements *)
+         init : string * (string * int) list -> Print.PD.pp_desc,
+         (* Define uw_db_init(), uw_db_close(), uw_db_begin(), uw_db_commit(), and uw_db_rollback()
+          * from dbstring and prepared statements *)
+         query : {loc : ErrorMsg.span, numCols : int,
+                  doCols : ({wontLeakStrings : bool, col : int, typ : sql_type} -> Print.PD.pp_desc)
+                           -> Print.PD.pp_desc}
+                 -> Print.PD.pp_desc,
+         queryPrepared : {loc : ErrorMsg.span, id : int, query : string,
+                          inputs : sql_type list, numCols : int,
+                          doCols : ({wontLeakStrings : bool, col : int, typ : sql_type} -> Print.PD.pp_desc)
+                                   -> Print.PD.pp_desc}
+                         -> Print.PD.pp_desc
     }
 
     val addDbms : dbms -> unit
--- a/src/settings.sml	Sun Jun 28 13:49:32 2009 -0400
+++ b/src/settings.sml	Sun Jun 28 16:03:00 2009 -0400
@@ -274,12 +274,54 @@
 fun setDebug b = debug := b
 fun getDebug () = !debug
 
+datatype sql_type =
+         Int
+       | Float
+       | String
+       | Bool
+       | Time
+       | Blob
+       | Channel
+       | Client
+       | Nullable of sql_type
+
+fun p_sql_type t =
+    let
+        open Print.PD
+        open Print
+    in
+        case t of
+            Int => string "uw_Basis_int"
+          | Float => string "uw_Basis_float"
+          | String => string "uw_Basis_string"
+          | Bool => string "uw_Basis_bool"
+          | Time => string "uw_Basis_time"
+          | Blob => string "uw_Basis_blob"
+          | Channel => string "uw_Basis_channel"
+          | Client => string "uw_Basis_client"
+          | Nullable String => string "uw_Basis_string"
+          | Nullable t => box [p_sql_type t, string "*"]
+    end
+
+fun isBlob Blob = true
+  | isBlob (Nullable t) = isBlob t
+  | isBlob _ = false
+
 type dbms = {
      name : string,
      header : string,
      link : string,
      global_init : Print.PD.pp_desc,
-     init : string * (string * int) list -> Print.PD.pp_desc
+     init : string * (string * int) list -> Print.PD.pp_desc,
+     query : {loc : ErrorMsg.span, numCols : int,
+              doCols : ({wontLeakStrings : bool, col : int, typ : sql_type} -> Print.PD.pp_desc)
+                       -> Print.PD.pp_desc}
+             -> Print.PD.pp_desc,
+     queryPrepared : {loc : ErrorMsg.span, id : int, query : string,
+                      inputs : sql_type list, numCols : int,
+                      doCols : ({wontLeakStrings : bool, col : int, typ : sql_type} -> Print.PD.pp_desc)
+                               -> Print.PD.pp_desc}
+                     -> Print.PD.pp_desc
 }
 
 val dbmses = ref ([] : dbms list)
@@ -287,7 +329,9 @@
                   header = "",
                   link = "",
                   global_init = Print.box [],
-                  init = fn _ => Print.box []} : dbms)
+                  init = fn _ => Print.box [],
+                  query = fn _ => Print.box [],
+                  queryPrepared = fn _ => Print.box []} : dbms)
 
 fun addDbms v = dbmses := v :: !dbmses
 fun setDbms s =