LLVM API Documentation

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
SITypeRewriter.cpp
Go to the documentation of this file.
1 //===-- SITypeRewriter.cpp - Remove unwanted types ------------------------===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 /// \file
11 /// This pass removes performs the following type substitution on all
12 /// non-compute shaders:
13 ///
14 /// v16i8 => i128
15 /// - v16i8 is used for constant memory resource descriptors. This type is
16 /// legal for some compute APIs, and we don't want to declare it as legal
17 /// in the backend, because we want the legalizer to expand all v16i8
18 /// operations.
19 /// v1* => *
20 /// - Having v1* types complicates the legalizer and we can easily replace
21 /// - them with the element type.
22 //===----------------------------------------------------------------------===//
23 
24 #include "AMDGPU.h"
25 
26 #include "llvm/IR/IRBuilder.h"
27 #include "llvm/InstVisitor.h"
28 
29 using namespace llvm;
30 
31 namespace {
32 
33 class SITypeRewriter : public FunctionPass,
34  public InstVisitor<SITypeRewriter> {
35 
36  static char ID;
37  Module *Mod;
38  Type *v16i8;
39  Type *i128;
40 
41 public:
42  SITypeRewriter() : FunctionPass(ID) { }
43  virtual bool doInitialization(Module &M);
44  virtual bool runOnFunction(Function &F);
45  virtual const char *getPassName() const {
46  return "SI Type Rewriter";
47  }
48  void visitLoadInst(LoadInst &I);
49  void visitCallInst(CallInst &I);
50  void visitBitCast(BitCastInst &I);
51 };
52 
53 } // End anonymous namespace
54 
55 char SITypeRewriter::ID = 0;
56 
57 bool SITypeRewriter::doInitialization(Module &M) {
58  Mod = &M;
59  v16i8 = VectorType::get(Type::getInt8Ty(M.getContext()), 16);
60  i128 = Type::getIntNTy(M.getContext(), 128);
61  return false;
62 }
63 
64 bool SITypeRewriter::runOnFunction(Function &F) {
65  AttributeSet Set = F.getAttributes();
66  Attribute A = Set.getAttribute(AttributeSet::FunctionIndex, "ShaderType");
67 
68  unsigned ShaderType = ShaderType::COMPUTE;
69  if (A.isStringAttribute()) {
70  StringRef Str = A.getValueAsString();
71  Str.getAsInteger(0, ShaderType);
72  }
73  if (ShaderType != ShaderType::COMPUTE) {
74  visit(F);
75  }
76 
77  visit(F);
78 
79  return false;
80 }
81 
82 void SITypeRewriter::visitLoadInst(LoadInst &I) {
83  Value *Ptr = I.getPointerOperand();
84  Type *PtrTy = Ptr->getType();
85  Type *ElemTy = PtrTy->getPointerElementType();
86  IRBuilder<> Builder(&I);
87  if (ElemTy == v16i8) {
88  Value *BitCast = Builder.CreateBitCast(Ptr, Type::getIntNPtrTy(I.getContext(), 128, 2));
89  LoadInst *Load = Builder.CreateLoad(BitCast);
92  for (unsigned i = 0, e = MD.size(); i != e; ++i) {
93  Load->setMetadata(MD[i].first, MD[i].second);
94  }
95  Value *BitCastLoad = Builder.CreateBitCast(Load, I.getType());
96  I.replaceAllUsesWith(BitCastLoad);
97  I.eraseFromParent();
98  }
99 }
100 
101 void SITypeRewriter::visitCallInst(CallInst &I) {
102  IRBuilder<> Builder(&I);
105  bool NeedToReplace = false;
106  Function *F = I.getCalledFunction();
107  std::string Name = F->getName().str();
108  for (unsigned i = 0, e = I.getNumArgOperands(); i != e; ++i) {
109  Value *Arg = I.getArgOperand(i);
110  if (Arg->getType() == v16i8) {
111  Args.push_back(Builder.CreateBitCast(Arg, i128));
112  Types.push_back(i128);
113  NeedToReplace = true;
114  Name = Name + ".i128";
115  } else if (Arg->getType()->isVectorTy() &&
116  Arg->getType()->getVectorNumElements() == 1 &&
117  Arg->getType()->getVectorElementType() ==
119  Type *ElementTy = Arg->getType()->getVectorElementType();
120  std::string TypeName = "i32";
122  assert(Def);
123  Args.push_back(Def->getOperand(1));
124  Types.push_back(ElementTy);
125  std::string VecTypeName = "v1" + TypeName;
126  Name = Name.replace(Name.find(VecTypeName), VecTypeName.length(), TypeName);
127  NeedToReplace = true;
128  } else {
129  Args.push_back(Arg);
130  Types.push_back(Arg->getType());
131  }
132  }
133 
134  if (!NeedToReplace) {
135  return;
136  }
137  Function *NewF = Mod->getFunction(Name);
138  if (!NewF) {
139  NewF = Function::Create(FunctionType::get(F->getReturnType(), Types, false), GlobalValue::ExternalLinkage, Name, Mod);
140  NewF->setAttributes(F->getAttributes());
141  }
142  I.replaceAllUsesWith(Builder.CreateCall(NewF, Args));
143  I.eraseFromParent();
144 }
145 
146 void SITypeRewriter::visitBitCast(BitCastInst &I) {
147  IRBuilder<> Builder(&I);
148  if (I.getDestTy() != i128) {
149  return;
150  }
151 
152  if (BitCastInst *Op = dyn_cast<BitCastInst>(I.getOperand(0))) {
153  if (Op->getSrcTy() == i128) {
154  I.replaceAllUsesWith(Op->getOperand(0));
155  I.eraseFromParent();
156  }
157  }
158 }
159 
161  return new SITypeRewriter();
162 }
Base class for instruction visitors.
Definition: InstVisitor.h:81
The main container class for the LLVM Intermediate Representation.
Definition: Module.h:112
enable_if_c<!is_simple_type< Y >::value, typename cast_retty< X, const Y >::ret_type >::type dyn_cast(const Y &Val)
Definition: Casting.h:266
Externally visible function.
Definition: GlobalValue.h:34
std::string str() const
str - Get the contents as an std::string.
Definition: StringRef.h:181
Type * getReturnType() const
Definition: Function.cpp:179
F(f)
Type * getPointerElementType() const
Definition: Type.h:373
StringRef getName() const
Definition: Value.cpp:167
void getAllMetadataOtherThanDebugLoc(SmallVectorImpl< std::pair< unsigned, MDNode * > > &MDs) const
Definition: Instruction.h:162
unsigned getNumArgOperands() const
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:421
ID
LLVM Calling Convention Representation.
Definition: CallingConv.h:26
Type * getVectorElementType() const
Definition: Type.h:371
FunctionPass * createSITypeRewriter()
This class represents a no-op cast from one type to another.
static FunctionType * get(Type *Result, ArrayRef< Type * > Params, bool isVarArg)
Definition: Type.cpp:361
void replaceAllUsesWith(Value *V)
Definition: Value.cpp:303
bool isVectorTy() const
Definition: Type.h:229
Value * getOperand(unsigned i) const
Definition: User.h:88
Value * getPointerOperand()
Definition: Instructions.h:223
enable_if_c< std::numeric_limits< T >::is_signed, bool >::type getAsInteger(unsigned Radix, T &Result) const
Definition: StringRef.h:337
LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:517
void setMetadata(unsigned KindID, MDNode *Node)
Definition: Metadata.cpp:589
unsigned getVectorNumElements() const
Definition: Type.cpp:214
Type * getType() const
Definition: Value.h:111
Type * getDestTy() const
Return the destination type, as a convenience.
Definition: InstrTypes.h:610
static IntegerType * getIntNTy(LLVMContext &C, unsigned N)
Definition: Type.cpp:244
Function * getCalledFunction() const
AttributeSet getAttributes() const
Return the attribute list for this Function.
Definition: Function.h:170
Value * getArgOperand(unsigned i) const
static IntegerType * getInt32Ty(LLVMContext &C)
Definition: Type.cpp:241
#define I(x, y, z)
Definition: MD5.cpp:54
bool isStringAttribute() const
Return true if the attribute is a string (target-dependent) attribute.
Definition: Attributes.cpp:102
void setAttributes(AttributeSet attrs)
Set the attribute list for this Function.
Definition: Function.h:173
Attribute getAttribute(unsigned Index, Attribute::AttrKind Kind) const
Return the attribute object that exists at the given index.
Definition: Attributes.cpp:847
static PointerType * getIntNPtrTy(LLVMContext &C, unsigned N, unsigned AS=0)
Definition: Type.cpp:276
StringRef getValueAsString() const
Return the attribute's value as a string. This requires the attribute to be a string attribute...
Definition: Attributes.cpp:127
LLVM Value Representation.
Definition: Value.h:66
static VectorType * get(Type *ElementType, unsigned NumElements)
Definition: Type.cpp:706
static IntegerType * getInt8Ty(LLVMContext &C)
Definition: Type.cpp:239
LLVMContext & getContext() const
Definition: Module.h:249
static Function * Create(FunctionType *Ty, LinkageTypes Linkage, const Twine &N="", Module *M=0)
Definition: Function.h:128