def setup_model(args):
    """Setup model and optimizer."""

    model = get_model(args)
    if DEEPSPEED_WRAP and args.deepspeed:
        print_rank_0("DeepSpeed is enabled.")

        model, optimizer, _, lr_scheduler = DEEPSPEED_WRAP.deepspeed.initialize(
            model=model,
            optimizer=None,
            args=args,
            lr_scheduler=None,
            mpu=mpu,
            dist_init_required=False)

    print("Load checkpoint from " + args.load)
    _ = load_checkpoint(model,
                        None,
                        None,
                        args,
                        deepspeed=DEEPSPEED_WRAP and args.deepspeed)
    model.eval()
    print("Loaded")
    if args.export_huggingface is not None:
        export_to_huggingface_model(model, args.export_huggingface)
        print(f"Exported in huggingface format to {args.export_huggingface}")

    return model
def get_train_val_test_data(args):
    """Load the data on rank zero and boradcast number of tokens to all GPUS."""

    (train_data, val_data, test_data) = (None, None, None)

    # Data loader only on rank 0 of each model parallel group.
    if mpu.get_model_parallel_rank() == 0:
        (train_data, val_data, test_data), num_tokens, eod_token, tokenizer = make_gpt3_dataloaders(args)
        before = num_tokens
        after = before
        multiple = args.make_vocab_size_divisible_by * mpu.get_model_parallel_world_size()
        while (after % multiple) != 0:
            after += 1
        print_rank_0(
            '> padded vocab (size: {}) with {} dummy tokens (new size: {})'.format(before, after - before, after))
        print_rank_0('> end-of-document token: {}'.format(eod_token))
        token_counts = torch.cuda.LongTensor(
            [after, eod_token, int(args.do_train), int(args.do_valid), int(args.do_test)])
    else:
        tokenizer = None
        token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(token_counts,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    num_tokens = token_counts[0].item()
    eod_token = token_counts[1].item()
    args.do_train = token_counts[2].item()
    args.do_valid = token_counts[3].item()
    args.do_test = token_counts[4].item()

    return train_data, val_data, test_data, num_tokens, eod_token, tokenizer
def get_model(args):
    """Build the model."""

    print_rank_0('building GPT3 model ...')
    model = GPT3Model(num_layers=args.num_layers,
                      vocab_size=args.vocab_size,
                      hidden_size=args.hidden_size,
                      num_attention_heads=args.num_attention_heads,
                      embedding_dropout_prob=args.hidden_dropout,
                      attention_dropout_prob=args.attention_dropout,
                      output_dropout_prob=args.hidden_dropout,
                      max_sequence_length=args.max_position_embeddings,
                      checkpoint_activations=args.checkpoint_activations,
                      checkpoint_num_layers=args.checkpoint_num_layers,
                      parallel_output=False)

    if mpu.get_data_parallel_rank() == 0:
        print(' > number of parameters on model parallel rank {}: {}'.format(
            mpu.get_model_parallel_rank(),
            sum([p.nelement() for p in model.parameters()])),
              flush=True)

    # GPU allocation.
    model.cuda(torch.cuda.current_device())

    # Fp16 conversion.
    if args.fp16:
        model = FP16_Module(model)

    # Wrap model for distributed training.
    model = DDP(model)

    return model
def get_model(args):
    """Build the model."""

    print_rank_0('building GPT3 model ...')
    assert args.num_attention_heads % args.model_parallel_size == 0
    num_local_heads = args.num_attention_heads // args.model_parallel_size
    deepspeed_sparsity_config = None
    if DEEPSPEED_WRAP and args.deepspeed:
        deepspeed_sparsity_config = get_sparse_attention_config(args, num_local_heads)
    if deepspeed_sparsity_config is not None:
        print_rank_0(f"Use sparse attention with mode {args.sparse_mode}")
    model = GPT3Model(num_layers=args.num_layers,
                      vocab_size=args.vocab_size,
                      hidden_size=args.hidden_size,
                      num_attention_heads=args.num_attention_heads,
                      embedding_dropout_prob=args.hidden_dropout,
                      attention_dropout_prob=args.attention_dropout,
                      output_dropout_prob=args.hidden_dropout,
                      max_sequence_length=args.max_position_embeddings,
                      checkpoint_activations=args.checkpoint_activations,
                      checkpoint_num_layers=args.checkpoint_num_layers,
                      parallel_output=True,
                      deepspeed_sparsity_config=deepspeed_sparsity_config,
                      sparse_mode=args.sparse_mode)

    if args.load_huggingface is not None:
        model = load_huggingface_model(model, args.load_huggingface, args.huggingface_double_pos_embeddings)

    if mpu.get_data_parallel_rank() == 0:
        print(' > number of parameters on model parallel rank {}: {}'.format(
            mpu.get_model_parallel_rank(),
            sum([p.nelement() for p in model.parameters()])), flush=True)

    # To prevent OOM for model sizes that cannot fit in GPU memory in full precision
    if DEEPSPEED_WRAP and args.deepspeed and args.fp16:
        model.half()

    # GPU allocation.
    model.cuda(torch.cuda.current_device())

    # Fp16 conversion.
    if args.fp16:
        model = FP16_Module(model)

    # Wrap model for distributed training.
    if USE_TORCH_DDP:
        i = torch.cuda.current_device()
        model = DDP(model, device_ids=[i], output_device=i,
                    process_group=mpu.get_data_parallel_group())
    else:
        model = DDP(model)

    return model
def setup_model_and_optimizer(args):
    """Setup model and optimizer."""

    model = get_model(args)
    optimizer = get_optimizer(model, args)
    lr_scheduler = get_learning_rate_scheduler(optimizer, args)

    if DEEPSPEED_WRAP and args.deepspeed:
        print_rank_0("DeepSpeed is enabled.")

        model, optimizer, _, lr_scheduler = DEEPSPEED_WRAP.deepspeed.initialize(
            model=model,
            optimizer=optimizer,
            args=args,
            lr_scheduler=lr_scheduler,
            mpu=mpu,
            dist_init_required=False
        )

    if args.load is not None:
        print_rank_0("Load checkpoint from " + args.load)
        args.iteration = load_checkpoint(model, optimizer, lr_scheduler, args, deepspeed=DEEPSPEED_WRAP and args.deepspeed)
        print_rank_0("Checkpoint loaded")
    else:
        args.iteration = 0

    return model, optimizer, lr_scheduler
Beispiel #6
0
def evaluate(data_iterator, model, args, timers, verbose=False):
    """Evaluation."""

    print (f"evaluate got passed data_iterator {data_iterator}")
    # Turn on evaluation mode which disables dropout.
    model.eval()

    total_lm_loss = 0
    eval_len = args.eval_iters or len(data_iterator)

    with torch.no_grad():
        # stop = False
        iteration = 0
        while iteration < eval_len:
            iteration += 1
            if verbose and iteration % args.log_interval == 0:
                print_rank_0('Evaluating iter {}/{}'.format(iteration, eval_len))
            # Forward evaluation.
            sample = next(data_iterator) if (data_iterator is not None) else None
            lm_loss = forward_step(sample, model, args, timers)

            '''when contiguous memory optimizations are enabled, the buffers
            allocated by the optimizations are deallocated during backward pass
            in the absence of backward pass the buffers should be reset after each
            forward pass'''
            if DEEPSPEED_WRAP and args.deepspeed and args.deepspeed_activation_checkpointing:
                DEEPSPEED_WRAP.deepspeed.checkpointing.reset()

            # Reduce across processes.
            if isinstance(model, DDP):
                torch.distributed.all_reduce(lm_loss.data)
                lm_loss.data = lm_loss.data / args.world_size

            total_lm_loss += lm_loss.data.detach().float().item()

    # Move model back to the train mode.
    model.train()

    total_lm_loss /= eval_len
    return total_lm_loss
Beispiel #7
0
def setup_model_and_optimizer(args):
    """Setup model and optimizer."""

    print ("setting up model...")
    model = get_model(args)
    print ("setting up optimizer...")
    optimizer = get_optimizer(model, args)
    print ("setting up lr scheduler...")
    lr_scheduler = get_learning_rate_scheduler(optimizer, args)

    
    if DEEPSPEED_WRAP and args.deepspeed:
        print_rank_0("DeepSpeed is enabled.")

        print ("Calling deepspeed.initialize with our model, optimizer and scheduler")
        model, optimizer, _, lr_scheduler = DEEPSPEED_WRAP.deepspeed.initialize(
            model=model,
            optimizer=optimizer,
            args=args,
            lr_scheduler=lr_scheduler,
            mpu=mpu,
            dist_init_required=False
        )
        print ("We've wrapped our model, optimizer and scheduler in DeepSpeed")

    if args.load is not None:
        print_rank_0("Load checkpoint from " + args.load)
        args.iteration = load_checkpoint(model, optimizer, lr_scheduler, args, deepspeed=DEEPSPEED_WRAP and args.deepspeed)
        print_rank_0("Checkpoint loaded")
#         input ("This was all it took? Mother...")
    else:
        args.iteration = 0

    print ("returning our model, optimizer and scheduler")    
    return model, optimizer, lr_scheduler
Beispiel #8
0
    def make_data_loader_(data_path, dataset_args):
        print_rank_0(
            f'Load RuGPT3 Dataset from {data_path}, {dataset_args.max_files_load} files per process'
        )
        dataset = RuGpt3TextDataset(
            tokenizer=tokenizer,
            args=dataset_args,
            rank=rank,
            world_size=world_size,
            file_path=data_path,
            # cache_prefix=args.cache_prefix
            all_args=args)
        # Use a simple sampler with distributed batch sampler.
        sampler = torch.utils.data.SequentialSampler(dataset)
        batch_sampler = ResumableBatchSampler(sampler=sampler,
                                              batch_size=args.batch_size,
                                              drop_last=True)

        return InfiniteDataLoader(dataset,
                                  batch_sampler=batch_sampler,
                                  num_workers=num_workers,
                                  pin_memory=True)
Beispiel #9
0
def get_model(args):
    """Build the model."""

    print_rank_0('building GPT3 model ...')
    print ("Calling GPT3Model constructor...")  
    model = GPT3Model(num_layers=args.num_layers,
                      vocab_size=args.vocab_size,
                      hidden_size=args.hidden_size,
                      num_attention_heads=args.num_attention_heads,
                      embedding_dropout_prob=args.hidden_dropout,
                      attention_dropout_prob=args.attention_dropout,
                      output_dropout_prob=args.hidden_dropout,
                      max_sequence_length=args.max_position_embeddings,
                      checkpoint_activations=args.checkpoint_activations,
                      checkpoint_num_layers=args.checkpoint_num_layers,
                      parallel_output=False)

    if mpu.get_data_parallel_rank() == 0:
        print(' > number of parameters on model parallel rank {}: {}'.format(
            mpu.get_model_parallel_rank(),
            sum([p.nelement() for p in model.parameters()])), flush=True)

    # GPU allocation.
    print (f"placing the model on device {torch.cuda.current_device()}")
    model.cuda(torch.cuda.current_device())

    # Fp16 conversion.
    if args.fp16:
        rint ("we have NOT halfed the model before, and now we're wrapping it into a fp16_module. For...some reason...")
        model = FP16_Module(model)

    # Wrap model for distributed training.
    print ("Setting up distributed training...")
    print ("No classic pytorch DDP this time; \nUsing sberbank magic DDP")
    model = DDP(model)

    input ("ready to return model")
    return model
def evaluate_and_print_results(prefix, data_iterator, model,
                               args, timers, verbose=False):
    """Helper function to evaluate and dump results on screen."""
    if args.load_tag:
        prefix = 'checkpoint {}'.format(args.load_tag)
    lm_loss = evaluate(data_iterator, model, args, timers, verbose)
    lm_ppl = math.exp(min(20, lm_loss))
    string = ' validation loss at {} | '.format(prefix)
    string += 'LM loss: {:.4f} | '.format(lm_loss)
    string += 'LM PPL: {:.3f}'.format(lm_ppl)
    length = len(string) + 1
    print_rank_0('-' * length)
    print_rank_0(string)
    print_rank_0('-' * length)

    return lm_loss, lm_ppl
Beispiel #11
0
def make_gpt3_dataloaders(args):
    # Data parallel arguments
    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    # global_batch_size = args.batch_size * world_size
    num_workers = args.num_workers

    # data_dir = args.train_data_path if args.train_data_path else os.path.dirname(args.test_data_path)
    tokenizer_path = args.load_huggingface if args.load_huggingface else \
        (args.tokenizer_path if args.tokenizer_path else os.path.join(os.path.dirname(args.train_data_path),
                                                                      '_tokenizer/'))
    print_rank_0('Load tokenizer from ' + tokenizer_path)
    tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)
    tokenizer.add_special_tokens({"bos_token": "<s>"})
    tokenizer.add_special_tokens({"eos_token": "</s>"})

    print("Add answer_sep:", args.answer_sep)
    tokenizer.add_tokens(args.answer_sep)

    print("Add start_sep", args.start_sep)
    tokenizer.add_tokens(args.start_sep)

    print("Add start_sep", args.end_sep)
    tokenizer.add_tokens(args.end_sep)

    eod_token = tokenizer.encoder['<pad>']
    num_tokens = len(tokenizer)

    train_dataset_args = RuGpt3DatasetArguments(
        block_size=args.seq_length,
        max_files_load=args.max_files_per_process,
        overwrite_cache=args.overwrite_cache,
        tqdm=False)
    eval_dataset_args = RuGpt3DatasetArguments(
        block_size=args.seq_length,
        max_files_load=args.max_files_per_process,
        overwrite_cache=args.overwrite_cache,
        tqdm=True)

    def make_data_loader_(data_path, dataset_args):
        print_rank_0(
            f'Load RuGPT3 Dataset from {data_path}, {dataset_args.max_files_load} files per process'
        )
        dataset = RuGpt3TextDataset(
            tokenizer=tokenizer,
            args=dataset_args,
            rank=rank,
            world_size=world_size,
            file_path=data_path,
            # cache_prefix=args.cache_prefix
            all_args=args)
        # Use a simple sampler with distributed batch sampler.
        sampler = torch.utils.data.SequentialSampler(dataset)
        batch_sampler = ResumableBatchSampler(sampler=sampler,
                                              batch_size=args.batch_size,
                                              drop_last=True)

        return InfiniteDataLoader(dataset,
                                  batch_sampler=batch_sampler,
                                  num_workers=num_workers,
                                  pin_memory=True)

    train = make_data_loader_(
        args.train_data_path,
        train_dataset_args) if args.train_data_path else None
    valid = make_data_loader_(
        args.val_data_path, eval_dataset_args) if args.val_data_path else None
    test = make_data_loader_(
        args.test_data_path,
        eval_dataset_args) if args.test_data_path else None

    args.do_train = train is not None
    args.do_valid = valid is not None
    args.do_test = test is not None

    return (train, valid, test), num_tokens, eod_token, tokenizer
