diff --git a/src/Language/Java/Classfile/Flags.hs b/src/Language/Java/Classfile/Flags.hs index 9a2e549..b52c725 100644 --- a/src/Language/Java/Classfile/Flags.hs +++ b/src/Language/Java/Classfile/Flags.hs @@ -6,10 +6,11 @@ {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE TypeFamilies #-} -module Language.Java.Classfile.Flags (Flags(..), FlagMask(..)) where +{-# LANGUAGE DefaultSignatures #-} +module Language.Java.Classfile.Flags (Flags(..), FlagMask(..), containsFlag) where -import Data.Bits (Bits((.&.))) +import Data.Bits (Bits((.&.), zeroBits)) import Data.Enum.Util (enumerate) import Data.Set (Set) @@ -18,6 +19,8 @@ import qualified Data.Set as Set import Language.Java.Classfile.Extract (Extract) import Language.Java.Classfile.Extractable (Extractable (extract)) import Data.Kind (Type) +import Control.Arrow ((>>>)) +import qualified Data.List as List -- | Using the 'FlagMask' instance of the type parameter, this will extract all the flags whose mask produced a non-zero value using '.&.' @@ -39,3 +42,13 @@ class FlagMask a where type FlagType a :: Type maskOf :: a -> FlagType a + ofMask :: FlagType a -> Set a + + default ofMask :: (Enum a, Bounded a, Ord a, Bits (FlagType a)) => FlagType a -> Set a + ofMask mask = List.filter (containsFlag mask) + >>> Set.fromList + $ enumerate @a + +containsFlag :: (Bits (FlagType a), FlagMask a) => FlagType a -> a -> Bool +containsFlag mask flag = mask .&. maskOf flag /= zeroBits +