module LLVM.Analysis.ClassHierarchy (
CHA,
VTable,
resolveVirtualCallee,
classSubtypes,
classTransitiveSubtypes,
classParents,
classAncestors,
classVTable,
functionAtSlot,
runCHA,
classHierarchyToTestFormat
) where
import ABI.Itanium
import Data.Foldable ( foldMap, toList )
import Data.Generics.Uniplate.Data
import Data.List ( stripPrefix )
import Data.Map ( Map )
import qualified Data.Map as M
import Data.Maybe ( fromMaybe, mapMaybe )
import Data.Monoid
import Data.Set ( Set )
import qualified Data.Set as S
import Data.Vector ( Vector, (!?) )
import qualified Data.Vector as V
import LLVM.Analysis hiding ( (!?) )
import LLVM.Analysis.Util.Names
data CHA = CHA { childrenMap :: Map Name (Set Name)
, parentMap :: Map Name (Set Name)
, vtblMap :: Map Name VTable
, typeMapping :: Map Name Type
, chaModule :: Module
}
data VTable = ExternalVTable
| VTable (Vector Function)
deriving (Show)
resolveVirtualCallee :: CHA -> Instruction -> Maybe [Function]
resolveVirtualCallee cha i =
case i of
CallInst { callFunction = (valueContent' -> FunctionC f) } -> Just [f]
CallInst { callFunction = (valueContent' -> InstructionC LoadInst { loadAddress = la })
, callArguments = (thisVal, _) : _
} ->
virtualDispatch cha la thisVal
InvokeInst { invokeFunction = (valueContent' -> FunctionC f) } -> Just [f]
InvokeInst { invokeFunction = (valueContent' -> InstructionC LoadInst { loadAddress = la })
, invokeArguments = (thisVal, _) : _
} ->
virtualDispatch cha la thisVal
_ -> Nothing
virtualDispatch :: CHA -> Value -> Value -> Maybe [Function]
virtualDispatch cha loadAddr thisVal = do
slotNumber <- getVFuncSlot cha loadAddr thisVal
return $! mapMaybe (functionAtSlot slotNumber) vtbls
where
TypePointer thisType _ = valueType thisVal
derivedTypes = classTransitiveSubtypes cha thisType
vtbls = mapMaybe (classVTable cha) derivedTypes
getVFuncSlot :: CHA -> Value -> Value -> Maybe Int
getVFuncSlot cha loadAddr thisArg =
case valueContent loadAddr of
InstructionC GetElementPtrInst {
getElementPtrIndices = [valueContent -> ConstantC ConstantInt { constantIntValue = slotNo }],
getElementPtrValue =
(valueContent -> InstructionC LoadInst {
loadAddress =
(valueContent -> InstructionC BitcastInst {
castedValue = thisPtr
})})} ->
case thisArg == thisPtr of
True -> return $! fromIntegral slotNo
False -> Nothing
InstructionC LoadInst {
loadAddress = (valueContent -> InstructionC BitcastInst {
castedValue = base})} ->
case thisArg == base of
True -> return 0
False -> Nothing
InstructionC LoadInst {
loadAddress =
(valueContent -> InstructionC GetElementPtrInst {
getElementPtrIndices = [ valueContent -> ConstantC ConstantInt { constantIntValue = 0 }
, valueContent -> ConstantC ConstantInt { constantIntValue = 0 }
],
getElementPtrValue = thisPtr})} ->
case thisArg == thisPtr of
True -> return 0
False -> Nothing
InstructionC BitcastInst {
castedValue =
(valueContent -> InstructionC GetElementPtrInst {
getElementPtrIndices = [valueContent -> ConstantC ConstantInt { constantIntValue = offset }],
getElementPtrValue =
(valueContent -> InstructionC BitcastInst {
castedValue =
(valueContent -> InstructionC LoadInst {
loadAddress =
(valueContent -> InstructionC GetElementPtrInst {
getElementPtrIndices = [ valueContent -> ConstantC ConstantInt { constantIntValue = 0 }
, valueContent -> ConstantC ConstantInt { constantIntValue = 0 }
],
getElementPtrValue = thisPtr})})})})} ->
case thisArg == thisPtr of
True -> Just $! indexFromOffset cha (fromIntegral offset)
False -> Nothing
_ -> Nothing
indexFromOffset :: CHA -> Int -> Int
indexFromOffset cha bytes = (bytes * 8) `div` pointerBits
where
m = chaModule cha
targetData = moduleDataLayout m
pointerBits = alignmentPrefSize (targetPointerPrefs targetData)
classSubtypes :: CHA -> Type -> [Type]
classSubtypes cha t =
namesToTypes cha (M.findWithDefault mempty (typeToName t) (childrenMap cha))
classTransitiveSubtypes :: CHA -> Type -> [Type]
classTransitiveSubtypes = transitiveTypes childrenMap
classParents :: CHA -> Type -> [Type]
classParents cha t =
namesToTypes cha (M.findWithDefault mempty (typeToName t) (parentMap cha))
classAncestors :: CHA -> Type -> [Type]
classAncestors = transitiveTypes parentMap
transitiveTypes :: (CHA -> Map Name (Set Name)) -> CHA -> Type -> [Type]
transitiveTypes selector cha t0 =
namesToTypes cha (go (S.singleton (typeToName t0)))
where
go ts =
let nextLevel = foldMap getParents ts
in case mempty == nextLevel of
True -> ts
False -> go nextLevel `mappend` ts
getParents t = M.findWithDefault mempty t (selector cha)
classVTable :: CHA -> Type -> Maybe VTable
classVTable cha t = M.lookup (typeToName t) (vtblMap cha)
functionAtSlot :: Int -> VTable -> Maybe Function
functionAtSlot _ ExternalVTable = Nothing
functionAtSlot slot (VTable v) = v !? slot
runCHA :: Module -> CHA
runCHA m = foldr buildTypeMap cha1 ctors
where
gvs = moduleGlobalVariables m
ctors = moduleConstructors m
cha0 = CHA mempty mempty mempty mempty m
cha1 = foldr recordParents cha0 gvs
moduleConstructors :: Module -> [Function]
moduleConstructors = filter isC2Constructor . moduleDefinedFunctions
buildTypeMap :: Function -> CHA -> CHA
buildTypeMap f cha =
case parseTypeName fname of
Left e -> error ("LLVM.Analysis.ClassHierarchy.buildTypeMap: " ++ e)
Right n ->
cha { typeMapping = M.insert n t (typeMapping cha) }
where
t = constructedType f
fname = case t of
TypeStruct (Just tn) _ _ -> stripNamePrefix tn
_ -> error ("LLVM.Analysis.ClassHierarchy.buildTypeMap: Expected class type: " ++ show t)
recordParents :: GlobalVariable -> CHA -> CHA
recordParents gv acc =
case dname of
Left _ -> acc
Right structuredName ->
case structuredName of
VirtualTable (ClassEnumType typeName) ->
recordVTable acc typeName (globalVariableInitializer gv)
VirtualTable tn -> error ("LLVM.Analysis.ClassHierarchy.recordParents: Expected a class name for virtual table: " ++ show tn)
TypeInfo (ClassEnumType typeName) ->
recordTypeInfo acc typeName (globalVariableInitializer gv)
TypeInfo tn -> error ("LLVM.Analysis.ClassHierarchy.recordParents: Expected a class name for typeinfo: " ++ show tn)
_ -> acc
where
n = identifierAsString (globalVariableName gv)
dname = demangleName n
recordVTable :: CHA -> Name -> Maybe Value -> CHA
recordVTable cha typeName Nothing =
cha { vtblMap = M.insert typeName ExternalVTable (vtblMap cha) }
recordVTable cha typeName (Just v) =
case valueContent' v of
ConstantC (ConstantArray _ _ vs) ->
cha { vtblMap = M.insert typeName (makeVTable vs) (vtblMap cha) }
_ -> recordVTable cha typeName Nothing
makeVTable :: [Value] -> VTable
makeVTable =
VTable . V.fromList . map unsafeToFunction . takeWhile isVTableFunctionType . dropWhile (not . isVTableFunctionType)
unsafeToFunction :: Value -> Function
unsafeToFunction v =
case valueContent' v of
FunctionC f -> f
_ -> error ("LLVM.Analysis.ClassHierarchy.unsafeToFunction: Expected vtable function entry: " ++ show v)
isVTableFunctionType :: Value -> Bool
isVTableFunctionType v =
case valueContent' v of
FunctionC _ -> True
_ -> False
recordTypeInfo :: CHA -> Name -> Maybe Value -> CHA
recordTypeInfo cha _ Nothing = cha
recordTypeInfo cha name (Just tbl) =
case valueContent tbl of
ConstantC (ConstantStruct _ _ vs) ->
let parentClassNames = mapMaybe toParentClassName vs
in cha { parentMap = M.insertWith' S.union name (S.fromList parentClassNames) (parentMap cha)
, childrenMap = foldr (addChild name) (childrenMap cha) parentClassNames
}
_ -> error ("LLVM.Analysis.ClassHierarchy.recordTypeInfo: Expected typeinfo literal " ++ show tbl)
toParentClassName :: Value -> Maybe Name
toParentClassName v =
case valueContent v of
ConstantC ConstantValue {
constantInstruction = BitcastInst {
castedValue = (valueContent -> GlobalVariableC GlobalVariable {
globalVariableName = gvn })}} ->
case demangleName (identifierAsString gvn) of
Left _ -> Nothing
Right (TypeInfo (ClassEnumType n)) -> Just n
_ -> Nothing
_ -> Nothing
addChild :: Name -> Name -> Map Name (Set Name) -> Map Name (Set Name)
addChild thisType parentType =
M.insertWith' S.union parentType (S.singleton thisType)
constructedType :: Function -> Type
constructedType f =
case map argumentType $ functionParameters f of
TypePointer t@(TypeStruct (Just _) _ _) _ : _ -> t
t -> error ("LLVM.Analysis.ClassHierarchy.constructedType: Expected pointer to struct type: " ++ show t)
isC2Constructor :: Function -> Bool
isC2Constructor f =
case dname of
Left _ -> False
Right structuredName ->
case universeBi structuredName of
[C2] -> True
_ -> False
where
n = identifierAsString (functionName f)
dname = demangleName n
stripPrefix' :: String -> String -> String
stripPrefix' pfx s = fromMaybe s (stripPrefix pfx s)
stripNamePrefix :: String -> String
stripNamePrefix =
stripPrefix' "struct." . stripPrefix' "class."
typeToName :: Type -> Name
typeToName (TypeStruct (Just n) _ _) =
case parseTypeName (stripNamePrefix n) of
Right tn -> tn
Left e -> error ("LLVM.Analysis.ClassHierarchy.typeToName: " ++ e)
typeToName t = error ("LLVM.Analysis.ClassHierarchy.typeToName: Expected named struct type: " ++ show t)
nameToString :: Name -> String
nameToString n = fromMaybe errMsg (unparseTypeName n)
where
errMsg = error ("Could not encode name as string: " ++ show n)
nameToType :: CHA -> Name -> Type
nameToType cha n = M.findWithDefault errMsg n (typeMapping cha)
where
errMsg = error ("Expected name in typeMapping for CHA: " ++ show n)
namesToTypes :: CHA -> Set Name -> [Type]
namesToTypes cha = map (nameToType cha) . toList
classHierarchyToTestFormat :: CHA -> Map String (Set String)
classHierarchyToTestFormat cha =
foldr mapify mempty (M.toList (childrenMap cha))
where
mapify (ty, subtypes) =
let ss = S.map nameToString subtypes
in M.insertWith S.union (nameToString ty) ss