def main():
    """Main training program."""

    # Disable CuDNN.
    torch.backends.cudnn.enabled = False

    # Timer.
    timers = Timers()

    # Arguments.
    args = get_args()

    #     if args.load_huggingface:
    #         args.make_vocab_size_divisible_by = 1

    # Pytorch distributed.
    initialize_distributed(args)
    if torch.distributed.get_rank() == 0:
        print('Pretrain GPT3 model')
        print_args(args)

    # Random seeds for reproducability.
    set_random_seed(args.seed)

    # Data stuff.
    train_data, val_data, test_data, args.vocab_size, args.eod_token, tokenizer = get_train_val_test_data(args)

    # Model, optimizer, and learning rate.
    model, optimizer, lr_scheduler = setup_model_and_optimizer(args)

    # Resume data loader if necessary.
    if args.resume_dataloader:
        if train_data is not None:
            train_data.batch_sampler.start_iter = args.iteration % len(train_data)
            print_rank_0(f"Resume train set from iteration {train_data.batch_sampler.start_iter}")
        if val_data is not None:
            start_iter_val = (args.train_iters // args.save_interval) * args.eval_interval
            val_data.batch_sampler.start_iter = start_iter_val % len(val_data)
    if train_data is not None:
        train_data_iterator = iter(train_data)
    else:
        train_data_iterator = None

    iteration = 0
    if args.train_iters > 0:
        if args.do_train:
            iteration, skipped = train(model, optimizer,
                                       lr_scheduler,
                                       train_data_iterator,
                                       val_data,
                                       timers,
                                       args,
                                       tokenizer)

        if args.do_valid:
            prefix = 'the end of training for val data'
            # val_loss, val_ppl
            _ = evaluate_and_print_results(prefix, iter(val_data) if val_data else None,
                                           model, args, timers, False)

    if args.save and iteration != 0:
        save_checkpoint(iteration, model, optimizer, lr_scheduler, args, deepspeed=DEEPSPEED_WRAP and args.deepspeed)

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, iter(test_data) if test_data else None,
                                   model, args, timers, True)
def train(model, optimizer, lr_scheduler,
          train_data_iterator, val_data, timers, args, tokenizer):
    """Train the model."""

    # Turn on training mode which enables dropout.
    model.train()

    # Tracking loss.
    total_lm_loss = 0.0

    # Iterations.
    iteration = args.iteration
    skipped_iters = 0
    tb_writer = None
    if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
        tb_writer = SummaryWriter(log_dir=args.logging_dir)

    timers('interval time').start()
    report_memory_flag = True
    is_master = not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
    print('--Start training loop--')
    train_start = True
    # avg_lm_loss = 1e6
    while iteration < args.train_iters:
        timers('data loader').start()
        sample = next(train_data_iterator) if (train_data_iterator is not None) else None
        timers('data loader').stop()

        if train_start and is_master:
            batch_text = f"\n\Iteration {iteration} start sample: {tokenizer.decode(sample[0, :200])}"
            tb_writer.add_text('train_start', batch_text, iteration)

        lm_loss, skipped_iter = train_step(sample,
                                           model,
                                           optimizer,
                                           lr_scheduler,
                                           args, timers, tokenizer, iteration, tb_writer)
        skipped_iters += skipped_iter
        iteration += 1
        train_start = False

        # Update losses.
        total_lm_loss += lm_loss.data.detach().float()

        # Logging.
        if is_master and iteration % args.log_interval == 0:
            learning_rate = optimizer.param_groups[0]['lr']
            avg_lm_loss = total_lm_loss.item() / args.log_interval
            ppl = math.exp(avg_lm_loss)
            elapsed_time = timers('interval time').elapsed()
            samples = args.log_interval * mpu.get_data_parallel_world_size() * args.batch_size
            tokens = samples * args.seq_length
            log_string = ' iteration {:8d}/{:8d} |'.format(iteration, args.train_iters)
            log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(elapsed_time * 1000.0 / args.log_interval)
            log_string += ' learning rate {:.3E} |'.format(learning_rate)
            log_string += ' lm loss {:.4f} |'.format(avg_lm_loss)
            log_string += ' perplexity {:.4f} |'.format(ppl)
            scalars = {
                'Loss/loss': avg_lm_loss,
                'Loss/perplexity': ppl,
                'learning_rate': learning_rate,
                'Speed/iteration_time_ms': (elapsed_time * 1000.0 / args.log_interval),
                'Speed/samples_per_sec': (samples / elapsed_time),
                'Speed/tokens_per_sec': (tokens / elapsed_time),
                'Speed/tokens_per_step': (tokens / args.log_interval),
                'Speed/seen_tokens': iteration * (tokens / args.log_interval)
            }
            if args.fp16:
                lscale = optimizer.cur_scale if DEEPSPEED_WRAP and args.deepspeed else optimizer.loss_scale
                log_string += ' loss scale {:.1f} |'.format(lscale)
                scalars['lscale'] = lscale
            print_rank_0(log_string)
            for k, v in scalars.items():
                tb_writer.add_scalar(k, v, iteration)

            if ppl < 3:
                # generate only when model is relatively good
                prefix = 'Бразильские ученые открыли редкий вид карликовых единорогов, обитающих на западе Ютландии'
                model.eval()
                with torch.no_grad():
                    text = generate(model, tokenizer, prefix, 128)
                model.train()
                tb_writer.add_text('sample', text, iteration)

            if args.log_memory:
                log_memory_usage(tb_writer, iteration)
            total_lm_loss = 0.0
            if report_memory_flag:
                report_memory('after {} iterations'.format(iteration))
                report_memory_flag = False
            if USE_TORCH_DDP:
                timers.log(['forward', 'backward', 'optimizer', 'data loader'], normalizer=args.log_interval)
            else:
                timers.log(['forward', 'backward', 'allreduce', 'optimizer', 'data loader'],
                           normalizer=args.log_interval)
        # Checkpointing
        if args.save and args.save_interval and iteration % args.save_interval == 0:
            save_checkpoint(iteration, model, optimizer, lr_scheduler, args, deepspeed=DEEPSPEED_WRAP and args.deepspeed)

        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and args.do_valid:
            prefix = 'iteration {}'.format(iteration)
            val_loss, val_ppl = evaluate_and_print_results(
                prefix, iter(val_data) if val_data else None, model, args, timers, False)
            if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
                scalars = {'val_loss': val_loss, 'val_perplexity': val_ppl}
                for k, v in scalars.items():
                    tb_writer.add_scalar(k, v, iteration)

        if args.exit_interval and iteration % args.exit_interval == 0:
            torch.distributed.barrier()
            time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            rank = torch.distributed.get_rank()
            print('rank: {} | time: {} | exiting the program at iteration {}'.
                  format(rank, time_str, iteration), flush=True)
            exit()

    return iteration, skipped_iters
Beispiel #14
0
def get_model(args):
    """Build the model."""

    print_rank_0('building GPT3 model ...')
    print ("asserting we have a correct number of attention heads...")
    assert args.num_attention_heads % args.model_parallel_size == 0
    num_local_heads = args.num_attention_heads // args.model_parallel_size
    deepspeed_sparsity_config = None
    if DEEPSPEED_WRAP and args.deepspeed:
        print ("we're using deepspeed, and so we're getting a sparse attention config")
        deepspeed_sparsity_config = get_sparse_attention_config(args, num_local_heads)
    if deepspeed_sparsity_config is not None:
        print_rank_0(f"Using sparse attention with mode {args.sparse_mode}")
    print ("Calling GPT3Model constructor...")    
    model = GPT3Model(num_layers=args.num_layers,
                      vocab_size=args.vocab_size,
                      hidden_size=args.hidden_size,
                      num_attention_heads=args.num_attention_heads,
                      embedding_dropout_prob=args.hidden_dropout,
                      attention_dropout_prob=args.attention_dropout,
                      output_dropout_prob=args.hidden_dropout,
                      max_sequence_length=args.max_position_embeddings,
                      checkpoint_activations=args.checkpoint_activations,
                      checkpoint_num_layers=args.checkpoint_num_layers,
                      parallel_output=True,
                      deepspeed_sparsity_config=deepspeed_sparsity_config,
                      sparse_mode=args.sparse_mode)

    if args.load_huggingface is not None:
        print ("Loading huggingface model...")
        model = load_huggingface_model(model, args.load_huggingface, args.huggingface_double_pos_embeddings)

    if mpu.get_data_parallel_rank() == 0:
        print(' > number of parameters on model parallel rank {}: {}'.format(
            mpu.get_model_parallel_rank(),
            sum([p.nelement() for p in model.parameters()])), flush=True)

    # To prevent OOM for model sizes that cannot fit in GPU memory in full precision
    if DEEPSPEED_WRAP and args.deepspeed and args.fp16:
        print ("We've had deepspeed AND fp16, so we're halfing the model...")
        model.half()

    # GPU allocation.
    print (f"placing the model on device {torch.cuda.current_device()}")
    model.cuda(torch.cuda.current_device())

    # Fp16 conversion.
    if args.fp16:
        print ("we've halfed the model before, but now we're wrapping it into a fp16_module. For...some reason...")
        model = FP16_Module(model)

    # Wrap model for distributed training.
    print ("Setting up distributed training...")
    if USE_TORCH_DDP:
        i = torch.cuda.current_device()
        print (f"Using classic pytorch DDP with device {i}")
        model = DDP(model, device_ids=[i], output_device=i,
                    process_group=mpu.get_data_parallel_group())
    else:
        print ("Using sberbank magic DDP")
        model = DDP(model)

#     input ("ready to return model")
    print ("ready to return model")
    return model