diff src/mysql.sml @ 874:3c7b48040dcf

MySQL demo/sql succeeds in reading no rows
author Adam Chlipala <adamc@hcoop.net>
date Sun, 12 Jul 2009 15:05:40 -0400
parents 41971801b62d
children c50101ddf7fa
line wrap: on
line diff
--- a/src/mysql.sml	Sun Jul 12 13:16:05 2009 -0400
+++ b/src/mysql.sml	Sun Jul 12 15:05:40 2009 -0400
@@ -55,6 +55,278 @@
       | Client => "MYSQL_TYPE_LONG"
       | Nullable t => p_buffer_type t
 
+fun p_sql_type_base t =
+    case t of
+        Int => "bigint"
+      | Float => "double"
+      | String => "longtext"
+      | Bool => "tinyint"
+      | Time => "timestamp"
+      | Blob => "longblob"
+      | Channel => "bigint"
+      | Client => "int"
+      | Nullable t => p_sql_type_base t
+
+val ident = String.translate (fn #"'" => "PRIME"
+                               | ch => str ch)
+
+fun checkRel (table, checkNullable) (s, xts) =
+    let
+        val sl = CharVector.map Char.toLower s
+
+        val q = "SELECT COUNT(*) FROM information_schema." ^ table ^ " WHERE table_name = '"
+                ^ sl ^ "'"
+
+        val q' = String.concat ["SELECT COUNT(*) FROM information_schema.columns WHERE table_name = '",
+                                sl,
+                                "' AND (",
+                                String.concatWith " OR "
+                                                  (map (fn (x, t) =>
+                                                           String.concat ["(column_name = 'uw_",
+                                                                          CharVector.map
+                                                                              Char.toLower (ident x),
+                                                                          "' AND data_type = '",
+                                                                          p_sql_type_base t,
+                                                                          "'",
+                                                                          if checkNullable then
+                                                                              (" AND is_nullable = '"
+                                                                               ^ (if isNotNull t then
+                                                                                      "NO"
+                                                                                  else
+                                                                                      "YES")
+                                                                               ^ "'")
+                                                                          else
+                                                                              "",
+                                                                          ")"]) xts),
+                                ")"]
+
+        val q'' = String.concat ["SELECT COUNT(*) FROM information_schema.columns WHERE table_name = '",
+                                 sl,
+                                 "' AND column_name LIKE 'uw_%'"]
+    in
+        box [string "if (mysql_query(conn->conn, \"",
+             string q,
+             string "\")) {",
+             newline,
+             box [string "mysql_close(conn->conn);",
+                  newline,
+                  string "uw_error(ctx, FATAL, \"Query failed:\\n",
+                  string q,
+                  string "\");",
+                  newline],
+             string "}",
+             newline,
+             newline,
+
+             string "if ((res = mysql_store_result(conn->conn)) == NULL) {",
+             newline,
+             box [string "mysql_free_result(res);",
+                  newline,
+                  string "mysql_close(conn->conn);",
+                  newline,
+                  string "uw_error(ctx, FATAL, \"Result store failed:\\n",
+                  string q,
+                  string "\");",
+                  newline],
+             string "}",
+             newline,
+             newline,
+
+             string "if (mysql_num_fields(res) != 1) {",
+             newline,
+             box [string "mysql_free_result(res);",
+                  newline,
+                  string "mysql_close(conn->conn);",
+                  newline,
+                  string "uw_error(ctx, FATAL, \"Bad column count:\\n",
+                  string q,
+                  string "\");",
+                  newline],
+             string "}",
+             newline,
+             newline,
+
+             string "if ((row = mysql_fetch_row(res)) == NULL) {",
+             newline,
+             box [string "mysql_free_result(res);",
+                  newline,
+                  string "mysql_close(conn->conn);",
+                  newline,
+                  string "uw_error(ctx, FATAL, \"Row fetch failed:\\n",
+                  string q,
+                  string "\");",
+                  newline],
+             string "}",
+             newline,
+             newline,
+
+             string "if (strcmp(row[0], \"1\")) {",
+             newline,
+             box [string "mysql_free_result(res);",
+                  newline,
+                  string "mysql_close(conn->conn);",
+                  newline,
+                  string "uw_error(ctx, FATAL, \"Table '",
+                  string s,
+                  string "' does not exist.\");",
+                  newline],
+             string "}",
+             newline,
+             newline,
+             string "mysql_free_result(res);",
+             newline,
+             newline,
+
+             string "if (mysql_query(conn->conn, \"",
+             string q',
+             string "\")) {",
+             newline,
+             box [string "mysql_close(conn->conn);",
+                  newline,
+                  string "uw_error(ctx, FATAL, \"Query failed:\\n",
+                  string q',
+                  string "\");",
+                  newline],
+             string "}",
+             newline,
+             newline,
+
+             string "if ((res = mysql_store_result(conn->conn)) == NULL) {",
+             newline,
+             box [string "mysql_free_result(res);",
+                  newline,
+                  string "mysql_close(conn->conn);",
+                  newline,
+                  string "uw_error(ctx, FATAL, \"Result store failed:\\n",
+                  string q',
+                  string "\");",
+                  newline],
+             string "}",
+             newline,
+             newline,
+
+             string "if (mysql_num_fields(res) != 1) {",
+             newline,
+             box [string "mysql_free_result(res);",
+                  newline,
+                  string "mysql_close(conn->conn);",
+                  newline,
+                  string "uw_error(ctx, FATAL, \"Bad column count:\\n",
+                  string q',
+                  string "\");",
+                  newline],
+             string "}",
+             newline,
+             newline,
+
+             string "if ((row = mysql_fetch_row(res)) == NULL) {",
+             newline,
+             box [string "mysql_free_result(res);",
+                  newline,
+                  string "mysql_close(conn->conn);",
+                  newline,
+                  string "uw_error(ctx, FATAL, \"Row fetch failed:\\n",
+                  string q',
+                  string "\");",
+                  newline],
+             string "}",
+             newline,
+             newline,
+
+             string "if (strcmp(row[0], \"",
+             string (Int.toString (length xts)),
+             string "\")) {",
+             newline,
+             box [string "mysql_free_result(res);",
+                  newline,
+                  string "mysql_close(conn->conn);",
+                  newline,
+                  string "uw_error(ctx, FATAL, \"Table '",
+                  string s,
+                  string "' has the wrong column types.\");",
+                  newline],
+             string "}",
+             newline,
+             newline,
+             string "mysql_free_result(res);",
+             newline,
+             newline,
+             
+             string "if (mysql_query(conn->conn, \"",
+             string q'',
+             string "\")) {",
+             newline,
+             box [string "mysql_close(conn->conn);",
+                  newline,
+                  string "uw_error(ctx, FATAL, \"Query failed:\\n",
+                  string q'',
+                  string "\");",
+                  newline],
+             string "}",
+             newline,
+             newline,
+
+             string "if ((res = mysql_store_result(conn->conn)) == NULL) {",
+             newline,
+             box [string "mysql_free_result(res);",
+                  newline,
+                  string "mysql_close(conn->conn);",
+                  newline,
+                  string "uw_error(ctx, FATAL, \"Result store failed:\\n",
+                  string q'',
+                  string "\");",
+                  newline],
+             string "}",
+             newline,
+             newline,
+
+             string "if (mysql_num_fields(res) != 1) {",
+             newline,
+             box [string "mysql_free_result(res);",
+                  newline,
+                  string "mysql_close(conn->conn);",
+                  newline,
+                  string "uw_error(ctx, FATAL, \"Bad column count:\\n",
+                  string q'',
+                  string "\");",
+                  newline],
+             string "}",
+             newline,
+             newline,
+
+             string "if ((row = mysql_fetch_row(res)) == NULL) {",
+             newline,
+             box [string "mysql_free_result(res);",
+                  newline,
+                  string "mysql_close(conn->conn);",
+                  newline,
+                  string "uw_error(ctx, FATAL, \"Row fetch failed:\\n",
+                  string q'',
+                  string "\");",
+                  newline],
+             string "}",
+             newline,
+             newline,
+
+             string "if (strcmp(row[0], \"",
+             string (Int.toString (length xts)),
+             string "\")) {",
+             newline,
+             box [string "mysql_free_result(res);",
+                  newline,
+                  string "mysql_close(conn->conn);",
+                  newline,
+                  string "uw_error(ctx, FATAL, \"Table '",
+                  string s,
+                  string "' has extra columns.\");",
+                  newline],
+             string "}",
+             newline,
+             newline,
+             string "mysql_free_result(res);",
+             newline]
+    end
+
 fun init {dbstring, prepared = ss, tables, views, sequences} =
     let
         val host = ref NONE
