-
Notifications
You must be signed in to change notification settings - Fork 12.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][sparse_tensor] Migrate SparseIterationToScf.cpp
to dialect conversion
#121054
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-sparse @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesUse the regular dialect conversion driver instead of the 1:N dialect conversion driver. The 1:N dialect conversion driver will be removed soon. Full diff: https://github.com/llvm/llvm-project/pull/121054.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index e8a40b1e033dd5..7ff148dffbb1b1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -7,11 +7,17 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
-#include "mlir/Transforms/OneToNTypeConversion.h"
+#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
using namespace mlir::sparse_tensor;
+/// Assert that the given value range contains a single value and return it.
+static Value getSingleValue(ValueRange values) {
+ assert(values.size() == 1 && "expected single value");
+ return values.front();
+}
+
static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
SmallVectorImpl<Type> &fields) {
// Position and coordinate buffer in the sparse structure.
@@ -54,14 +60,17 @@ static ValueRange
genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
Value loopCrd,
ArrayRef<std::unique_ptr<SparseIterator>> iters,
- ArrayRef<Region *> subCases, ArrayRef<Value> userReduc) {
- if (subCases.empty())
+ ArrayRef<Block *> newBlocks, ArrayRef<Block *> oldBlocks,
+ ArrayRef<Value> userReduc) {
+ if (newBlocks.empty())
return userReduc;
// The current branch that we are handling.
- Region *b = subCases.front();
+ Block *newBlock = newBlocks.front();
+ Block *oldBlock = oldBlocks.front();
Value casePred = constantI1(rewriter, loc, true);
- I64BitSet caseBits = op.getRegionDefinedSpace(b->getRegionNumber());
+ I64BitSet caseBits =
+ op.getRegionDefinedSpace(newBlock->getParent()->getRegionNumber());
for (unsigned i : caseBits.bits()) {
SparseIterator *it = iters[i].get();
Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
@@ -80,16 +89,20 @@ genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
for (unsigned idx : caseBits.bits())
llvm::append_range(blockArgs, iters[idx]->getCursor());
+ // Map the old block arguments, because the dialect conversion driver does
+ // not immediately perform SSA value replacements. This function is still
+ // seeing the old uses.
IRMapping mapping;
- for (auto [from, to] :
- llvm::zip_equal(b->front().getArguments(), blockArgs)) {
+ for (auto [from, to] : llvm::zip_equal(oldBlock->getArguments(), blockArgs)) {
mapping.map(from, to);
}
// Clone the region, we can not erase the region now because the same region
// might be a subcase for multiple lattice point.
- rewriter.cloneRegionBefore(*b, ifOp.getThenRegion(),
+ rewriter.cloneRegionBefore(*newBlock->getParent(), ifOp.getThenRegion(),
ifOp.getThenRegion().begin(), mapping);
+ // Remove the block arguments, they were already replaced via `mapping`.
+ ifOp.getThenRegion().front().eraseArguments(0, blockArgs.size());
// replace sparse_tensor::YieldOp -> scf::YieldOp
auto spY = cast<sparse_tensor::YieldOp>(&ifOp.getThenRegion().front().back());
@@ -101,7 +114,8 @@ genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
// Generates remaining case recursively.
rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd, iters,
- subCases.drop_front(), userReduc);
+ newBlocks.drop_front(),
+ oldBlocks.drop_front(), userReduc);
if (!res.empty())
rewriter.create<scf::YieldOp>(loc, res);
@@ -119,15 +133,13 @@ static ValueRange genLoopWithIterator(
if (it->iteratableByFor()) {
auto [lo, hi] = it->genForCond(rewriter, loc);
Value step = constantIndex(rewriter, loc, 1);
- scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, reduc);
+ scf::ForOp forOp = rewriter.create<scf::ForOp>(
+ loc, lo, hi, step, reduc,
+ [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
+ // Empty builder function to ensure that no terminator is created.
+ });
{
OpBuilder::InsertionGuard guard(rewriter);
- // Erase the implicit yield operation created by ForOp when there is no
- // yielding values.
- if (!forOp.getBody()->empty())
- rewriter.eraseOp(&forOp.getBody()->front());
- assert(forOp.getBody()->empty());
-
it->linkNewScope(forOp.getInductionVar());
rewriter.setInsertionPointToStart(forOp.getBody());
SmallVector<Value> ret = bodyBuilder(rewriter, loc, forOp.getBodyRegion(),
@@ -178,46 +190,47 @@ namespace {
/// Sparse codegen rule for number of entries operator.
class ExtractIterSpaceConverter
- : public OneToNOpConversionPattern<ExtractIterSpaceOp> {
+ : public OpConversionPattern<ExtractIterSpaceOp> {
public:
- using OneToNOpConversionPattern::OneToNOpConversionPattern;
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(ExtractIterSpaceOp op, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
+ matchAndRewrite(ExtractIterSpaceOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
// Construct the iteration space.
- SparseIterationSpace space(loc, rewriter, op.getTensor(), 0,
+ SparseIterationSpace space(loc, rewriter,
+ getSingleValue(adaptor.getTensor()), 0,
op.getLvlRange(), adaptor.getParentIter());
SmallVector<Value> result = space.toValues();
- rewriter.replaceOp(op, result, resultMapping);
+ rewriter.replaceOpWithMultiple(op, {result});
return success();
}
};
/// Sparse codegen rule for number of entries operator.
-class ExtractValOpConverter : public OneToNOpConversionPattern<ExtractValOp> {
+class ExtractValOpConverter : public OpConversionPattern<ExtractValOp> {
public:
- using OneToNOpConversionPattern::OneToNOpConversionPattern;
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(ExtractValOp op, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
+ matchAndRewrite(ExtractValOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value pos = adaptor.getIterator().back();
- Value valBuf = rewriter.create<ToValuesOp>(loc, op.getTensor());
+ Value valBuf =
+ rewriter.create<ToValuesOp>(loc, getSingleValue(adaptor.getTensor()));
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, valBuf, pos);
return success();
}
};
-class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
+class SparseIterateOpConverter : public OpConversionPattern<IterateOp> {
public:
- using OneToNOpConversionPattern::OneToNOpConversionPattern;
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(IterateOp op, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
+ matchAndRewrite(IterateOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
if (!op.getCrdUsedLvls().empty())
return rewriter.notifyMatchFailure(
op, "non-empty coordinates list not implemented.");
@@ -235,14 +248,15 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
llvm::append_range(ivs, inits);
// Type conversion on iterate op block.
- OneToNTypeMapping blockTypeMapping(op.getBody()->getArgumentTypes());
+ unsigned numOrigArgs = op.getBody()->getArgumentTypes().size();
+ TypeConverter::SignatureConversion signatureConversion(numOrigArgs);
if (failed(typeConverter->convertSignatureArgs(
- op.getBody()->getArgumentTypes(), blockTypeMapping)))
+ op.getBody()->getArgumentTypes(), signatureConversion)))
return rewriter.notifyMatchFailure(
op, "failed to convert iterate region argurment types");
- rewriter.applySignatureConversion(op.getBody(), blockTypeMapping);
- Block *block = op.getBody();
+ Block *block = rewriter.applySignatureConversion(
+ op.getBody(), signatureConversion, getTypeConverter());
ValueRange ret = genLoopWithIterator(
rewriter, loc, it.get(), ivs,
[block](PatternRewriter &rewriter, Location loc, Region &loopBody,
@@ -263,19 +277,28 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
return result;
});
- const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
- rewriter.replaceOp(op, ret, resultMapping);
+ /*
+ SmallVector<ValueRange> repl;
+ for (unsigned i = 0; i < numOrigArgs; ++i) {
+ auto mapping = signatureConversion.getInputMapping(i);
+ assert(mapping.has_value());
+ llvm::errs() << "start = " << mapping->inputNo << ", num = " <<
+ mapping->size << "\n"; llvm::errs() << "range size = " << ret.size() <<
+ "\n"; repl.push_back(ret.slice(mapping->inputNo, mapping->size));
+ }
+ rewriter.replaceOpWithMultiple(op, repl);
+ */
+ rewriter.replaceOp(op, ret);
return success();
}
};
-class SparseCoIterateOpConverter
- : public OneToNOpConversionPattern<CoIterateOp> {
- using OneToNOpConversionPattern::OneToNOpConversionPattern;
+class SparseCoIterateOpConverter : public OpConversionPattern<CoIterateOp> {
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(CoIterateOp op, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
+ matchAndRewrite(CoIterateOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
assert(op.getSpaceDim() == 1 && "Not implemented");
Location loc = op.getLoc();
@@ -299,18 +322,23 @@ class SparseCoIterateOpConverter
assert(!needUniv && "Not implemented");
(void)needUniv;
+ SmallVector<Block *> newBlocks;
+ DenseMap<Block *, Block *> newToOldBlockMap;
for (Region ®ion : op.getCaseRegions()) {
// Do a one-shot type conversion on all region blocks, since the same
// region might be used multiple time.
Block *block = ®ion.getBlocks().front();
- OneToNTypeMapping blockTypeMapping(block->getArgumentTypes());
+ TypeConverter::SignatureConversion blockTypeMapping(
+ block->getArgumentTypes().size());
if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(),
blockTypeMapping))) {
return rewriter.notifyMatchFailure(
op, "failed to convert coiterate region argurment types");
}
- rewriter.applySignatureConversion(block, blockTypeMapping);
+ newBlocks.push_back(rewriter.applySignatureConversion(
+ block, blockTypeMapping, getTypeConverter()));
+ newToOldBlockMap[newBlocks.back()] = block;
}
SmallVector<SparseIterationSpace> spaces;
@@ -343,7 +371,7 @@ class SparseCoIterateOpConverter
// Generates a loop sequence, one loop per case.
for (auto [r, caseBits] :
- llvm::zip_equal(op.getCaseRegions(), op.getRegionDefinedSpaces())) {
+ llvm::zip_equal(newBlocks, op.getRegionDefinedSpaces())) {
assert(caseBits.count() > 0 && "Complement space not implemented");
// Retrives a vector of pointers to the iterators used in the case.
@@ -359,11 +387,17 @@ class SparseCoIterateOpConverter
// The subcases are never empty, it must contains at least the current
// region itself.
// TODO: these cases should be sorted.
- SmallVector<Region *> subCases = op.getSubCasesOf(r.getRegionNumber());
+ SmallVector<Region *> subCases =
+ op.getSubCasesOf(r->getParent()->getRegionNumber());
+ SmallVector<Block *> newBlocks, oldBlocks;
+ for (Region *r : subCases) {
+ newBlocks.push_back(&r->front());
+ oldBlocks.push_back(newToOldBlockMap[newBlocks.back()]);
+ }
assert(!subCases.empty());
- ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd,
- iters, subCases, userReduc);
+ ValueRange res = genCoIterateBranchNest(
+ rewriter, loc, op, loopCrd, iters, newBlocks, oldBlocks, userReduc);
SmallVector<Value> nextIterYields(res);
// 2nd. foward the loop.
@@ -388,7 +422,7 @@ class SparseCoIterateOpConverter
// This is a simple iteration loop.
assert(caseBits.count() == 1);
- Block *block = &r.getBlocks().front();
+ Block *block = r;
ValueRange curResult = genLoopWithIterator(
rewriter, loc, validIters.front(), userReduc,
/*bodyBuilder=*/
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 1cac949b68c79d..153b9b170e5d34 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -172,11 +172,16 @@ struct LowerSparseIterationToSCFPass
ConversionTarget target(*ctx);
// The actual conversion.
- target.addIllegalOp<ExtractIterSpaceOp, IterateOp>();
+ target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
+ memref::MemRefDialect, scf::SCFDialect,
+ sparse_tensor::SparseTensorDialect>();
+ target.addIllegalOp<CoIterateOp, ExtractIterSpaceOp, ExtractValOp,
+ IterateOp>();
+ target.addLegalOp<UnrealizedConversionCastOp>();
populateLowerSparseIterationToSCFPatterns(converter, patterns);
- if (failed(applyPartialOneToNConversion(getOperation(), converter,
- std::move(patterns))))
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
signalPassFailure();
}
};
|
be9948f
to
8547cba
Compare
Note: I don't know what this code is doing in detail. There may be easier/simpler ways to migrate to the regular dialect conversion driver. |
Use the regular dialect conversion driver instead of the 1:N dialect conversion driver. The 1:N dialect conversion driver will be removed soon.