diff src/cjr_print.sml @ 282:0236d9412ad2

Ran a prepared statement with one string parameter
author Adam Chlipala <adamc@hcoop.net>
date Sun, 07 Sep 2008 09:28:13 -0400
parents fdd7a698be01
children c0e4ac23522d
line wrap: on
line diff
--- a/src/cjr_print.sml	Thu Sep 04 10:27:21 2008 -0400
+++ b/src/cjr_print.sml	Sun Sep 07 09:28:13 2008 -0400
@@ -333,6 +333,45 @@
               Print.eprefaces' [("Type", p_typ env tAll)];
               string "ERROR")
 
+datatype sql_type =
+         Int
+       | Float
+       | String
+       | Bool
+
+fun p_sql_type t =
+    string (case t of
+                Int => "lw_Basis_int"
+              | Float => "lw_Basis_float"
+              | String => "lw_Basis_string"
+              | Bool => "lw_Basis_bool")
+
+fun getPargs (e, _) =
+    case e of
+        EPrim (Prim.String _) => []
+      | EFfiApp ("Basis", "strcat", [e1, e2]) => getPargs e1 @ getPargs e2
+
+      | EFfiApp ("Basis", "sqlifyInt", [e]) => [(e, Int)]
+      | EFfiApp ("Basis", "sqlifyFloat", [e]) => [(e, Float)]
+      | EFfiApp ("Basis", "sqlifyString", [e]) => [(e, String)]
+      | EFfiApp ("Basis", "sqlifyBool", [e]) => [(e, Bool)]
+
+      | _ => raise Fail "CjrPrint: getPargs"
+
+fun p_ensql t e =
+    case t of
+        Int => box [string "(char *)&", e]
+      | Float => box [string "(char *)&", e]
+      | String => e
+      | Bool => box [string "lw_Basis_ensqlBool(", e, string ")"]
+
+fun p_ensql_len t e =
+    case t of
+        Int => string "sizeof(lw_Basis_int)"
+      | Float => string "sizeof(lw_Basis_float)"
+      | String => box [string "strlen(", e, string ")"]
+      | Bool => string "sizeof(lw_Basis_bool)"
+
 fun p_exp' par env (e, loc) =
     case e of
         EPrim p => Prim.p_t_GCC p
@@ -560,7 +599,7 @@
                                     newline,
                                     string "})"]
 
-      | EQuery {exps, tables, rnum, state, query, body, initial} =>
+      | EQuery {exps, tables, rnum, state, query, body, initial, prepared} =>
         let
             val exps = map (fn (x, t) => ("__lwf_" ^ x, t)) exps
             val tables = ListUtil.mapConcat (fn (x, xts) =>
@@ -573,10 +612,54 @@
                  newline,
                  string "PGconn *conn = lw_get_db(ctx);",
                  newline,
-                 string "char *query = ",
-                 p_exp env query,
-                 string ";",
-                 newline,
+                 case prepared of
+                     NONE => box [string "char *query = ",
+                                  p_exp env query,
+                                  string ";",
+                                  newline]
+                   | SOME _ =>
+                     let
+                         val ets = getPargs query
+                     in
+                         box [p_list_sepi newline
+                                          (fn i => fn (e, t) =>
+                                                      box [p_sql_type t,
+                                                           space,
+                                                           string "arg",
+                                                           string (Int.toString (i + 1)),
+                                                           space,
+                                                           string "=",
+                                                           space,
+                                                           p_exp env e,
+                                                           string ";"])
+                                          ets,
+                              newline,
+                              newline,
+
+                              string "const char *paramValues[] = { ",
+                              p_list_sepi (box [string ",", space])
+                              (fn i => fn (_, t) => p_ensql t (box [string "arg",
+                                                                    string (Int.toString (i + 1))]))
+                              ets,
+                              string " };",
+                              newline,
+                              newline,
+
+                              string "const int paramLengths[] = { ",
+                              p_list_sepi (box [string ",", space])
+                              (fn i => fn (_, t) => p_ensql_len t (box [string "arg",
+                                                                        string (Int.toString (i + 1))]))
+                              ets,
+                              string " };",
+                              newline,
+                              newline,
+                              
+                              string "const static int paramFormats[] = { ",
+                              p_list_sep (box [string ",", space]) (fn _ => string "1") ets,
+                              string " };",
+                              newline,
+                              newline]
+                     end,
                  string "int n, i;",
                  newline,
                  p_typ env state,
@@ -588,7 +671,14 @@
                  p_exp env initial,
                  string ";",
                  newline,
-                 string "PGresult *res = PQexecParams(conn, query, 0, NULL, NULL, NULL, NULL, 1);",
+                 string "PGresult *res = ",
+                 case prepared of
+                     NONE => string "PQexecParams(conn, query, 0, NULL, NULL, NULL, NULL, 1);"
+                   | SOME n => box [string "PQexecPrepared(conn, \"lw",
+                                    string (Int.toString n),
+                                    string "\", ",
+                                    string (Int.toString (length (getPargs query))),
+                                    string ", paramValues, paramLengths, paramFormats, 1);"],
                  newline,
                  newline,
 
@@ -602,7 +692,11 @@
                       newline,
                       string "lw_error(ctx, FATAL, \"",
                       string (ErrorMsg.spanToString loc),
-                      string ": Query failed:\\n%s\\n%s\", query, PQerrorMessage(conn));",
+                      string ": Query failed:\\n%s\\n%s\", ",
+                      case prepared of
+                          NONE => string "query"
+                        | SOME _ => p_exp env query,
+                      string ", PQerrorMessage(conn));",
                       newline],
                  string "}",
                  newline,
@@ -814,6 +908,8 @@
                               newline]
       | DDatabase s => box [string "static void lw_db_validate(lw_context);",
                             newline,
+                            string "static void lw_db_prepare(lw_context);",
+                            newline,
                             newline,
                             string "void lw_db_init(lw_context ctx) {",
                             newline,
@@ -843,6 +939,8 @@
                             newline,
                             string "lw_db_validate(ctx);",
                             newline,
+                            string "lw_db_prepare(ctx);",
+                            newline,
                             string "}",
                             newline,
                             newline,
@@ -853,6 +951,48 @@
                             string "}",
                             newline]
 
+      | DPreparedStatements ss =>
+        box [string "static void lw_db_prepare(lw_context ctx) {",
+             newline,
+             string "PGconn *conn = lw_get_db(ctx);",
+             newline,
+             string "PGresult *res;",
+             newline,
+             newline,
+
+             p_list_sepi newline (fn i => fn (s, n) =>
+                                             box [string "res = PQprepare(conn, \"lw",
+                                                  string (Int.toString i),
+                                                  string "\", \"",
+                                                  string (String.toString s),
+                                                  string "\", ",
+                                                  string (Int.toString n),
+                                                  string ", NULL);",
+                                                  newline,
+                                                  string "if (PQresultStatus(res) != PGRES_COMMAND_OK) {",
+                                                  newline,
+                                                  box [string "char msg[1024];",
+                                                       newline,
+                                                       string "strncpy(msg, PQerrorMessage(conn), 1024);",
+                                                       newline,
+                                                       string "msg[1023] = 0;",
+                                                       newline,
+                                                       string "PQclear(res);",
+                                                       newline,
+                                                       string "PQfinish(conn);",
+                                                       newline,
+                                                       string "lw_error(ctx, FATAL, \"Unable to create prepared statement:\\n",
+                                                       string (String.toString s),
+                                                       string "\\n%s\", msg);",
+                                                       newline],
+                                                  string "}",
+                                                  newline,
+                                                  string "PQclear(res);",
+                                                  newline])
+                         ss,
+             
+             string "}"]
+
 datatype 'a search =
          Found of 'a
        | NotFound