Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Prevent invalid IR from being passed outside of RemoveDeadValues #121079

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

parsifal-47
Copy link
Contributor

This is a follow-up for #119110 and a fix for #118450

RemoveDeadValues used to delete Values and analyzing the IR at the same time, because of that, isMemoryEffectFree got invalid IR with half-deleted linalg.generic operation. This PR separates analysis and cleanup to prevent such situation.

Thank you!

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Dec 25, 2024
@llvmbot
Copy link
Member

llvmbot commented Dec 25, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Renat Idrisov (parsifal-47)

Changes

This is a follow-up for #119110 and a fix for #118450

RemoveDeadValues used to delete Values and analyzing the IR at the same time, because of that, isMemoryEffectFree got invalid IR with half-deleted linalg.generic operation. This PR separates analysis and cleanup to prevent such situation.

Thank you!


Full diff: https://github.com/llvm/llvm-project/pull/121079.diff

2 Files Affected:

  • (modified) mlir/lib/Transforms/RemoveDeadValues.cpp (+146-60)
  • (modified) mlir/test/Transforms/remove-dead-values.mlir (+26)
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 3429008b2f241a..5d4ec66d6905a4 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -72,15 +72,54 @@ using namespace mlir::dataflow;
 
 namespace {
 
+// Set of structures below to be filled with operations and arguments to erase.
+// This is done to separate analysis and tree modification phases,
+// otherwise analysis is operating on half-deleted tree which is incorrect.
+
+struct CleanupFunction {
+  FunctionOpInterface funcOp;
+  BitVector nonLiveArgs;
+  BitVector nonLiveRets;
+};
+
+struct CleanupOperands {
+  Operation *op;
+  BitVector nonLiveOperands;
+};
+
+struct CleanupResults {
+  Operation *op;
+  BitVector nonLiveResults;
+};
+
+struct CleanupBlockArgs {
+  Block *b;
+  BitVector nonLiveArgs;
+};
+
+struct CleanupSuccessorOperands {
+  BranchOpInterface branch;
+  unsigned index;
+  BitVector nonLiveOperands;
+};
+
+struct CleanupList {
+  SmallVector<Operation *> operations;
+  SmallVector<Value> values;
+  SmallVector<CleanupFunction> functions;
+  SmallVector<CleanupOperands> operands;
+  SmallVector<CleanupResults> results;
+  SmallVector<CleanupBlockArgs> blocks;
+  SmallVector<CleanupSuccessorOperands> successorOperands;
+};
+
 // Some helper functions...
 
 /// Return true iff at least one value in `values` is live, given the liveness
 /// information in `la`.
-static bool hasLive(ValueRange values, RunLivenessAnalysis &la) {
+static bool hasLive(ValueRange values, const DenseSet<Value> &deletionSet, RunLivenessAnalysis &la) {
   for (Value value : values) {
-    // If there is a null value, it implies that it was dropped during the
-    // execution of this pass, implying that it was non-live.
-    if (!value)
+    if (deletionSet.contains(value))
       continue;
 
     const Liveness *liveness = la.getLiveness(value);
@@ -92,11 +131,11 @@ static bool hasLive(ValueRange values, RunLivenessAnalysis &la) {
 
 /// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the
 /// i-th value in `values` is live, given the liveness information in `la`.
-static BitVector markLives(ValueRange values, RunLivenessAnalysis &la) {
+static BitVector markLives(ValueRange values, const DenseSet<Value> &deletionSet, RunLivenessAnalysis &la) {
   BitVector lives(values.size(), true);
 
   for (auto [index, value] : llvm::enumerate(values)) {
-    if (!value) {
+    if (deletionSet.contains(value)) {
       lives.reset(index);
       continue;
     }
@@ -115,6 +154,18 @@ static BitVector markLives(ValueRange values, RunLivenessAnalysis &la) {
   return lives;
 }
 
+// DeletionSet is used to track the Values that are scheduled for removal
+void updateDeletionSet(DenseSet<Value> &deletionSet, ValueRange range, const BitVector &nonLive) {
+  for (auto [index, result] : llvm::enumerate(range)) {
+    if (!nonLive[index]) continue;
+    deletionSet.insert(result);
+  }
+}
+
+void updateDeletionSet(DenseSet<Value> &deletionSet, Operation *op, const BitVector &nonLive) {
+  updateDeletionSet(deletionSet, op->getResults(), nonLive);
+}
+
 /// Drop the uses of the i-th result of `op` and then erase it iff toErase[i]
 /// is 1.
 static void dropUsesAndEraseResults(Operation *op, BitVector toErase) {
@@ -174,43 +225,43 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
 /// It is assumed that `op` is simple. Here, a simple op is one which isn't a
 /// function-like op, a call-like op, a region branch op, a branch op, a region
 /// branch terminator op, or return-like.
-static void cleanSimpleOp(Operation *op, RunLivenessAnalysis &la) {
-  if (!isMemoryEffectFree(op) || hasLive(op->getResults(), la))
+static void cleanSimpleOp(CleanupList &cl, DenseSet<Value> &deletionSet, Operation *op, RunLivenessAnalysis &la) {
+  if (!isMemoryEffectFree(op) || hasLive(op->getResults(), deletionSet, la))
     return;
 
-  op->dropAllUses();
-  op->erase();
+  cl.operations.push_back(op);
+  updateDeletionSet(deletionSet, op, BitVector(op->getNumResults(), true));
 }
 
 /// Clean a function-like op `funcOp`, given the liveness information in `la`
 /// and the IR in `module`. Here, cleaning means:
 ///   (1) Dropping the uses of its unnecessary (non-live) arguments,
-///   (2) Erasing these arguments,
-///   (3) Erasing their corresponding operands from its callers,
+///   (2) Erasing their corresponding operands from its callers,
+///   (3) Erasing these arguments,
 ///   (4) Erasing its unnecessary terminator operands (return values that are
 ///   non-live across all callers),
 ///   (5) Dropping the uses of these return values from its callers, AND
 ///   (6) Erasing these return values
 /// iff it is not public or external.
-static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module,
+static void cleanFuncOp(CleanupList &cl, DenseSet<Value> &deletionSet,
+                        FunctionOpInterface funcOp, Operation *module,
                         RunLivenessAnalysis &la) {
   if (funcOp.isPublic() || funcOp.isExternal())
     return;
 
   // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
   SmallVector<Value> arguments(funcOp.getArguments());
-  BitVector nonLiveArgs = markLives(arguments, la);
+  BitVector nonLiveArgs = markLives(arguments, deletionSet, la);
   nonLiveArgs = nonLiveArgs.flip();
 
   // Do (1).
   for (auto [index, arg] : llvm::enumerate(arguments))
-    if (arg && nonLiveArgs[index])
-      arg.dropAllUses();
+    if (arg && nonLiveArgs[index]) {
+      cl.values.push_back(arg);
+      deletionSet.insert(arg);
+    }
 
   // Do (2).
-  funcOp.eraseArguments(nonLiveArgs);
-
-  // Do (3).
   SymbolTable::UseRange uses = *funcOp.getSymbolUses(module);
   for (SymbolTable::SymbolUse use : uses) {
     Operation *callOp = use.getUser();
@@ -222,7 +273,7 @@ static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module,
         operandsToOpOperands(cast<CallOpInterface>(callOp).getArgOperands());
     for (int index : nonLiveArgs.set_bits())
       nonLiveCallOperands.set(callOpOperands[index]->getOperandNumber());
-    callOp->eraseOperands(nonLiveCallOperands);
+    cl.operands.push_back({callOp, nonLiveCallOperands});
   }
 
   // Get the list of unnecessary terminator operands (return values that are
@@ -253,26 +304,27 @@ static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module,
   for (SymbolTable::SymbolUse use : uses) {
     Operation *callOp = use.getUser();
     assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
-    BitVector liveCallRets = markLives(callOp->getResults(), la);
+    BitVector liveCallRets = markLives(callOp->getResults(), deletionSet, la);
     nonLiveRets &= liveCallRets.flip();
   }
 
-  // Do (4).
+  // Do (3).
   // Note that in the absence of control flow ops forcing the control to go from
   // the entry (first) block to the other blocks, the control never reaches any
   // block other than the entry block, because every block has a terminator.
   for (Block &block : funcOp.getBlocks()) {
     Operation *returnOp = block.getTerminator();
     if (returnOp && returnOp->getNumOperands() == numReturns)
-      returnOp->eraseOperands(nonLiveRets);
+      cl.operands.push_back({returnOp, nonLiveRets});
   }
-  funcOp.eraseResults(nonLiveRets);
+  cl.functions.push_back({funcOp, nonLiveArgs, nonLiveRets});
 
   // Do (5) and (6).
   for (SymbolTable::SymbolUse use : uses) {
     Operation *callOp = use.getUser();
     assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
-    dropUsesAndEraseResults(callOp, nonLiveRets);
+    cl.results.push_back({callOp, nonLiveRets});
+    updateDeletionSet(deletionSet, callOp, nonLiveRets);
   }
 }
 
@@ -297,18 +349,19 @@ static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module,
 /// It is important to note that values in this op flow from operands and
 /// terminator operands (successor operands) to arguments and results (successor
 /// inputs).
-static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
+static void cleanRegionBranchOp(CleanupList &cl, DenseSet<Value> &deletionSet,
+                                RegionBranchOpInterface regionBranchOp,
                                 RunLivenessAnalysis &la) {
   // Mark live results of `regionBranchOp` in `liveResults`.
   auto markLiveResults = [&](BitVector &liveResults) {
-    liveResults = markLives(regionBranchOp->getResults(), la);
+    liveResults = markLives(regionBranchOp->getResults(), deletionSet, la);
   };
 
   // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
   auto markLiveArgs = [&](DenseMap<Region *, BitVector> &liveArgs) {
     for (Region &region : regionBranchOp->getRegions()) {
       SmallVector<Value> arguments(region.front().getArguments());
-      BitVector regionLiveArgs = markLives(arguments, la);
+      BitVector regionLiveArgs = markLives(arguments, deletionSet, la);
       liveArgs[&region] = regionLiveArgs;
     }
   };
@@ -497,9 +550,8 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
   // It could never be live because of this op but its liveness could have been
   // attributed to something else.
   if (isMemoryEffectFree(regionBranchOp.getOperation()) &&
-      !hasLive(regionBranchOp->getResults(), la)) {
-    regionBranchOp->dropAllUses();
-    regionBranchOp->erase();
+      !hasLive(regionBranchOp->getResults(), deletionSet, la)) {
+    cl.operations.push_back(regionBranchOp.getOperation());
     return;
   }
 
@@ -538,29 +590,27 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
                    terminatorOperandsToKeep);
 
   // Do (1).
-  regionBranchOp->eraseOperands(operandsToKeep.flip());
+  cl.operands.push_back({regionBranchOp, operandsToKeep.flip()});
 
   // Do (2.a) and (2.b).
   for (Region &region : regionBranchOp->getRegions()) {
     assert(!region.empty() && "expected a non-empty region in an op "
                               "implementing `RegionBranchOpInterface`");
-    for (auto [index, arg] : llvm::enumerate(region.front().getArguments())) {
-      if (argsToKeep[&region][index])
-        continue;
-      if (arg)
-        arg.dropAllUses();
-    }
-    region.front().eraseArguments(argsToKeep[&region].flip());
+    BitVector argsToRemove = argsToKeep[&region].flip();
+    cl.blocks.push_back({&region.front(), argsToRemove});
+    updateDeletionSet(deletionSet, region.front().getArguments(), argsToRemove);
   }
 
   // Do (2.c).
   for (Region &region : regionBranchOp->getRegions()) {
     Operation *terminator = region.front().getTerminator();
-    terminator->eraseOperands(terminatorOperandsToKeep[terminator].flip());
+    cl.operands.push_back({terminator, terminatorOperandsToKeep[terminator].flip()});
   }
 
   // Do (3) and (4).
-  dropUsesAndEraseResults(regionBranchOp.getOperation(), resultsToKeep.flip());
+  BitVector resultsToRemove = resultsToKeep.flip();
+  updateDeletionSet(deletionSet, regionBranchOp.getOperation(), resultsToRemove);
+  cl.results.push_back({regionBranchOp.getOperation(), resultsToRemove});
 }
 
 // 1. Iterate over each successor block of the given BranchOpInterface
@@ -572,7 +622,8 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
 //    c. Mark each operand as live or dead based on the analysis.
 // 3. Remove dead operands from the branch operation and arguments accordingly
 
-static void cleanBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la) {
+static void cleanBranchOp(CleanupList &cl, DenseSet<Value> &deletionSet,
+                          BranchOpInterface branchOp, RunLivenessAnalysis &la) {
   unsigned numSuccessors = branchOp->getNumSuccessors();
 
   // Do (1)
@@ -588,22 +639,53 @@ static void cleanBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la) {
       operandValues.push_back(successorOperands[operandIdx]);
     }
 
-    BitVector successorLiveOperands = markLives(operandValues, la);
-
     // Do (3)
-    for (int argIdx = successorLiveOperands.size() - 1; argIdx >= 0; --argIdx) {
-      if (!successorLiveOperands[argIdx]) {
-        if (successorBlock->getNumArguments() < successorOperands.size()) {
-          // if block was cleaned through a different code path
-          // we only need to remove operands from the invokation
-          successorOperands.erase(argIdx);
-          continue;
-        }
+    BitVector successorNonLive = markLives(operandValues, deletionSet, la).flip();
+    updateDeletionSet(deletionSet, successorBlock->getArguments(), successorNonLive);
+    cl.blocks.push_back({successorBlock, successorNonLive});
+    cl.successorOperands.push_back({branchOp, succIdx, successorNonLive});
+  }
+}
+
+void cleanup(CleanupList &cl) {
+  for (auto &op: cl.operations) {
+    op->dropAllUses();
+    op->erase();
+  }
+
+  for (auto &v: cl.values) {
+    v.dropAllUses();
+  }
+
+  for (auto &f: cl.functions) {
+    f.funcOp.eraseArguments(f.nonLiveArgs);
+    f.funcOp.eraseResults(f.nonLiveRets);
+  }
+
+  for (auto &o: cl.operands) {
+    o.op->eraseOperands(o.nonLiveOperands);  }
+
+  for (auto &r: cl.results) {
+    dropUsesAndEraseResults(r.op, r.nonLiveResults);
+  }
 
-        successorBlock->getArgument(argIdx).dropAllUses();
-        successorOperands.erase(argIdx);
-        successorBlock->eraseArgument(argIdx);
-      }
+  for (auto &b: cl.blocks) {
+    // blocks that are accessed via multiple codepaths processed once
+    if (b.b->getNumArguments() != b.nonLiveArgs.size()) continue;
+    for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
+      if (!b.nonLiveArgs[i]) continue;
+      b.b->getArgument(i).dropAllUses();
+      b.b->eraseArgument(i);
+    }
+  }
+  for (auto &op: cl.successorOperands) {
+    SuccessorOperands successorOperands =
+            op.branch.getSuccessorOperands(op.index);
+    // blocks that are accessed via multiple codepaths processed once
+    if (successorOperands.size() != op.nonLiveOperands.size()) continue;
+    for (int i = successorOperands.size() - 1; i >= 0; --i) {
+      if (!op.nonLiveOperands[i]) continue;
+      successorOperands.erase(i);
     }
   }
 }
@@ -616,14 +698,16 @@ struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
 void RemoveDeadValues::runOnOperation() {
   auto &la = getAnalysis<RunLivenessAnalysis>();
   Operation *module = getOperation();
+  DenseSet<Value> deletionSet;
+  CleanupList cl;
 
   module->walk([&](Operation *op) {
     if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
-      cleanFuncOp(funcOp, module, la);
+      cleanFuncOp(cl, deletionSet, funcOp, module, la);
     } else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
-      cleanRegionBranchOp(regionBranchOp, la);
+      cleanRegionBranchOp(cl, deletionSet, regionBranchOp, la);
     } else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
-      cleanBranchOp(branchOp, la);
+      cleanBranchOp(cl, deletionSet, branchOp, la);
     } else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
       // Nothing to do here because this is a terminator op and it should be
       // honored with respect to its parent
@@ -631,9 +715,11 @@ void RemoveDeadValues::runOnOperation() {
       // Nothing to do because this op is associated with a function op and gets
       // cleaned when the latter is cleaned.
     } else {
-      cleanSimpleOp(op, la);
+      cleanSimpleOp(cl, deletionSet, op, la);
     }
   });
+
+  cleanup(cl);
 }
 
 std::unique_ptr<Pass> mlir::createRemoveDeadValuesPass() {
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 9273ac01e7ccec..fe7bcbc7c490b6 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -73,6 +73,32 @@ func.func @cleanable_loop_iter_args_value(%arg0: index) -> index {
 
 // -----
 
+// Checking that the arguments of linalg.generic are properly handled
+// All code below is removed as a result of the pass
+//
+#map = affine_map<(d0, d1, d2) -> (0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+module {
+  func.func @main() {
+    %cst_3 = arith.constant dense<54> : tensor<1x25x13xi32>
+    %cst_7 = arith.constant dense<11> : tensor<1x25x13xi32>
+    // CHECK-NOT: arith.constant
+    %0 = tensor.empty() : tensor<1x25x13xi32>
+    // CHECK-NOT: tensor
+    %1 = linalg.generic {indexing_maps = [#map, #map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_3, %cst_7 : tensor<1x25x13xi32>, tensor<1x25x13xi32>) outs(%0 : tensor<1x25x13xi32>) {
+    // CHECK-NOT: linalg.generic
+    ^bb0(%in: i32, %in_15: i32, %out: i32):
+      %29 = arith.xori %in, %in_15 : i32
+      // CHECK-NOT: arith.xori
+      linalg.yield %29 : i32
+      // CHECK-NOT: linalg.yield
+    } -> tensor<1x25x13xi32>
+    return
+  }
+}
+
+// -----
+
 // Note that this cleanup cannot be done by the `canonicalize` pass.
 //
 // CHECK-LABEL: func.func private @clean_func_op_remove_argument_and_return_value() {

@parsifal-47
Copy link
Contributor Author

@banach-space please take a look when you get a chance, thank you!

@parsifal-47 parsifal-47 changed the title Prevent invalid IR from being passed outside of RemoveDeadValues [MLIR] Prevent invalid IR from being passed outside of RemoveDeadValues Dec 25, 2024
Copy link

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Discourse for more information.

Copy link

github-actions bot commented Dec 25, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants