Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Step1: Not mask source part #660

Open
HuggingLLM opened this issue Aug 1, 2023 · 1 comment · May be fixed by #762
Open

[Bug] Step1: Not mask source part #660

HuggingLLM opened this issue Aug 1, 2023 · 1 comment · May be fixed by #762

Comments

@HuggingLLM
Copy link

Desciption: In DeepSpeed-Chat step1 PromptDataset. According to my understanding, the SFT part needs to mask the source part, and the source part does not participate in the loss calculation.

class PromptDataset(Dataset):

    def __init__(self, prompt_dataset, chosen_dataset, reject_dataset,
                 pad_token_id, train_phase) -> None:
        super().__init__()
        self.prompt_dataset = prompt_dataset
        self.chosen_dataset = chosen_dataset
        self.reject_dataset = reject_dataset
        self.pad_token_id = pad_token_id
        self.train_phase = train_phase

    def __len__(self):
        length = len(self.chosen_dataset)
        if self.train_phase == 3:
            length = len(self.prompt_dataset)
        return length

    def __getitem__(self, idx):
        if self.train_phase == 1:
            return {
                "input_ids": self.chosen_dataset[idx]["input_ids"],
                "attention_mask": self.chosen_dataset[idx]["attention_mask"],
                "labels": self.chosen_dataset[idx]["input_ids"] # maybe: [*[-100] * source_len, *[target_ids]]
            }
        elif self.train_phase == 2:
            return self.chosen_dataset[idx]["input_ids"], self.chosen_dataset[idx]["attention_mask"], \
                self.reject_dataset[idx]["input_ids"], self.reject_dataset[idx]["attention_mask"]
        elif self.train_phase == 3:
            return self.prompt_dataset[idx]["input_ids"],self.prompt_dataset[idx]["attention_mask"], \
                self.pad_token_id
@manestay
Copy link

manestay commented Aug 4, 2023

+1 I also believe this is a bug. I guess because the SFT is the first step of 3, the bug might not affect the final result so much. But nevertheless without masking the source part, it is not correct.

Maybe adapting the data collator to perform masking of the labels like this would work https://huggingface.co/docs/trl/main/en/sft_trainer#train-on-completions-only

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants