def initialize_distributed(args): """Initialize torch.distributed.""" # Manually set the device ids. device = args.rank % torch.cuda.device_count() if args.local_rank is not None: device = args.local_rank torch.cuda.set_device(device) # Call the init process init_method = 'tcp://' master_ip = os.getenv('MASTER_ADDR', 'localhost') master_port = os.getenv('MASTER_PORT', '6000') init_method += master_ip + ':' + master_port torch.distributed.init_process_group(backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method) # Set the model-parallel / data-parallel communicators. mpu.initialize_model_parallel(args.model_parallel_size) groups.initialize(ep_size=args.expert_parallel_size, mpu=mpu) # Optional DeepSpeed Activation Checkpointing Features # if args.deepspeed and args.deepspeed_activation_checkpointing: set_deepspeed_activation_checkpointing(args)
def _test_moe(args, hidden_dim, ep_size): # E+D -- ep_size = 2 # E only -- ep_size = 4 #groups.initialize_model_parallel(1) #groups.initialize_expert_parallel(2) groups.initialize(ep_size=ep_size) model = SimpleMoEModel(hidden_dim) optimizer = torch.optim.AdamW(params=model.parameters()) model, _, _, _ = deepspeed.initialize(args=args, model=model, optimizer=optimizer, dist_init_required=False) #dist_init_required=False -- parameterize to True/False? assert dist.get_world_size() == groups.get_data_parallel_world_size( ), "incorrect data parallel world size" assert ep_size == groups.get_expert_parallel_world_size( groups.get_max_expert_size_name( )), "incorrect expert parallel world size" data_loader = sequence_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step()
def _helper(args): groups.initialize(ep_size=ep_size) models = [SimpleMoEModel(hidden_dim=hidden_dim) for _ in range(2)] optimizers = [torch.optim.AdamW(params=model.parameters()) for model in models] checkpoint_correctness_verification(args, models=models, hidden_dim=hidden_dim, tmpdir=tmpdir, load_optimizer_states=True, load_lr_scheduler_states=False, fp16=config_dict["fp16"]["enabled"], empty_tag=True, base_optimizers=optimizers, seq_dataloader=True)
def _helper(args): groups.initialize(ep_size=ep_size) models = [ SimpleMoEModel(hidden_dim=hidden_dim, num_experts=ep_size) for _ in range(2) ] params = [ split_params_into_different_moe_groups_for_optimizer( create_param_groups(model)) for model in models ] optimizers = [torch.optim.AdamW(params=param) for param in params] checkpoint_correctness_verification( args, models=models, hidden_dim=hidden_dim, tmpdir=tmpdir, load_optimizer_states=load_optim_states, load_lr_scheduler_states=False, fp16=config_dict["fp16"]["enabled"], empty_tag=True, base_optimizers=optimizers, seq_dataloader=True)