@@ -102,8 +374,37 @@
              newline,
              newline,
 
+             string "void uw_client_init(void) {",
+             newline,
+             box [string "if (mysql_library_init(0, NULL, NULL)) {",
+                  newline,
+                  box [string "fprintf(stderr, \"Could not initialize MySQL library\\n\");",
+                       newline,
+                       string "exit(1);",
+                       newline],
+                  string "}",
+                  newline],
+             string "}",
+             newline,
+             newline,
+
              if #persistent (currentProtocol ()) then
-                 box [string "static void uw_db_prepare(uw_context ctx) {",
+                 box [string "static void uw_db_validate(uw_context ctx) {",
+                      newline,
+                      string "uw_conn *conn = uw_get_db(ctx);",
+                      newline,
+                      string "MYSQL_RES *res;",
+                      newline,
+                      string "MYSQL_ROW row;",
+                      newline,
+                      newline,
+                      p_list_sep newline (checkRel ("tables", true)) tables,
+                      p_list_sep newline (checkRel ("views", false)) views,
+                      string "}",
+                      newline,
+                      newline,
+
+                      string "static void uw_db_prepare(uw_context ctx) {",
                       newline,
                       string "uw_conn *conn = uw_get_db(ctx);",
                       newline,
@@ -147,6 +448,10 @@
                                                                uhoh false "Out of memory allocating prepared statement" [],
                                                                string "}",
                                                                newline,
+                                                               string "conn->p",
+                                                               string (Int.toString i),
+                                                               string " = stmt;",
+                                                               newline,
 
                                                                string "if (mysql_stmt_prepare(stmt, \"",
                                                                string (String.toString s),
@@ -162,10 +467,6 @@
                                                                     newline,
                                                                     uhoh true "Error preparing statement: %s" ["msg"]],
                                                                string "}",
-                                                               newline,
-                                                               string "conn->p",
-                                                               string (Int.toString i),
-                                                               string " = stmt;",
                                                                newline]
                                                       end)
                                   ss,
@@ -199,7 +500,7 @@
                | SOME n => string (Int.toString n),
              string ", ",
              stringOf unix_socket,
-             string ", 0)) {",
+             string ", 0) == NULL) {",
              newline,
              box [string "char msg[1024];",
                   newline,
@@ -214,7 +515,7 @@
              newline,
              string "}",
              newline,
-             string "conn = calloc(1, sizeof(conn));",
+             string "conn = calloc(1, sizeof(uw_conn));",
              newline,
              string "conn->conn = mysql;",
              newline,
@@ -471,19 +772,19 @@
 
          string "if (mysql_stmt_execute(stmt)) uw_error(ctx, FATAL, \"",
          string (ErrorMsg.spanToString loc),
-         string ": Error executing query\");",
+         string ": Error executing query: %s\", mysql_error(conn->conn));",
          newline,
          newline,
 
          string "if (mysql_stmt_store_result(stmt)) uw_error(ctx, FATAL, \"",
          string (ErrorMsg.spanToString loc),
-         string ": Error storing query result\");",
+         string ": Error storing query result: %s\", mysql_error(conn->conn));",
          newline,
          newline,
 
          string "if (mysql_stmt_bind_result(stmt, out)) uw_error(ctx, FATAL, \"",
          string (ErrorMsg.spanToString loc),
-         string ": Error binding query result\");",
+         string ": Error binding query result: %s\", mysql_error(conn->conn));",
          newline,
          newline,
 
@@ -496,9 +797,9 @@
          newline,
          newline,
 
-         string "if (r != MYSQL_NO_DATA) uw_error(ctx, FATAL, \"",
+         string "if (r == 1) uw_error(ctx, FATAL, \"",
          string (ErrorMsg.spanToString loc),
-         string ": query result fetching failed\");",
+         string ": query result fetching failed (%d): %s\", r, mysql_error(conn->conn));",
          newline]    
 
 fun query {loc, cols, doCols} =
@@ -514,7 +815,7 @@
          newline,
          string "if (mysql_stmt_prepare(stmt, query, strlen(query))) uw_error(ctx, FATAL, \"",
          string (ErrorMsg.spanToString loc),
-         string "\");",
+         string ": error preparing statement: %s\", mysql_error(conn->conn));",
          newline,
          newline,
 
@@ -760,21 +1061,24 @@
 fun nextval _ = box []
 fun nextvalPrepared _ = box []
 
+fun sqlifyString s = "CAST('" ^ String.translate (fn #"'" => "\\'"
+                                                   | #"\\" => "\\\\"
+                                                   | ch =>
+                                                     if Char.isPrint ch then
+                                                         str ch
+                                                     else
+                                                         (ErrorMsg.error
+                                                              "Non-printing character found in SQL string literal";
+                                                          ""))
+                                                 (String.toString s) ^ "' AS longtext)"
+
+fun p_cast (s, t) = "CAST(" ^ s ^ " AS " ^ p_sql_type t ^ ")"
+
+fun p_blank _ = "?"
+
 val () = addDbms {name = "mysql",
                   header = "mysql/mysql.h",
                   link = "-lmysqlclient",
-                  global_init = box [string "void uw_client_init() {",
-                                     newline,
-                                     box [string "if (mysql_library_init(0, NULL, NULL)) {",
-                                          newline,
-                                          box [string "fprintf(stderr, \"Could not initialize MySQL library\\n\");",
-                                               newline,
-                                               string "exit(1);",
-                                               newline],
-                                          string "}",
-                                          newline],
-                              string "}",
-                                     newline],
                   init = init,
                   p_sql_type = p_sql_type,
                   query = query,
@@ -782,6 +1086,10 @@
                   dml = dml,
                   dmlPrepared = dmlPrepared,
                   nextval = nextval,
-                  nextvalPrepared = nextvalPrepared}
+                  nextvalPrepared = nextvalPrepared,
+                  sqlifyString = sqlifyString,
+                  p_cast = p_cast,
+                  p_blank = p_blank,
+                  supportsDeleteAs = false}
 
 end