Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR] Add cir.global_addr attribute #1248

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,40 @@ def GlobalViewAttr : CIR_Attr<"GlobalView", "global_view", [TypedAttrInterface]>
}];
}

//===----------------------------------------------------------------------===//
// GlobalAddrAttr
//===----------------------------------------------------------------------===//

def GlobalAddrAttr
: CIR_Attr<"GlobalAddr", "global_addr", [TypedAttrInterface]> {
let summary = "Get access to a constant integral address of a global";
let description = [{
Get constant address of a global `symbol` as an integer value. The type of
the `#cir.global_addr` attribute must be an integer type.

Example:

```
cir.global external @str = @"hello": !cir.ptr<i8>
cir.global external @str_addr = #cir.global_addr<@str> : !s64i
```
}];

let parameters = (ins AttributeSelfTypeParameter<"", "cir::IntType">:$type,
"mlir::FlatSymbolRefAttr":$symbol);

let builders = [
AttrBuilderWithInferredContext<(ins "mlir::Type":$type,
"mlir::FlatSymbolRefAttr":$symbol), [{
return $_get(type.getContext(), mlir::cast<cir::IntType>(type), symbol);
}]>
];

let assemblyFormat = [{
`<` $symbol `>`
}];
}

//===----------------------------------------------------------------------===//
// TypeInfoAttr
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
}

if (mlir::isa<cir::GlobalViewAttr>(attrType) ||
mlir::isa<cir::GlobalAddrAttr>(attrType) ||
mlir::isa<cir::TypeInfoAttr>(attrType) ||
mlir::isa<cir::ConstArrayAttr>(attrType) ||
mlir::isa<cir::ConstVectorAttr>(attrType) ||
Expand Down
135 changes: 87 additions & 48 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,36 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstVectorAttr constVec,
mlirValues));
}

static void lookupGlobalSymbolInfo(mlir::ModuleOp module,
mlir::FlatSymbolRefAttr symbolRef,
mlir::Type *sourceType,
unsigned *sourceAddrSpace,
llvm::StringRef *symName,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter &converter) {
auto *sourceSymbol = mlir::SymbolTable::lookupSymbolIn(module, symbolRef);
if (auto llvmSymbol = dyn_cast<mlir::LLVM::GlobalOp>(sourceSymbol)) {
*sourceType = llvmSymbol.getType();
*symName = llvmSymbol.getSymName();
*sourceAddrSpace = llvmSymbol.getAddrSpace();
} else if (auto cirSymbol = dyn_cast<cir::GlobalOp>(sourceSymbol)) {
*sourceType = converter.convertType(cirSymbol.getSymType());
*symName = cirSymbol.getSymName();
*sourceAddrSpace =
getGlobalOpTargetAddrSpace(rewriter, &converter, cirSymbol);
} else if (auto llvmFun = dyn_cast<mlir::LLVM::LLVMFuncOp>(sourceSymbol)) {
*sourceType = llvmFun.getFunctionType();
*symName = llvmFun.getSymName();
*sourceAddrSpace = 0;
} else if (auto fun = dyn_cast<cir::FuncOp>(sourceSymbol)) {
*sourceType = converter.convertType(fun.getFunctionType());
*symName = fun.getSymName();
*sourceAddrSpace = 0;
} else {
llvm_unreachable("Unexpected GlobalOp type");
}
}

// GlobalViewAttr visitor.
static mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::GlobalViewAttr globalAttr,
Expand All @@ -575,28 +605,8 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::GlobalViewAttr globalAttr,
mlir::Type sourceType;
unsigned sourceAddrSpace = 0;
llvm::StringRef symName;
auto *sourceSymbol =
mlir::SymbolTable::lookupSymbolIn(module, globalAttr.getSymbol());
if (auto llvmSymbol = dyn_cast<mlir::LLVM::GlobalOp>(sourceSymbol)) {
sourceType = llvmSymbol.getType();
symName = llvmSymbol.getSymName();
sourceAddrSpace = llvmSymbol.getAddrSpace();
} else if (auto cirSymbol = dyn_cast<cir::GlobalOp>(sourceSymbol)) {
sourceType = converter->convertType(cirSymbol.getSymType());
symName = cirSymbol.getSymName();
sourceAddrSpace =
getGlobalOpTargetAddrSpace(rewriter, converter, cirSymbol);
} else if (auto llvmFun = dyn_cast<mlir::LLVM::LLVMFuncOp>(sourceSymbol)) {
sourceType = llvmFun.getFunctionType();
symName = llvmFun.getSymName();
sourceAddrSpace = 0;
} else if (auto fun = dyn_cast<cir::FuncOp>(sourceSymbol)) {
sourceType = converter->convertType(fun.getFunctionType());
symName = fun.getSymName();
sourceAddrSpace = 0;
} else {
llvm_unreachable("Unexpected GlobalOp type");
}
lookupGlobalSymbolInfo(module, globalAttr.getSymbol(), &sourceType,
&sourceAddrSpace, &symName, rewriter, *converter);

