changeset 275:73456bfde988

Validating schema of a live database
author Adam Chlipala <adamc@hcoop.net>
date Tue, 02 Sep 2008 14:40:57 -0400
parents e4baf03a3a64
children ed4af33681d8
files src/cjr_print.sml src/expl_env.sml src/list_util.sig src/list_util.sml src/print.sig src/print.sml
diffstat 6 files changed, 233 insertions(+), 19 deletions(-) [+]
line wrap: on
line diff
--- a/src/cjr_print.sml	Tue Sep 02 13:44:54 2008 -0400
+++ b/src/cjr_print.sml	Tue Sep 02 14:40:57 2008 -0400
@@ -692,9 +692,10 @@
                               string x,
                               string " */",
                               newline]
-      | DDatabase s => box [string "void lw_db_init(lw_context ctx) {",
+      | DDatabase s => box [string "static void lw_db_validate(lw_context);",
                             newline,
-                            string "PGresult *res;",
+                            newline,
+                            string "void lw_db_init(lw_context ctx) {",
                             newline,
                             string "PGconn *conn = PQconnectdb(\"",
                             string (String.toString s),
@@ -720,6 +721,8 @@
                             newline,
                             string "lw_set_db(ctx, conn);",
                             newline,
+                            string "lw_db_validate(ctx);",
+                            newline,
                             string "}",
                             newline,
                             newline,
@@ -735,6 +738,17 @@
        | NotFound
        | Error
 
+fun p_sqltype' env (tAll as (t, loc)) =
+    case t of
+        TFfi ("Basis", "int") => "int8"
+      | TFfi ("Basis", "float") => "float8"
+      | TFfi ("Basis", "string") => "text"
+      | TFfi ("Basis", "bool") => "bool"
+      | _ => (ErrorMsg.errorAt loc "Don't know SQL equivalent of type";
+              Print.eprefaces' [("Type", p_typ env tAll)];
+              "ERROR")
+
+fun p_sqltype env t = string (p_sqltype' env t)
 
 fun p_file env (ds, ps) =
     let
@@ -1204,6 +1218,195 @@
             end
 
         val pds' = map p_page ps
+
+        val tables = List.mapPartial (fn (DTable (s, xts), _) => SOME (s, xts)
+                                       | _ => NONE) ds
+
+        val validate =
+            box [string "static void lw_db_validate(lw_context ctx) {",
+                 newline,
+                 string "PGconn *conn = lw_get_db(ctx);",
+                 newline,
+                 string "PGresult *res;",
+                 newline,
+                 newline,
+                 p_list_sep newline
+                            (fn (s, xts) =>
+                                let
+                                    val q = "SELECT COUNT(*) FROM pg_class WHERE relname = '"
+                                            ^ s ^ "'"
+
+                                    val q' = String.concat ["SELECT COUNT(*) FROM pg_attribute WHERE attrelid = (SELECT oid FROM pg_class WHERE relname = '",
+                                                            s,
+                                                            "') AND (",
+                                                            String.concatWith " OR "
+                                                              (map (fn (x, t) =>
+                                                                       String.concat ["(attname = 'lw_",
+                                                                                      CharVector.map
+                                                                                          Char.toLower x,
+                                                                                      "' AND atttypid = (SELECT oid FROM pg_type",
+                                                                                      " WHERE typname = '",
+                                                                                      p_sqltype' env t,
+                                                                                      "'))"]) xts),
+                                                            ")"]
+
+                                    val q'' = String.concat ["SELECT COUNT(*) FROM pg_attribute WHERE attrelid = (SELECT oid FROM pg_class WHERE relname = '",
+                                                             s,
+                                                             "') AND attnum >= 0"]
+                                in
+                                    box [string "res = PQexec(conn, \"",
+                                         string q,
+                                         string "\");",
+                                         newline,
+                                         newline,
+                                         string "if (res == NULL) {",
+                                         newline,
+                                         box [string "PQfinish(conn);",
+                                              newline,
+                                              string "lw_error(ctx, FATAL, \"Out of memory allocating query result.\");",
+                                              newline],
+                                         string "}",
+                                         newline,
+                                         newline,
+                                         string "if (PQresultStatus(res) != PGRES_TUPLES_OK) {",
+                                         newline,
+                                         box [string "char msg[1024];",
+                                              newline,
+                                              string "strncpy(msg, PQerrorMessage(conn), 1024);",
+                                              newline,
+                                              string "msg[1023] = 0;",
+                                              newline,
+                                              string "PQclear(res);",
+                                              newline,
+                                              string "PQfinish(conn);",
+                                              newline,
+                                              string "lw_error(ctx, FATAL, \"Query failed:\\n",
+                                              string q,
+                                              string "\\n%s\", msg);",
+                                              newline],
+                                         string "}",
+                                         newline,
+                                         newline,
+                                         string "if (strcmp(PQgetvalue(res, 0, 0), \"1\")) {",
+                                         newline,
+                                         box [string "PQclear(res);",
+                                              newline,
+                                              string "PQfinish(conn);",
+                                              newline,
+                                              string "lw_error(ctx, FATAL, \"Table '",
+                                              string s,
+                                              string "' does not exist.\");",
+                                              newline],
+                                         string "}",
+                                         newline,
+                                         newline,
+                                         string "PQclear(res);",
+                                         newline,
+
+                                         string "res = PQexec(conn, \"",
+                                         string q',
+                                         string "\");",
+                                         newline,
+                                         newline,
+                                         string "if (res == NULL) {",
+                                         newline,
+                                         box [string "PQfinish(conn);",
+                                              newline,
+                                              string "lw_error(ctx, FATAL, \"Out of memory allocating query result.\");",
+                                              newline],
+                                         string "}",
+                                         newline,
+                                         newline,
+                                         string "if (PQresultStatus(res) != PGRES_TUPLES_OK) {",
+                                         newline,
+                                         box [string "char msg[1024];",
+                                              newline,
+                                              string "strncpy(msg, PQerrorMessage(conn), 1024);",
+                                              newline,
+                                              string "msg[1023] = 0;",
+                                              newline,
+                                              string "PQclear(res);",
+                                              newline,
+                                              string "PQfinish(conn);",
+                                              newline,
+                                              string "lw_error(ctx, FATAL, \"Query failed:\\n",
+                                              string q',
+                                              string "\\n%s\", msg);",
+                                              newline],
+                                         string "}",
+                                         newline,
+                                         newline,
+                                         string "if (strcmp(PQgetvalue(res, 0, 0), \"",
+                                         string (Int.toString (length xts)),
+                                         string "\")) {",
+                                         newline,
+                                         box [string "PQclear(res);",
+                                              newline,
+                                              string "PQfinish(conn);",
+                                              newline,
+                                              string "lw_error(ctx, FATAL, \"Table '",
+                                              string s,
+                                              string "' has the wrong column types.\");",
+                                              newline],
+                                         string "}",
+                                         newline,
+                                         newline,
+                                         string "PQclear(res);",
+                                         newline,
+                                         newline,
+
+                                         string "res = PQexec(conn, \"",
+                                         string q'',
+                                         string "\");",
+                                         newline,
+                                         newline,
+                                         string "if (res == NULL) {",
+                                         newline,
+                                         box [string "PQfinish(conn);",
+                                              newline,
+                                              string "lw_error(ctx, FATAL, \"Out of memory allocating query result.\");",
+                                              newline],
+                                         string "}",
+                                         newline,
+                                         newline,
+                                         string "if (PQresultStatus(res) != PGRES_TUPLES_OK) {",
+                                         newline,
+                                         box [string "char msg[1024];",
+                                              newline,
+                                              string "strncpy(msg, PQerrorMessage(conn), 1024);",
+                                              newline,
+                                              string "msg[1023] = 0;",
+                                              newline,
+                                              string "PQclear(res);",
+                                              newline,
+                                              string "PQfinish(conn);",
+                                              newline,
+                                              string "lw_error(ctx, FATAL, \"Query failed:\\n",
+                                              string q'',
+                                              string "\\n%s\", msg);",
+                                              newline],
+                                         string "}",
+                                         newline,
+                                         newline,
+                                         string "if (strcmp(PQgetvalue(res, 0, 0), \"",
+                                         string (Int.toString (length xts)),
+                                         string "\")) {",
+                                         newline,
+                                         box [string "PQclear(res);",
+                                              newline,
+                                              string "PQfinish(conn);",
+                                              newline,
+                                              string "lw_error(ctx, FATAL, \"Table '",
+                                              string s,
+                                              string "' has extra columns.\");",
+                                              newline],
+                                         string "}",
+                                         newline,
+                                         newline,
+                                         string "PQclear(res);",
+                                         newline]
+                                end) tables,
+                 string "}"]
     in
         box [string "#include <stdio.h>",
              newline,
@@ -1235,23 +1438,12 @@
              p_list_sep newline (fn x => x) pds',
              newline,
              string "}",
+             newline,
+             newline,
+             validate,
              newline]
     end
 
-fun p_sqltype env (tAll as (t, loc)) =
-    let
-        val s = case t of
-                    TFfi ("Basis", "int") => "int8"
-                  | TFfi ("Basis", "float") => "float8"
-                  | TFfi ("Basis", "string") => "text"
-                  | TFfi ("Basis", "bool") => "bool"
-                  | _ => (ErrorMsg.errorAt loc "Don't know SQL equivalent of type";
-                          Print.eprefaces' [("Type", p_typ env tAll)];
-                          "ERROR")
-    in
-        string s
-    end
-
 fun p_sql env (ds, _) =
     let
         val (pps, _) = ListUtil.foldlMap
@@ -1264,9 +1456,7 @@
                                                  string "(",
                                                  p_list (fn (x, t) =>
                                                             box [string "lw_",
-                                                                 string x,
-                                                                 space,
-                                                                 string ":",
+                                                                 string (CharVector.map Char.toLower x),
                                                                  space,
                                                                  p_sqltype env t,
                                                                  space,
--- a/src/expl_env.sml	Tue Sep 02 13:44:54 2008 -0400
+++ b/src/expl_env.sml	Tue Sep 02 14:40:57 2008 -0400
@@ -288,6 +288,7 @@
         in
             pushENamed env x n t
         end
+      | DDatabase _ => env
 
 fun sgiBinds env (sgi, loc) =
     case sgi of
--- a/src/list_util.sig	Tue Sep 02 13:44:54 2008 -0400
+++ b/src/list_util.sig	Tue Sep 02 14:40:57 2008 -0400
@@ -42,5 +42,6 @@
 
     val mapi : (int * 'a -> 'b) -> 'a list -> 'b list
     val foldli : (int * 'a * 'b -> 'b) -> 'b -> 'a list -> 'b
+    val foldri : (int * 'a * 'b -> 'b) -> 'b -> 'a list -> 'b
 
 end
--- a/src/list_util.sml	Tue Sep 02 13:44:54 2008 -0400
+++ b/src/list_util.sml	Tue Sep 02 14:40:57 2008 -0400
@@ -156,4 +156,11 @@
         m 0
     end
 
+fun foldri f i ls =
+    let
+        val len = length ls
+    in
+        foldli (fn (n, x, s) => f (len - n - 1, x, s)) i (rev ls)
+    end
+
 end
--- a/src/print.sig	Tue Sep 02 13:44:54 2008 -0400
+++ b/src/print.sig	Tue Sep 02 14:40:57 2008 -0400
@@ -42,6 +42,8 @@
     val p_list_sep : PD.pp_desc -> 'a printer -> 'a list printer
     val p_list : 'a printer -> 'a list printer
 
+    val p_list_sepi : PD.pp_desc -> (int -> 'a printer) -> 'a list printer
+
     val fprint : PD.PPS.stream -> PD.pp_desc -> unit
     val print : PD.pp_desc -> unit
     val eprint : PD.pp_desc -> unit
--- a/src/print.sml	Tue Sep 02 13:44:54 2008 -0400
+++ b/src/print.sml	Tue Sep 02 14:40:57 2008 -0400
@@ -59,6 +59,19 @@
         end
 fun p_list f = p_list_sep (box [PD.string ",", space]) f
 
+fun p_list_sepi sep f ls =
+    case ls of
+        [] => PD.string ""
+      | [x] => f 0 x
+      | x :: rest =>
+        let
+            val tokens = ListUtil.foldri (fn (n, x, tokens) =>
+                                             sep :: PD.cut :: f (n + 1) x :: tokens)
+                                         [] rest
+        in
+            box (f 0 x :: tokens)
+        end
+
 fun fprint f d = (PD.description (f, d);
                   PD.PPS.flushStream f)
 val print = fprint out