00001
00002
00003
00004
00005
00006
00007
00008
00009 #include "ExportMap.h"
00010 #include "DomainInfo.h"
00011 #include "DomainInstance.h"
00012 #include "DomainView.h"
00013 #include "comma/ast/Decl.h"
00014 #include "comma/codegen/CodeGen.h"
00015 #include "comma/codegen/CodeGenCapsule.h"
00016 #include "comma/codegen/CodeGenTypes.h"
00017 #include "comma/codegen/CommaRT.h"
00018
00019 #include "llvm/ADT/IndexedMap.h"
00020 #include "llvm/Module.h"
00021 #include "llvm/Value.h"
00022 #include "llvm/Support/IRBuilder.h"
00023
00024 using namespace comma;
00025
00026 using llvm::dyn_cast;
00027 using llvm::cast;
00028 using llvm::isa;
00029
00030 CommaRT::CommaRT(CodeGen &CG)
00031 : CG(CG),
00032 EM(new ExportMap),
00033 ITableName("_comma_itable_t"),
00034 DomainCtorName("_comma_domain_ctor_t"),
00035
00036 DInfo(0),
00037 DomainInfoPtrTy(0),
00038
00039 DView(0),
00040 DomainViewPtrTy(0),
00041
00042 DInstance(0),
00043 DomainInstancePtrTy(0),
00044
00045 ExportFnPtrTy(getExportFnPtrTy()),
00046 ITablePtrTy(getITablePtrTy()),
00047 DomainCtorPtrTy(0),
00048
00049 GetDomainName("_comma_get_domain")
00050 {
00051 DInfo = new DomainInfo(*this);
00052 DomainInfoPtrTy = DInfo->getPointerTypeTo();
00053
00054 DView = new DomainView(*this);
00055 DomainViewPtrTy = DView->getPointerTypeTo();
00056
00057 DInstance = new DomainInstance(*this);
00058 DomainInstancePtrTy = DInstance->getPointerTypeTo();
00059
00060 DomainCtorPtrTy = DInfo->getCtorPtrType();
00061
00062 DInfo->init();
00063 DView->init();
00064 DInstance->init();
00065
00066 generateRuntimeTypes();
00067 generateRuntimeFunctions();
00068 }
00069
00070 CommaRT::~CommaRT()
00071 {
00072 delete EM;
00073 delete DInfo;
00074 delete DView;
00075 }
00076
00077 const std::string &CommaRT::getTypeName(TypeId id) const
00078 {
00079 switch (id) {
00080 default:
00081 assert(false && "Invalid type id!");
00082 return InvalidName;
00083 case CRT_ITable:
00084 return ITableName;
00085 case CRT_DomainInfo:
00086 return DInfo->getTypeName();
00087 case CRT_DomainView:
00088 return DView->getTypeName();
00089 case CRT_DomainInstance:
00090 return DInstance->getTypeName();
00091 case CRT_DomainCtor:
00092 return DomainCtorName;
00093 }
00094 }
00095
00096 void CommaRT::generateRuntimeTypes()
00097 {
00098
00099 llvm::Module *M = CG.getModule();
00100 M->addTypeName(getTypeName(CRT_DomainInfo), getType<CRT_DomainInfo>());
00101 M->addTypeName(getTypeName(CRT_DomainView), getType<CRT_DomainView>());
00102 M->addTypeName(getTypeName(CRT_DomainInstance), getType<CRT_DomainInstance>());
00103 }
00104
00105 const llvm::PointerType *CommaRT::getDomainCtorPtrTy()
00106 {
00107 std::vector<const llvm::Type*> args;
00108
00109 args.push_back(DomainInstancePtrTy);
00110
00111 const llvm::Type *ctorTy = llvm::FunctionType::get(llvm::Type::VoidTy, args, false);
00112 return CG.getPointerType(ctorTy);
00113 }
00114
00115 const llvm::PointerType *CommaRT::getExportFnPtrTy()
00116 {
00117 std::vector<const llvm::Type*> args;
00118 llvm::Type *ftype;
00119
00120 ftype = llvm::FunctionType::get(llvm::Type::VoidTy, args, false);
00121 return CG.getPointerType(ftype);
00122 }
00123
00124 const llvm::PointerType *CommaRT::getITablePtrTy()
00125 {
00126 return CG.getPointerType(llvm::Type::Int8Ty);
00127 }
00128
00129 void CommaRT::generateRuntimeFunctions()
00130 {
00131 defineGetDomain();
00132 }
00133
00134
00135 void CommaRT::defineGetDomain()
00136 {
00137 const llvm::Type *retTy = getType<CRT_DomainInstance>();
00138 std::vector<const llvm::Type *> args;
00139
00140 args.push_back(getType<CRT_DomainInfo>());
00141
00142
00143
00144
00145 llvm::FunctionType *fnTy = llvm::FunctionType::get(retTy, args, true);
00146
00147 getDomainFn = CG.makeFunction(fnTy, GetDomainName);
00148 }
00149
00150 void CommaRT::registerSignature(const Sigoid *sigoid)
00151 {
00152 EM->addSignature(sigoid);
00153 }
00154
00155 llvm::GlobalVariable *CommaRT::registerCapsule(CodeGenCapsule &CGC)
00156 {
00157 return DInfo->generateInstance(CGC);
00158 }
00159
00160 llvm::Value *CommaRT::getDomain(llvm::IRBuilder<> &builder,
00161 llvm::GlobalValue *capsuleInfo) const
00162 {
00163 return builder.CreateCall(getDomainFn, capsuleInfo);
00164 }
00165
00166 llvm::Value *CommaRT::getDomain(llvm::IRBuilder<> &builder,
00167 std::vector<llvm::Value *> &args) const
00168 {
00169 assert(args.front()->getType() == getType<CRT_DomainInfo>()
00170 && "First argument is not a domain_info_t!");
00171 return builder.CreateCall(getDomainFn, args.begin(), args.end());
00172 }
00173
00174 llvm::Value *CommaRT::getLocalCapsule(llvm::IRBuilder<> &builder,
00175 llvm::Value *percent, unsigned ID) const
00176 {
00177 assert(percent->getType() == DomainInstancePtrTy &&
00178 "Bad type for percent value!");
00179
00180
00181 llvm::Value *elt;
00182 elt = builder.CreateStructGEP(percent, 4);
00183
00184
00185 elt = builder.CreateLoad(elt);
00186
00187
00188 elt = builder.CreateGEP(elt,
00189 llvm::ConstantInt::get(llvm::Type::Int32Ty, ID - 1));
00190
00191
00192 return builder.CreateLoad(elt);
00193 }
00194
00195 unsigned CommaRT::getSignatureOffset(Domoid *domoid, SignatureType *target)
00196 {
00197 typedef SignatureSet::iterator iterator;
00198 SignatureSet &SS = domoid->getSignatureSet();
00199
00200 unsigned offset = 0;
00201 for (iterator iter = SS.begin(); iter != SS.end(); ++iter) {
00202 if (target->equals(*iter))
00203 return offset;
00204 offset++;
00205 }
00206 assert(false && "Could not find target signature!");
00207 return 0;
00208 }
00209
00210 llvm::Value *CommaRT::genAbstractCall(llvm::IRBuilder<> &builder,
00211 llvm::Value *percent,
00212 const SubroutineDecl *srDecl,
00213 const std::vector<llvm::Value *> &args) const
00214 {
00215 const AbstractDomainDecl *param;
00216 const FunctorDecl *context;
00217 SignatureType *target;
00218
00219 param = cast<AbstractDomainDecl>(srDecl->getDeclRegion());
00220 context = cast<FunctorDecl>(param->getDeclRegion());
00221 target = param->getSignatureType();
00222
00223 unsigned paramIdx = context->getFormalIndex(param);
00224
00225
00226
00227
00228 unsigned sigIdx = ExportMap::getSignatureOffset(srDecl);
00229 unsigned exportIdx = EM->getLocalIndex(srDecl);
00230
00231
00232 llvm::Value *view = DInstance->loadParam(builder, percent, paramIdx);
00233
00234
00235
00236 llvm::Value *viewIndexAdr = builder.CreateStructGEP(view, 1);
00237 llvm::Value *viewIndex = builder.CreateLoad(viewIndexAdr);
00238 llvm::Value *sigIndex =
00239 builder.CreateAdd(viewIndex,
00240 llvm::ConstantInt::get(llvm::Type::Int64Ty, sigIdx));
00241
00242
00243
00244
00245 llvm::Value *abstractInstance = DView->loadInstance(builder, view);
00246 llvm::Value *abstractInfo = DInstance->loadInfo(builder, abstractInstance);
00247 llvm::Value *sigOffset = DInfo->indexSigOffset(builder, abstractInfo, sigIndex);
00248 llvm::Value *exportOffset =
00249 builder.CreateAdd(sigOffset,
00250 llvm::ConstantInt::get(llvm::Type::Int64Ty, exportIdx));
00251 llvm::Value *exportFn = DInfo->loadExportFn(builder, abstractInfo, exportOffset);
00252
00253
00254 CodeGenTypes &CGT = CG.getTypeGenerator();
00255 const llvm::FunctionType *funcTy = CGT.lowerType(srDecl->getType());
00256 llvm::PointerType *funcPtrTy = CG.getPointerType(funcTy);
00257 llvm::Value *func = builder.CreateBitCast(exportFn, funcPtrTy);
00258
00259
00260 std::vector<llvm::Value *> arguments;
00261 arguments.push_back(abstractInstance);
00262 arguments.insert(arguments.end(), args.begin(), args.end());
00263 return builder.CreateCall(func, arguments.begin(), arguments.end());
00264 }