auto loc = parentOp->getLoc();
mlir::Value addrOp = rewriter.create<mlir::LLVM::AddressOfOp>(
Expand Down Expand Up @@ -637,36 +647,53 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::GlobalViewAttr globalAttr,
addrOp);
}

// GlobalViewAddr visitor.
static mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::GlobalAddrAttr globalAttr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter) {
auto module = parentOp->getParentOfType<mlir::ModuleOp>();
mlir::Type sourceType;
unsigned sourceAddrSpace = 0;
llvm::StringRef symName;
lookupGlobalSymbolInfo(module, globalAttr.getSymbol(), &sourceType,
&sourceAddrSpace, &symName, rewriter, *converter);

auto loc = parentOp->getLoc();
auto addrTy =
mlir::LLVM::LLVMPointerType::get(rewriter.getContext(), sourceAddrSpace);
mlir::Value addrOp =
rewriter.create<mlir::LLVM::AddressOfOp>(loc, addrTy, symName);

auto llvmDstTy = converter->convertType(globalAttr.getType());
return rewriter.create<mlir::LLVM::PtrToIntOp>(parentOp->getLoc(), llvmDstTy,
addrOp);
}

/// Switches on the type of attribute and calls the appropriate conversion.
mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::Attribute attr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter) {
if (const auto intAttr = mlir::dyn_cast<cir::IntAttr>(attr))
return lowerCirAttrAsValue(parentOp, intAttr, rewriter, converter);
if (const auto fltAttr = mlir::dyn_cast<cir::FPAttr>(attr))
return lowerCirAttrAsValue(parentOp, fltAttr, rewriter, converter);
if (const auto ptrAttr = mlir::dyn_cast<cir::ConstPtrAttr>(attr))
return lowerCirAttrAsValue(parentOp, ptrAttr, rewriter, converter);
if (const auto constStruct = mlir::dyn_cast<cir::ConstStructAttr>(attr))
return lowerCirAttrAsValue(parentOp, constStruct, rewriter, converter);
if (const auto constArr = mlir::dyn_cast<cir::ConstArrayAttr>(attr))
return lowerCirAttrAsValue(parentOp, constArr, rewriter, converter);
if (const auto constVec = mlir::dyn_cast<cir::ConstVectorAttr>(attr))
return lowerCirAttrAsValue(parentOp, constVec, rewriter, converter);
if (const auto boolAttr = mlir::dyn_cast<cir::BoolAttr>(attr))
return lowerCirAttrAsValue(parentOp, boolAttr, rewriter, converter);
if (const auto zeroAttr = mlir::dyn_cast<cir::ZeroAttr>(attr))
return lowerCirAttrAsValue(parentOp, zeroAttr, rewriter, converter);
if (const auto undefAttr = mlir::dyn_cast<cir::UndefAttr>(attr))
return lowerCirAttrAsValue(parentOp, undefAttr, rewriter, converter);
if (const auto poisonAttr = mlir::dyn_cast<cir::PoisonAttr>(attr))
return lowerCirAttrAsValue(parentOp, poisonAttr, rewriter, converter);
if (const auto globalAttr = mlir::dyn_cast<cir::GlobalViewAttr>(attr))
return lowerCirAttrAsValue(parentOp, globalAttr, rewriter, converter);
if (const auto vtableAttr = mlir::dyn_cast<cir::VTableAttr>(attr))
return lowerCirAttrAsValue(parentOp, vtableAttr, rewriter, converter);
if (const auto typeinfoAttr = mlir::dyn_cast<cir::TypeInfoAttr>(attr))
return lowerCirAttrAsValue(parentOp, typeinfoAttr, rewriter, converter);
#define LOWER_CIR_ATTR(type) \
if (const auto castedAttr = mlir::dyn_cast<type>(attr)) \
return lowerCirAttrAsValue(parentOp, castedAttr, rewriter, converter);

LOWER_CIR_ATTR(cir::BoolAttr)
LOWER_CIR_ATTR(cir::ConstArrayAttr)
LOWER_CIR_ATTR(cir::ConstPtrAttr)
LOWER_CIR_ATTR(cir::ConstStructAttr)
LOWER_CIR_ATTR(cir::ConstVectorAttr)
LOWER_CIR_ATTR(cir::FPAttr)
LOWER_CIR_ATTR(cir::GlobalAddrAttr)
LOWER_CIR_ATTR(cir::GlobalViewAttr)
LOWER_CIR_ATTR(cir::IntAttr)
LOWER_CIR_ATTR(cir::PoisonAttr)
LOWER_CIR_ATTR(cir::TypeInfoAttr)
LOWER_CIR_ATTR(cir::UndefAttr)
LOWER_CIR_ATTR(cir::VTableAttr)
LOWER_CIR_ATTR(cir::ZeroAttr)

#undef LOWER_CIR_ATTR

llvm_unreachable("unhandled attribute type");
}
Expand Down Expand Up @@ -1663,6 +1690,13 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite(
attr = rewriter.getIntegerAttr(typeConverter->convertType(op.getType()),
value);
} else if (mlir::isa<cir::IntType>(op.getType())) {
// Lower GlobalAddrAttr to llvm.mlir.addressof + llvm.mlir.ptrtoint
if (auto ga = mlir::dyn_cast<cir::GlobalAddrAttr>(op.getValue())) {
auto newOp = lowerCirAttrAsValue(op, ga, rewriter, getTypeConverter());
rewriter.replaceOp(op, newOp);
return mlir::success();
}

attr = rewriter.getIntegerAttr(
typeConverter->convertType(op.getType()),
mlir::cast<cir::IntAttr>(op.getValue()).getValue());
Expand Down Expand Up @@ -2348,6 +2382,11 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
rewriter.create<mlir::LLVM::ReturnOp>(
loc, lowerCirAttrAsValue(op, attr, rewriter, typeConverter));
return mlir::success();
} else if (auto attr = mlir::dyn_cast<cir::GlobalAddrAttr>(init.value())) {
setupRegionInitializedLLVMGlobalOp(op, rewriter);
rewriter.create<mlir::LLVM::ReturnOp>(
loc, lowerCirAttrAsValue(op, attr, rewriter, typeConverter));
return mlir::success();
} else if (const auto vtableAttr =
mlir::dyn_cast<cir::VTableAttr>(init.value())) {
setupRegionInitializedLLVMGlobalOp(op, rewriter);
Expand Down
20 changes: 20 additions & 0 deletions clang/test/CIR/Lowering/globals.cir
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,22 @@ module {
cir.global external @alpha = #cir.const_array<[#cir.int<97> : !s8i, #cir.int<98> : !s8i, #cir.int<99> : !s8i, #cir.int<0> : !s8i]> : !cir.array<!s8i x 4>
cir.global "private" constant internal @".str" = #cir.const_array<"example\00" : !cir.array<!s8i x 8>> : !cir.array<!s8i x 8> {alignment = 1 : i64}
cir.global external @s = #cir.global_view<@".str"> : !cir.ptr<!s8i>
cir.global external @s_addr = #cir.global_addr<@".str"> : !u64i
// MLIR: llvm.mlir.global internal constant @".str"("example\00")
// MLIR-SAME: {addr_space = 0 : i32, alignment = 1 : i64}
// MLIR: llvm.mlir.global external @s() {addr_space = 0 : i32} : !llvm.ptr {
// MLIR: %0 = llvm.mlir.addressof @".str" : !llvm.ptr
// MLIR: %1 = llvm.bitcast %0 : !llvm.ptr to !llvm.ptr
// MLIR: llvm.return %1 : !llvm.ptr
// MLIR: }
// MLIR: llvm.mlir.global external @s_addr() {addr_space = 0 : i32} : i64 {
// MLIR: %0 = llvm.mlir.addressof @".str" : !llvm.ptr
// MLIR: %1 = llvm.ptrtoint %0 : !llvm.ptr to i64
// MLIR: llvm.return %1 : i64
// MLIR: }
// LLVM: @.str = internal constant [8 x i8] c"example\00"
// LLVM: @s = global ptr @.str
// LLVM: @s_addr = global i64 ptrtoint (ptr @.str to i64)
cir.global external @aPtr = #cir.global_view<@a> : !cir.ptr<!s32i>
// MLIR: llvm.mlir.global external @aPtr() {addr_space = 0 : i32} : !llvm.ptr {
// MLIR: %0 = llvm.mlir.addressof @a : !llvm.ptr
Expand Down Expand Up @@ -198,4 +205,17 @@ module {
}
// MLIR: %0 = llvm.mlir.addressof @zero_array

cir.func @const_global_addr() -> !u64i {
%0 = cir.const #cir.global_addr<@".str"> : !u64i
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like this is equivalent to a global view with zero index? Why can't this be modeled ash cir.const #cir.global_view<@".str"> ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cir.global_addr is of integer types, and I'm not sure if it's appropriate to make cir.global_view an integer type. If it's OK I could just update cir.global_view instead of introducing a new one.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My take is that we want here is a constant pointer coming out of cir.const, followed up by a cast to integer?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My take is that we want here is a constant pointer coming out of cir.const, followed up by a cast to integer?

In this case, yes. But I'm actually proposing this attribute for the following scenario where you don't have space for additional operations:

%0 = cir.const #cir.const_struct {
  #cir.global_addr @aaa,
  #cir.global_addr @bbb
}

And I'm looking at this scenario because recently I'm working on the LLVM lowering of member function pointers. Upon ABI lowering, a member function pointer is lowered to a struct with two fields of type ptrdiff_t. When the member function pointer represents a non-virtual member function, the first field stores the address of the target function as an integer. Thus to represent constant member function pointers I need an attribute that works like #cir.global_addr.

cir.return %0 : !u64i
}
// MLIR-LABEL: @const_global_addr
// MLIR-NEXT: %0 = llvm.mlir.addressof @".str" : !llvm.ptr
// MLIR-NEXT: %1 = llvm.ptrtoint %0 : !llvm.ptr to i64
// MLIR-NEXT: llvm.return %1 : i64
// MLIR-NEXT: }
// LLVM-LABEL: @const_global_addr
// LLVM-NEXT: ret i64 ptrtoint (ptr @.str to i64)
// LLVM-NEXT: }

}
Loading