Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for DS comms #50

Open
wants to merge 1 commit into
base: main_before_rebase
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions megatron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import deepspeed

from .package_info import (
__description__,
Expand All @@ -37,19 +38,19 @@

def print_rank_0(message):
"""If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
if deepspeed.comm.is_initialized():
if deepspeed.comm.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)

def is_last_rank():
return torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1)
return deepspeed.comm.get_rank() == (
deepspeed.comm.get_world_size() - 1)

def print_rank_last(message):
"""If distributed is initialized, print only on last rank."""
if torch.distributed.is_initialized():
if deepspeed.comm.is_initialized():
if is_last_rank():
print(message, flush=True)
else:
Expand Down
21 changes: 11 additions & 10 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import numpy as np

import torch
import deepspeed

from megatron import (get_args,
mpu,
Expand Down Expand Up @@ -118,7 +119,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
iteration, args.save))

if not torch.distributed.is_initialized() or mpu.get_data_parallel_rank() == 0 \
if not deepspeed.comm.is_initialized() or mpu.get_data_parallel_rank() == 0 \
or args.deepspeed:

# Arguments, iteration, and model.
Expand Down Expand Up @@ -177,21 +178,21 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
model[0].module.state_dict = original_state_dict

# Wait so everyone is done (necessary)
if torch.distributed.is_initialized():
torch.distributed.barrier()
if deepspeed.comm.is_initialized():
deepspeed.comm.barrier()

print_rank_0(' successfully saved checkpoint at iteration {:7d} to {}'.format(
iteration, args.save))

# And update the latest iteration
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(args.save)
with open(tracker_filename, 'w') as f:
f.write(str(iteration))

# Wait so everyone is done (not necessary)
if torch.distributed.is_initialized():
torch.distributed.barrier()
if deepspeed.comm.is_initialized():
deepspeed.comm.barrier()

def _transpose_first_dim(t, num_splits, num_splits_first, model):
input_shape = t.size()
Expand Down Expand Up @@ -419,8 +420,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
sys.exit()

# Some utilities want to load a checkpoint without distributed being initialized
if torch.distributed.is_initialized():
torch.distributed.barrier()
if deepspeed.comm.is_initialized():
deepspeed.comm.barrier()

print_rank_0(f' successfully loaded checkpoint from {args.load} '
f'at iteration {iteration}')
Expand Down Expand Up @@ -448,7 +449,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,
checkpoint_name = get_checkpoint_name(load_path, iteration, False)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
deepspeed.comm.get_rank(), checkpoint_name))

state_dict = torch.load(checkpoint_name, map_location='cpu')
ret_state_dict = state_dict['model']
Expand All @@ -460,7 +461,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,

assert len(model) == 1
model[0].load_state_dict(ret_state_dict)
torch.distributed.barrier()
deepspeed.comm.barrier()

if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))
Expand Down
7 changes: 4 additions & 3 deletions megatron/data/biencoder_dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import torch
import deepspeed

from megatron import get_args, get_tokenizer, mpu, print_rank_0
from megatron.data.dataset_utils import create_masked_lm_predictions, \
Expand Down Expand Up @@ -157,7 +158,7 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
assert block_dataset.sizes.dtype == np.int32

# Build samples mapping
verbose = torch.distributed.get_rank() == 0
verbose = deepspeed.comm.get_rank() == 0
start_time = time.time()
print_rank_0(' > building samples index mapping for {} ...'.format(
name))
Expand Down Expand Up @@ -188,8 +189,8 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size(
deepspeed.comm.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == deepspeed.comm.get_world_size(
group=mpu.get_data_parallel_group())

# Load indexed dataset.
Expand Down
3 changes: 2 additions & 1 deletion megatron/data/blendable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import numpy as np
import torch
import deepspeed

from megatron import print_rank_0
from megatron import mpu
Expand Down Expand Up @@ -53,7 +54,7 @@ def __init__(self, datasets, weights):
helpers.build_blending_indices(self.dataset_index,
self.dataset_sample_index,
weights, num_datasets, self.size,
torch.distributed.get_rank() == 0)
deepspeed.comm.get_rank() == 0)
print_rank_0('> elapsed time for building blendable dataset indices: '
'{:.2f} (sec)'.format(time.time() - start_time))

Expand Down
13 changes: 7 additions & 6 deletions megatron/data/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import numpy as np
import torch
import deepspeed

from megatron import (
get_args,
Expand Down Expand Up @@ -662,7 +663,7 @@ def get_samples_mapping(indexed_dataset,
indexmap_filename += '.npy'

# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0 and \
if deepspeed.comm.get_rank() == 0 and \
not os.path.isfile(indexmap_filename):
print(' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'.format(indexmap_filename))
Expand All @@ -672,7 +673,7 @@ def get_samples_mapping(indexed_dataset,
assert indexed_dataset.sizes.dtype == np.int32

# Build samples mapping
verbose = torch.distributed.get_rank() == 0
verbose = deepspeed.comm.get_rank() == 0
start_time = time.time()
print_rank_0(' > building sapmles index mapping for {} ...'.format(
name))
Expand Down Expand Up @@ -700,11 +701,11 @@ def get_samples_mapping(indexed_dataset,
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
deepspeed.comm.all_reduce(counts, group=mpu.get_data_parallel_group())
deepspeed.comm.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
assert counts[0].item() == (
torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()))
deepspeed.comm.get_world_size() //
deepspeed.comm.get_world_size(group=mpu.get_tensor_model_parallel_group()))

# Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format(
Expand Down
11 changes: 6 additions & 5 deletions megatron/data/gpt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import numpy as np
import torch
import deepspeed

from megatron import mpu, print_rank_0
from megatron.data.blendable_dataset import BlendableDataset
Expand Down Expand Up @@ -212,7 +213,7 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
shuffle_idx_filename = _filename + '_shuffle_idx.npy'

# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0:
if deepspeed.comm.get_rank() == 0:
if (not os.path.isfile(doc_idx_filename)) or \
(not os.path.isfile(sample_idx_filename)) or \
(not os.path.isfile(shuffle_idx_filename)):
Expand Down Expand Up @@ -297,11 +298,11 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
deepspeed.comm.all_reduce(counts, group=mpu.get_data_parallel_group())
deepspeed.comm.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
assert counts[0].item() == (
torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()))
deepspeed.comm.get_world_size() //
deepspeed.comm.get_world_size(group=mpu.get_tensor_model_parallel_group()))

# Load mappings.
start_time = time.time()
Expand Down
7 changes: 4 additions & 3 deletions megatron/data/realm_dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import torch
import deepspeed

from megatron import mpu, print_rank_0
from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy
Expand Down Expand Up @@ -147,7 +148,7 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
assert block_dataset.sizes.dtype == np.int32

# Build samples mapping
verbose = torch.distributed.get_rank() == 0
verbose = deepspeed.comm.get_rank() == 0
start_time = time.time()
print_rank_0(' > building samples index mapping for {} ...'.format(
name))
Expand Down Expand Up @@ -178,8 +179,8 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size(
deepspeed.comm.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == deepspeed.comm.get_world_size(
group=mpu.get_data_parallel_group())

# Load indexed dataset.
Expand Down
7 changes: 4 additions & 3 deletions megatron/global_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import time

import torch
import deepspeed

from megatron.tokenizer import build_tokenizer
from .arguments import parse_args
Expand Down Expand Up @@ -254,9 +255,9 @@ def log(self, names, normalizer=1.0, reset=True):
elapsed_time = self.timers[name].elapsed(
reset=reset) * 1000.0 / normalizer
string += ' | {}: {:.2f}'.format(name, elapsed_time)
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1):
if deepspeed.comm.is_initialized():
if deepspeed.comm.get_rank() == (
deepspeed.comm.get_world_size() - 1):
print(string, flush=True)
else:
print(string, flush=True)
6 changes: 3 additions & 3 deletions megatron/indexer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys
import torch
import torch.distributed as dist
import deepspeed.comm as dist

from megatron import get_args
from megatron import mpu
Expand Down Expand Up @@ -112,7 +112,7 @@ def build_and_save_index(self):
# This process signals to finalize its shard and then synchronize with
# the other processes
self.evidence_embedder_obj.save_shard()
torch.distributed.barrier()
deepspeed.comm.barrier()
del self.model

# rank 0 process builds the final copy
Expand All @@ -124,4 +124,4 @@ def build_and_save_index(self):
self.evidence_embedder_obj.clear()

# complete building the final copy
torch.distributed.barrier()
deepspeed.comm.barrier()
32 changes: 16 additions & 16 deletions megatron/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
args_defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args)

# torch.distributed initialization
# deepspeed.comm initialization
def finish_mpu_init():
args = get_args()
# Pytorch distributed.
# Pydeepspeed.comm.
_initialize_distributed()

# Random seeds for reproducibility.
Expand Down Expand Up @@ -100,7 +100,7 @@ def _compile_dependencies():
# Compile dataset C++ code.
# =========================
# TODO: move this to ninja
if torch.distributed.get_rank() == 0:
if deepspeed.comm.get_rank() == 0:
start_time = time.time()
print('> compiling dataset index builder ...')
from megatron.data.dataset_utils import compile_helper
Expand Down Expand Up @@ -131,20 +131,20 @@ def _compile_dependencies():
' back to unfused kernel invocations.', flush=True)

# Always build on rank zero first.
if torch.distributed.get_rank() == 0:
if deepspeed.comm.get_rank() == 0:
start_time = time.time()
print('> compiling and loading fused kernels ...', flush=True)
fused_kernels.load(args)
torch.distributed.barrier()
deepspeed.comm.barrier()
else:
torch.distributed.barrier()
deepspeed.comm.barrier()
fused_kernels.load(args)
# Simple barrier to make sure all ranks have passed the
# compilation phase successfully before moving on to the
# rest of the program. We think this might ensure that
# the lock is released.
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
deepspeed.comm.barrier()
if deepspeed.comm.get_rank() == 0:
print('>>> done with compiling and loading fused kernels. '
'Compilation time: {:.3f} seconds'.format(
time.time() - start_time), flush=True)
Expand Down Expand Up @@ -182,20 +182,20 @@ def setup_deepspeed_random_and_activation_checkpointing(args):


def _initialize_distributed():
"""Initialize torch.distributed and mpu."""
"""Initialize deepspeed.comm and mpu."""
args = get_args()
device_count = torch.cuda.device_count()
if torch.distributed.is_initialized():
if deepspeed.comm.is_initialized():

if args.rank == 0:
print('torch distributed is already initialized, '
print('deepspeed.comm is already initialized, '
'skipping initialization ...', flush=True)
args.rank = torch.distributed.get_rank()
args.world_size = torch.distributed.get_world_size()
args.rank = deepspeed.comm.get_rank()
args.world_size = deepspeed.comm.get_world_size()

else:
if args.rank == 0:
print('> initializing torch distributed ...', flush=True)
print('> initializing deepspeed.comm ...', flush=True)
# Manually set the device ids.
if device_count > 0:
device = args.rank % device_count
Expand Down Expand Up @@ -238,9 +238,9 @@ def _init_autoresume():
"""Set autoresume start time."""
autoresume = get_adlr_autoresume()
if autoresume:
torch.distributed.barrier()
deepspeed.comm.barrier()
autoresume.init()
torch.distributed.barrier()
deepspeed.comm.barrier()


def _set_random_seed(seed_):
Expand Down
Loading