diff src/unpoly.sml @ 399:2d64457eedb1

listFun uses length
author Adam Chlipala <adamc@hcoop.net>
date Tue, 21 Oct 2008 13:41:03 -0400
parents e457d8972ff1
children dc3fc3f3b834
line wrap: on
line diff
--- a/src/unpoly.sml	Tue Oct 21 13:24:54 2008 -0400
+++ b/src/unpoly.sml	Tue Oct 21 13:41:03 2008 -0400
@@ -46,17 +46,18 @@
 val liftConInExp = E.liftConInExp
 val subConInExp = E.subConInExp
 
+val isOpen = U.Con.exists {kind = fn _ => false,
+                           con = fn c =>
+                                    case c of
+                                        CRel _ => true
+                                      | _ => false}
+
 fun unpolyNamed (xn, rep) =
     U.Exp.map {kind = fn k => k,
                con = fn c => c,
                exp = fn e =>
                         case e of
-                            ENamed xn' =>
-                            if xn' = xn then
-                                rep
-                            else
-                                e
-                          | ECApp (e', _) =>
+                            ECApp (e', _) =>
                             let
                                 fun isTheOne (e, _) =
                                     case e of
@@ -65,7 +66,7 @@
                                       | _ => false
                             in
                                 if isTheOne e' then
-                                    #1 e'
+                                    rep
                                 else
                                     e
                             end
@@ -96,71 +97,74 @@
             case unravel (e, []) of
                 NONE => (e, st)
               | SOME (n, cargs) =>
-                case IM.find (#funcs st, n) of
-                    NONE => (e, st)
-                  | SOME (ks, vis) =>
-                    let
-                        val (vis, nextName) = ListUtil.foldlMap
-                                             (fn ((x, n, t, e, s), nextName) =>
-                                                 ((x, nextName, n, t, e, s), nextName + 1))
-                                             (#nextName st) vis
+                if List.exists isOpen cargs then
+                    (e, st)
+                else
+                    case IM.find (#funcs st, n) of
+                        NONE => (e, st)
+                      | SOME (ks, vis) =>
+                        let
+                            val (vis, nextName) = ListUtil.foldlMap
+                                                      (fn ((x, n, t, e, s), nextName) =>
+                                                          ((x, nextName, n, t, e, s), nextName + 1))
+                                                      (#nextName st) vis
 
-                        fun specialize (x, n, n_old, t, e, s) =
-                            let
-                                fun trim (t, e, cargs) =
-                                    case (t, e, cargs) of
-                                        ((TCFun (_, _, t), _),
-                                         (ECAbs (_, _, e), _),
-                                         carg :: cargs) =>
-                                        let
-                                            val t = subConInCon (length cargs, carg) t
-                                            val e = subConInExp (length cargs, carg) e
-                                        in
-                                            trim (t, e, cargs)
-                                        end
-                                      | (_, _, []) =>
-                                        let
-                                            val e = foldl (fn ((_, n, n_old, _, _, _), e) =>
-                                                              unpolyNamed (n_old, ENamed n) e)
-                                                          e vis
-                                        in
-                                            SOME (t, e)
-                                        end
-                                      | _ => NONE
-                            in
-                                (*Print.prefaces "specialize"
-                                               [("t", CorePrint.p_con CoreEnv.empty t),
-                                                ("e", CorePrint.p_exp CoreEnv.empty e),
-                                                ("|cargs|", Print.PD.string (Int.toString (length cargs)))];*)
-                                Option.map (fn (t, e) => (x, n, n_old, t, e, s))
-                                (trim (t, e, cargs))
-                            end
+                            fun specialize (x, n, n_old, t, e, s) =
+                                let
+                                    fun trim (t, e, cargs) =
+                                        case (t, e, cargs) of
+                                            ((TCFun (_, _, t), _),
+                                             (ECAbs (_, _, e), _),
+                                             carg :: cargs) =>
+                                            let
+                                                val t = subConInCon (length cargs, carg) t
+                                                val e = subConInExp (length cargs, carg) e
+                                            in
+                                                trim (t, e, cargs)
+                                            end
+                                          | (_, _, []) =>
+                                            let
+                                                val e = foldl (fn ((_, n, n_old, _, _, _), e) =>
+                                                                  unpolyNamed (n_old, ENamed n) e)
+                                                              e vis
+                                            in
+                                                SOME (t, e)
+                                            end
+                                          | _ => NONE
+                                in
+                                    (*Print.prefaces "specialize"
+                                                     [("t", CorePrint.p_con CoreEnv.empty t),
+                                                      ("e", CorePrint.p_exp CoreEnv.empty e),
+                                                      ("|cargs|", Print.PD.string (Int.toString (length cargs)))];*)
+                                    Option.map (fn (t, e) => (x, n, n_old, t, e, s))
+                                               (trim (t, e, cargs))
+                                end
 
-                        val vis = List.map specialize vis
-                    in
-                        if List.exists (not o Option.isSome) vis orelse length cargs > length ks then
-                            (e, st)
-                        else
-                            let
-                                val vis = List.mapPartial (fn x => x) vis
-                                val vis = map (fn (x, n, n_old, t, e, s) =>
-                                                   (x ^ "_unpoly", n, n_old, t, e, s)) vis
-                                val vis' = map (fn (x, n, _, t, e, s) =>
-                                                   (x, n, t, e, s)) vis
+                            val vis = List.map specialize vis
+                        in
+                            if List.exists (not o Option.isSome) vis orelse length cargs > length ks then
+                                (e, st)
+                            else
+                                let
+                                    val vis = List.mapPartial (fn x => x) vis
+                                    val vis = map (fn (x, n, n_old, t, e, s) =>
+                                                      (x ^ "_unpoly", n, n_old, t, e, s)) vis
+                                    val vis' = map (fn (x, n, _, t, e, s) =>
+                                                       (x, n, t, e, s)) vis
 
-                                val ks' = List.drop (ks, length cargs)
-                            in
-                                case List.find (fn (_, _, n_old, _, _, _) => n_old = n) vis of
-                                    NONE => raise Fail "Unpoly: Inconsistent 'val rec' record"
-                                  | SOME (_, n, _, _, _, _) =>
-                                    (ENamed n,
-                                     {funcs = foldl (fn (vi, funcs) =>
-                                                        IM.insert (funcs, #2 vi, (ks', vis')))
-                                                    (#funcs st) vis',
-                                      decls = (DValRec vis', ErrorMsg.dummySpan) :: #decls st,
-                                      nextName = nextName})
-                            end
-                    end
+                                    val ks' = List.drop (ks, length cargs)
+                                in
+                                    case List.find (fn (_, _, n_old, _, _, _) => n_old = n) vis of
+                                        NONE => raise Fail "Unpoly: Inconsistent 'val rec' record"
+                                      | SOME (_, n, _, _, _, _) =>
+                                        (ENamed n,
+                                         {funcs = foldl (fn (vi, funcs) =>
+                                                            IM.insert (funcs, #2 vi, (ks', vis')))
+                                                        (#funcs st) vis',
+                                          decls = (DValRec vis', ErrorMsg.dummySpan) :: #decls st,
+                                          nextName = nextName})
+                                end
+                        end
         end
       | _ => (e, st)