diff --git a/src/Language/Json/Type.hs b/src/Language/Json/Type.hs index 672d4bf..9b8d384 100644 --- a/src/Language/Json/Type.hs +++ b/src/Language/Json/Type.hs @@ -2,7 +2,7 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE DerivingVia #-} {-# LANGUAGE DeriveGeneric #-} -module Language.Json.Type (infer, Type(..)) where +module Language.Json.Type (infer, Type(..), isObject) where import Data.Map.Strict (Map) import Data.Text (Text) import Data.Set (Set) @@ -38,15 +38,27 @@ instance Semigroup Type where mergeMatched = (Map.zipWithMatched $ const (<>)) -- merge keys present in both maps unionNull = (Map.mapMissing $ const (<> Null)) -- mark keys missing in either map as nullable in Object $ Map.merge unionNull unionNull mergeMatched fieldsA fieldsB - (Union typesA, Union typesB) -> Union $ Set.union typesA typesB - (Union typesA, _) -> Union $ Set.insert b typesA - (_, Union typesB) -> Union $ Set.insert a typesB - _ -> Union $ Set.fromList [a, b] + (Union typesA, Union typesB) -> Union . mergeUnionObjects $ Set.union typesA typesB + (Union typesA, _) -> Union . mergeUnionObjects $ Set.insert b typesA + (_, Union typesB) -> Union . mergeUnionObjects $ Set.insert a typesB + _ -> Union . mergeUnionObjects $ Set.fromList [a, b] instance Monoid Type where mempty :: Type mempty = All +mergeUnionObjects :: Set Type -> Set Type +mergeUnionObjects set = let + (objects, rest) = Set.partition isObject set + in if Set.null objects + then rest + else Set.insert (Set.foldl (<>) mempty objects) rest + +isObject :: Type -> Bool +isObject = \case + Object _ -> True + _ -> False + infer :: Json.Value -> Type infer = \case Value.Null -> Null