-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DeepSpeed-Chat: prefetch of layers during reward model forward pass leads to error during sample generation #337
Comments
A second question that came up while looking at this... it seems like the Should it be
|
@adammoody, thanks for the detailed analysis of this bug. To answer your question, no, This problem occurs because in RLHF we are context switching 5 models in a rank to share GPU memory. And so, the recommended solution here is to ensure that there are no |
@adammoody, can you please try this PR? Thanks! |
Thanks for the explanation and quick reply, @tjruwase . Unfortunately, I'm still hitting the same problem with this PR. The problematic params seem to be from the first few layers of the actor_model, which have been prefetched due to a forward step of the critic_model. I thought maybe we could move those empty calls to the end to try to clear any INFLIGHT actor params that the critic started to prefetch:
However, with that change I get the following error on the
|
This is confusing to me. Each model has independent prefetchers, so the critic_model should not affect the actor_model. Can you share a stack trace that you get with the PR? |
@adammoody, by the way, I was not able to repro your error on my 4xV100-16GB setup. This makes it harder to resolve. |
By the way, this is a separate bug that needs to be addressed. |
I think you have found a bug here. Do you mind opening a separate ticket for this? |
Sure. I'll post that one to the main DeepSpeed repo. |
As another clue, it seems like the following changes work around the problem. I defined a "wait on inflight" function in
I then call that from
And then I call
If I drop the
to:
then I still get this error:
It seems that it misses the necessary parameters without the recursion. |
And if I keep
I still get the original error noted at the top of the issue:
In summary, it seems that I need to do all three of these:
|
Thanks for sharing these details. I agree that First, |
The error message and stack trace when using the changes in the PR are the same as the original error report. Right, I haven't given up on figuring out this problem either. I have some more ideas to try to debug things. I'll keep posting updates. |
@tjruwase , I still haven't cracked it, but here are some more clues... The first problematic layer corresponds to the vocab embedding layer of the actor model. I did verify that layer actually belongs to the actor model, and that it is not shared with the critic model or any other model. The stack trace for the prefetch of that layer is shown below. The line numbers will vary because I've added lots of debug statements.
Note that this happens from a That trace was kicked off by a call to:
The
I did verify that this layer belongs to the actor model by matching its Python I can see that it fails for me on what I think is the third training step. This appears to be the first step where it has completed a trace and thus has enabled prefetching for the model.
At that point Since there seems to be some "model mixing" in this case, one area that caught my eye is the global
You can see I've added a print above. The vocab embedding layer is the 8th or 9th element at the time the problem occurs. I'd have to double check which if you need an exact value. That list is initialized with a base model.
In this case, I can see that the four base models correspond to the first four elements of that list. I'm not sure what 4 or so modules are stored as the elements in between the base models and the vocab embed layer at the point where I see the problem. I still haven't tracked down why invoking a function on the critic model could end up fetching params for the actor model, but I wondered if there might be some linkage here. |
@adammoody, kudos on the intensive debugging. I think I know what might be wrong, but I need your help to confirm. I have updated my PR with some asserts to verify that |
I added those new changes by hand, so my source file line numbers will be different. I have been editing DeepSpeed files in place within my python environment, so it takes some effort to set up a clean environment at this point. Anyway, if you trust that, I hit the tensor dimension mismatch here: DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py Line 77 in ce049be
If I add a second call to
then the assertion triggers:
|
Thanks for sharing these updates. Adding the second assert for the actor model cache is a really good idea. It is mystery why it fails. This supports your suspicion of a leakage between the parameter partitioning of actor and critic model. Can you confirm that your critic model is 1.3b or 350m? |
Also, can you try dropping |
Yes, I'm actually using a 350m model for the critic. I had a cut-and-paste typo in the path name when I wrote out the checkpoint, so the path suggests it's a 1.3b param model, but it is really 350m. I tried dropping the
Here are some other work arounds that I found earlier but didn't list yet:
|
Thanks for the update.
|
@adammoody, FYI I think this DeepSpeed PR from my colleague @HeyangQin might be relevant here. Please give him a bit more time to get it ready. |
@tjruwase , I think I found the cause. I believe the problem is that all four models share the same ReLU module object. Each model registers a forward hook on that module in I found this by adding the following code in
With that, I get the following example output for the actor and reward models. You can see that the
As a test, I then found that I could work around the problem by modifying the OPT model to instantiate a unique ReLU object for each layer in
|
Amazing debugging, @adammoody. Truly outstanding! @stas00, FYI. It seems ReLU objects are shared across models in the same transformer process. Do you have context for this behavior? |
yes, it's the same object, it's like a cache: The paradigm is shifting. Clearly there was no need to create a new object before because deepspeed won't support more than one model. And there is no issue with reusing the same object with multiple models outside of deepspeed world. Probably should file a feature request to create these on the fly, rather the pre-create. So that each instance will be unique. There are quite a few changes that need to be made to support multiple deepspeed models paradigm. Some possible workarounds:
|
I encountered this bug after the previous bug (microsoft/DeepSpeed#3528) was solved. @HeyangQin |
I have the same issue. Have you resolved the problem? |
Encountering the same error here. The issue persists even after updating DeepSpeed and PyTorch Lightning to the latest versions. |
When running step 3 with ZERO stage 3 enabled for both the actor and critic models,
I get the following error (line numbers may be offset due to debug statements I've added):
This happens because the
weight.data
shape does not match the tensor shape resulting from the lora matmul operation.I am using a system with 4x 16GB V100 GPUs per node with DeepSpeed 0.9.1. I trained a 1.3b-param model in step 1 and 350m-param model in step 2.
My step 3 run command launches 4 processes on one node, binding one process per GPU:
After some debugging, I found that the above error arises because the GatheredParameters context does not gather all layers. If I print the tensor shape for each parameter of each layer immediately after GatheredParameters like so:
https://github.com/microsoft/DeepSpeed/blob/050aee287d70157e29f242b3a629a4cb97b4e4e7/deepspeed/runtime/hybrid_engine.py#L238
then I see the following output on the step just before the error:
Note that dimensions of the parameters in layer_id=0 are mostly all zero. On that steps that complete without an error, those parameters have non-zero shapes as shown below. The count of
non_active_layers
in 962 below vs 931 above.By adding the following lines for further details:
https://github.com/microsoft/DeepSpeed/blob/050aee287d70157e29f242b3a629a4cb97b4e4e7/deepspeed/runtime/hybrid_engine.py#L234-L238
It seems that the 0-shape parameters are marked as "ds_status == ZeroParamStatus.INFLIGHT" before calling "GatheredParameters":
I think those parameters are marked as INFLIGHT because they have been prefetched.
Adding some more debugging lines to print the stack at the point where the status is set to INFLIGHT:
https://github.com/microsoft/DeepSpeed/blob/050aee287d70157e29f242b3a629a4cb97b4e4e7/deepspeed/runtime/zero/partition_parameters.py#L873-L885
I can see those layers are set to INFLIGHT here:
It seems that the layers are being prefetched during the call to the critic model forward pass:
DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py
Line 174 in 2aa7a31
They are still in
INFLIGHT
status when trying to generate a sample. Theget_inactive_params
function then only include params marked asNOT_AVAILABLE
:https://github.com/microsoft/DeepSpeed/blob/48297c4841374776a3c4d9319a9b57378f598d65/deepspeed/runtime/utils.py#L972-L975
Later,
GatheredParameters
may only consider params whose state is NOT_AVAILABLE:https://github.com/microsoft/DeepSpeed/blob/48297c4841374776a3c4d9319a9b57378f598d65/deepspeed/runtime/zero/partition_parameters.py#L1058
Assuming that diagnosis is correct, I'm not sure what the recommended fix would be. Should
get_inactive_params
includeINFLIGHT
params?The text was updated successfully, but these errors were encountered: