Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][sparse_tensor] Migrate SparseIterationToScf.cpp to dialect conversion #121054

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

matthias-springer
Copy link
Member

Use the regular dialect conversion driver instead of the 1:N dialect conversion driver. The 1:N dialect conversion driver will be removed soon.

@llvmbot
Copy link
Member

llvmbot commented Dec 24, 2024

@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

Use the regular dialect conversion driver instead of the 1:N dialect conversion driver. The 1:N dialect conversion driver will be removed soon.


Full diff: https://github.com/llvm/llvm-project/pull/121054.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp (+84-50)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp (+8-3)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index e8a40b1e033dd5..7ff148dffbb1b1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -7,11 +7,17 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
-#include "mlir/Transforms/OneToNTypeConversion.h"
+#include "mlir/Transforms/DialectConversion.h"
 
 using namespace mlir;
 using namespace mlir::sparse_tensor;
 
+/// Assert that the given value range contains a single value and return it.
+static Value getSingleValue(ValueRange values) {
+  assert(values.size() == 1 && "expected single value");
+  return values.front();
+}
+
 static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
                              SmallVectorImpl<Type> &fields) {
   // Position and coordinate buffer in the sparse structure.
@@ -54,14 +60,17 @@ static ValueRange
 genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
                        Value loopCrd,
                        ArrayRef<std::unique_ptr<SparseIterator>> iters,
-                       ArrayRef<Region *> subCases, ArrayRef<Value> userReduc) {
-  if (subCases.empty())
+                       ArrayRef<Block *> newBlocks, ArrayRef<Block *> oldBlocks,
+                       ArrayRef<Value> userReduc) {
+  if (newBlocks.empty())
     return userReduc;
 
   // The current branch that we are handling.
-  Region *b = subCases.front();
+  Block *newBlock = newBlocks.front();
+  Block *oldBlock = oldBlocks.front();
   Value casePred = constantI1(rewriter, loc, true);
-  I64BitSet caseBits = op.getRegionDefinedSpace(b->getRegionNumber());
+  I64BitSet caseBits =
+      op.getRegionDefinedSpace(newBlock->getParent()->getRegionNumber());
   for (unsigned i : caseBits.bits()) {
     SparseIterator *it = iters[i].get();
     Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
@@ -80,16 +89,20 @@ genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
   for (unsigned idx : caseBits.bits())
     llvm::append_range(blockArgs, iters[idx]->getCursor());
 
+  // Map the old block arguments, because the dialect conversion driver does
+  // not immediately perform SSA value replacements. This function is still
+  // seeing the old uses.
   IRMapping mapping;
-  for (auto [from, to] :
-       llvm::zip_equal(b->front().getArguments(), blockArgs)) {
+  for (auto [from, to] : llvm::zip_equal(oldBlock->getArguments(), blockArgs)) {
     mapping.map(from, to);
   }
 
   // Clone the region, we can not erase the region now because the same region
   // might be a subcase for multiple lattice point.
-  rewriter.cloneRegionBefore(*b, ifOp.getThenRegion(),
+  rewriter.cloneRegionBefore(*newBlock->getParent(), ifOp.getThenRegion(),
                              ifOp.getThenRegion().begin(), mapping);
+  // Remove the block arguments, they were already replaced via `mapping`.
+  ifOp.getThenRegion().front().eraseArguments(0, blockArgs.size());
 
   // replace sparse_tensor::YieldOp -> scf::YieldOp
   auto spY = cast<sparse_tensor::YieldOp>(&ifOp.getThenRegion().front().back());
@@ -101,7 +114,8 @@ genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
   // Generates remaining case recursively.
   rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
   ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd, iters,
-                                          subCases.drop_front(), userReduc);
+                                          newBlocks.drop_front(),
+                                          oldBlocks.drop_front(), userReduc);
   if (!res.empty())
     rewriter.create<scf::YieldOp>(loc, res);
 
@@ -119,15 +133,13 @@ static ValueRange genLoopWithIterator(
   if (it->iteratableByFor()) {
     auto [lo, hi] = it->genForCond(rewriter, loc);
     Value step = constantIndex(rewriter, loc, 1);
-    scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, reduc);
+    scf::ForOp forOp = rewriter.create<scf::ForOp>(
+        loc, lo, hi, step, reduc,
+        [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
+          // Empty builder function to ensure that no terminator is created.
+        });
     {
       OpBuilder::InsertionGuard guard(rewriter);
-      // Erase the implicit yield operation created by ForOp when there is no
-      // yielding values.
-      if (!forOp.getBody()->empty())
-        rewriter.eraseOp(&forOp.getBody()->front());
-      assert(forOp.getBody()->empty());
-
       it->linkNewScope(forOp.getInductionVar());
       rewriter.setInsertionPointToStart(forOp.getBody());
       SmallVector<Value> ret = bodyBuilder(rewriter, loc, forOp.getBodyRegion(),
@@ -178,46 +190,47 @@ namespace {
 
 /// Sparse codegen rule for number of entries operator.
 class ExtractIterSpaceConverter
-    : public OneToNOpConversionPattern<ExtractIterSpaceOp> {
+    : public OpConversionPattern<ExtractIterSpaceOp> {
 public:
-  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+  using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(ExtractIterSpaceOp op, OpAdaptor adaptor,
-                  OneToNPatternRewriter &rewriter) const override {
+  matchAndRewrite(ExtractIterSpaceOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-    const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
 
     // Construct the iteration space.
-    SparseIterationSpace space(loc, rewriter, op.getTensor(), 0,
+    SparseIterationSpace space(loc, rewriter,
+                               getSingleValue(adaptor.getTensor()), 0,
                                op.getLvlRange(), adaptor.getParentIter());
 
     SmallVector<Value> result = space.toValues();
-    rewriter.replaceOp(op, result, resultMapping);
+    rewriter.replaceOpWithMultiple(op, {result});
     return success();
   }
 };
 
 /// Sparse codegen rule for number of entries operator.
-class ExtractValOpConverter : public OneToNOpConversionPattern<ExtractValOp> {
+class ExtractValOpConverter : public OpConversionPattern<ExtractValOp> {
 public:
-  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+  using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(ExtractValOp op, OpAdaptor adaptor,
-                  OneToNPatternRewriter &rewriter) const override {
+  matchAndRewrite(ExtractValOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     Value pos = adaptor.getIterator().back();
-    Value valBuf = rewriter.create<ToValuesOp>(loc, op.getTensor());
+    Value valBuf =
+        rewriter.create<ToValuesOp>(loc, getSingleValue(adaptor.getTensor()));
     rewriter.replaceOpWithNewOp<memref::LoadOp>(op, valBuf, pos);
     return success();
   }
 };
 
-class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
+class SparseIterateOpConverter : public OpConversionPattern<IterateOp> {
 public:
-  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+  using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(IterateOp op, OpAdaptor adaptor,
-                  OneToNPatternRewriter &rewriter) const override {
+  matchAndRewrite(IterateOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
     if (!op.getCrdUsedLvls().empty())
       return rewriter.notifyMatchFailure(
           op, "non-empty coordinates list not implemented.");
@@ -235,14 +248,15 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
       llvm::append_range(ivs, inits);
 
     // Type conversion on iterate op block.
-    OneToNTypeMapping blockTypeMapping(op.getBody()->getArgumentTypes());
+    unsigned numOrigArgs = op.getBody()->getArgumentTypes().size();
+    TypeConverter::SignatureConversion signatureConversion(numOrigArgs);
     if (failed(typeConverter->convertSignatureArgs(
-            op.getBody()->getArgumentTypes(), blockTypeMapping)))
+            op.getBody()->getArgumentTypes(), signatureConversion)))
       return rewriter.notifyMatchFailure(
           op, "failed to convert iterate region argurment types");
-    rewriter.applySignatureConversion(op.getBody(), blockTypeMapping);
 
-    Block *block = op.getBody();
+    Block *block = rewriter.applySignatureConversion(
+        op.getBody(), signatureConversion, getTypeConverter());
     ValueRange ret = genLoopWithIterator(
         rewriter, loc, it.get(), ivs,
         [block](PatternRewriter &rewriter, Location loc, Region &loopBody,
@@ -263,19 +277,28 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
           return result;
         });
 
-    const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
-    rewriter.replaceOp(op, ret, resultMapping);
+    /*
+        SmallVector<ValueRange> repl;
+        for (unsigned i = 0; i < numOrigArgs; ++i) {
+          auto mapping = signatureConversion.getInputMapping(i);
+          assert(mapping.has_value());
+          llvm::errs() << "start = " << mapping->inputNo << ", num = " <<
+       mapping->size << "\n"; llvm::errs() << "range size = " << ret.size() <<
+       "\n"; repl.push_back(ret.slice(mapping->inputNo, mapping->size));
+        }
+        rewriter.replaceOpWithMultiple(op, repl);
+      */
+    rewriter.replaceOp(op, ret);
     return success();
   }
 };
 
-class SparseCoIterateOpConverter
-    : public OneToNOpConversionPattern<CoIterateOp> {
-  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+class SparseCoIterateOpConverter : public OpConversionPattern<CoIterateOp> {
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(CoIterateOp op, OpAdaptor adaptor,
-                  OneToNPatternRewriter &rewriter) const override {
+  matchAndRewrite(CoIterateOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
     assert(op.getSpaceDim() == 1 && "Not implemented");
     Location loc = op.getLoc();
 
@@ -299,18 +322,23 @@ class SparseCoIterateOpConverter
     assert(!needUniv && "Not implemented");
     (void)needUniv;
 
+    SmallVector<Block *> newBlocks;
+    DenseMap<Block *, Block *> newToOldBlockMap;
     for (Region &region : op.getCaseRegions()) {
       // Do a one-shot type conversion on all region blocks, since the same
       // region might be used multiple time.
       Block *block = &region.getBlocks().front();
-      OneToNTypeMapping blockTypeMapping(block->getArgumentTypes());
+      TypeConverter::SignatureConversion blockTypeMapping(
+          block->getArgumentTypes().size());
       if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(),
                                                      blockTypeMapping))) {
         return rewriter.notifyMatchFailure(
             op, "failed to convert coiterate region argurment types");
       }
 
-      rewriter.applySignatureConversion(block, blockTypeMapping);
+      newBlocks.push_back(rewriter.applySignatureConversion(
+          block, blockTypeMapping, getTypeConverter()));
+      newToOldBlockMap[newBlocks.back()] = block;
     }
 
     SmallVector<SparseIterationSpace> spaces;
@@ -343,7 +371,7 @@ class SparseCoIterateOpConverter
 
     // Generates a loop sequence, one loop per case.
     for (auto [r, caseBits] :
-         llvm::zip_equal(op.getCaseRegions(), op.getRegionDefinedSpaces())) {
+         llvm::zip_equal(newBlocks, op.getRegionDefinedSpaces())) {
       assert(caseBits.count() > 0 && "Complement space not implemented");
 
       // Retrives a vector of pointers to the iterators used in the case.
@@ -359,11 +387,17 @@ class SparseCoIterateOpConverter
         // The subcases are never empty, it must contains at least the current
         // region itself.
         // TODO: these cases should be sorted.
-        SmallVector<Region *> subCases = op.getSubCasesOf(r.getRegionNumber());
+        SmallVector<Region *> subCases =
+            op.getSubCasesOf(r->getParent()->getRegionNumber());
+        SmallVector<Block *> newBlocks, oldBlocks;
+        for (Region *r : subCases) {
+          newBlocks.push_back(&r->front());
+          oldBlocks.push_back(newToOldBlockMap[newBlocks.back()]);
+        }
         assert(!subCases.empty());
 
-        ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd,
-                                                iters, subCases, userReduc);
+        ValueRange res = genCoIterateBranchNest(
+            rewriter, loc, op, loopCrd, iters, newBlocks, oldBlocks, userReduc);
 
         SmallVector<Value> nextIterYields(res);
         // 2nd. foward the loop.
@@ -388,7 +422,7 @@ class SparseCoIterateOpConverter
         // This is a simple iteration loop.
         assert(caseBits.count() == 1);
 
-        Block *block = &r.getBlocks().front();
+        Block *block = r;
         ValueRange curResult = genLoopWithIterator(
             rewriter, loc, validIters.front(), userReduc,
             /*bodyBuilder=*/
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 1cac949b68c79d..153b9b170e5d34 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -172,11 +172,16 @@ struct LowerSparseIterationToSCFPass
     ConversionTarget target(*ctx);
 
     // The actual conversion.
-    target.addIllegalOp<ExtractIterSpaceOp, IterateOp>();
+    target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
+                           memref::MemRefDialect, scf::SCFDialect,
+                           sparse_tensor::SparseTensorDialect>();
+    target.addIllegalOp<CoIterateOp, ExtractIterSpaceOp, ExtractValOp,
+                        IterateOp>();
+    target.addLegalOp<UnrealizedConversionCastOp>();
     populateLowerSparseIterationToSCFPatterns(converter, patterns);
 
-    if (failed(applyPartialOneToNConversion(getOperation(), converter,
-                                            std::move(patterns))))
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
       signalPassFailure();
   }
 };

@matthias-springer matthias-springer force-pushed the users/matthias-springer/sparse_iteration branch from be9948f to 8547cba Compare December 24, 2024 14:07
@matthias-springer
Copy link
Member Author

Note: I don't know what this code is doing in detail. There may be easier/simpler ways to migrate to the regular dialect conversion driver.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants