コード例 #1
0
def prepare_model():
    """Main training program."""

    #print('Generate Samples')

    # Disable CuDNN.
    torch.backends.cudnn.enabled = False

    # Timer.
    timers = Timers()

    # Arguments.
    args = set_args()
    #print(args)
    args.mem_length = args.seq_length + args.mem_length - 1

    # Pytorch distributed.
    initialize_distributed(args)

    # Random seeds for reproducability.
    args.seed = random.randint(0, 1000000)
    set_random_seed(args.seed)

    #get the tokenizer
    tokenizer = prepare_tokenizer(args)

    # Model, optimizer, and learning rate.
    model = setup_model(args)
    #args.load="../ckp/txl-2.8b11-20-15-10"
    #model2=setup_model(args)
    #setting default batch size to 1
    args.batch_size = 1

    #generate samples
    return model, tokenizer, args
コード例 #2
0
def main():
    """Main training program."""

    print('Generate Samples')

    # Disable CuDNN.
    torch.backends.cudnn.enabled = False

    # Timer.
    timers = Timers()

    # Arguments.
    args = get_args()

    # Pytorch distributed.
    initialize_distributed(args)

    # Random seeds for reproducability.
    set_random_seed(args.seed)

    #get the tokenizer
    tokenizer = prepare_tokenizer(args)

    # Model, optimizer, and learning rate.
    model = setup_model(args)

    #setting default batch size to 1
    args.batch_size = 1

    #generate samples
    generate_samples(model, tokenizer, args, torch.cuda.current_device())
コード例 #3
0
def main():
    """Main training program."""

    print('Generate Samples')

    # Disable CuDNN.
    torch.backends.cudnn.enabled = False

    # Timer.
    timers = Timers()

    # Arguments.
    args = get_args()

    # Pytorch distributed.
    initialize_distributed(args)

    # Random seeds for reproducability.
    set_random_seed(args.seed)

    #get the tokenizer
    tokenizer = GPT2Tokenizer(
        os.path.join(args.tokenizer_path, 'vocab.json'),
        os.path.join(args.tokenizer_path, 'chinese_vocab.model'))

    # Model
    model = setup_model(args)

    #setting default batch size to 1
    args.batch_size = 1

    #generate samples
    generate_samples(model, tokenizer, args, torch.cuda.current_device())
コード例 #4
0
def main():
    """Main training program."""

    num_of_gpus = 8
    num_of_layers = 24
    hp = 1024 // num_of_gpus
    d_binglr = torch.load(
        '/relevance2-nfs/romittel/binglr_pretrained_model/pytorch_model.bin')
    emb_per_gpu = d_binglr['bert.embeddings.word_embeddings.weight'].size(
    )[0] // num_of_gpus
    # Disable CuDNN.
    torch.backends.cudnn.enabled = False

    # Timer.
    timers = Timers()

    # Arguments.
    args = get_args()
    file_len = 0
    for line in open(args.valid_data[0], 'r', encoding='utf-8'):
        file_len += 1
    print("file_len= ", file_len)
    # Pytorch distributed.
    initialize_distributed(args)
    if torch.distributed.get_rank() == 0:
        print('Pretrain GPT2 model')
        print_args(args)

    # Random seeds for reproducability.
    set_random_seed(args.seed)

    # Data stuff.
    train_data, val_data, test_data, args.vocab_size, \
        args.eod_token = get_train_val_test_data(args)

    # Model, optimizer, and learning rate.

    model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
    #model.optimizer.dynamic_loss_scale=True
    j = 0
    if j == torch.distributed.get_rank():
        # word_embeddings
        #num_embeddings_per_partition = model.module.module.module.word_embeddings.num_embeddings_per_partition
        #embedding_dim = model.module.module.module.word_embeddings.embedding_dim
        print(model.module.module.module.input_layernorm.bias.size())
        print(d_binglr['bert.embeddings.LayerNorm.bias'].size())
コード例 #5
0
def main():
    """Main training program."""

    print('Generate Samples')

    # Disable CuDNN.
    # torch.backends.cudnn.enabled = False
    torch.backends.cudnn.enabled = True

    # Timer.
    timers = Timers()

    # Arguments.
    args = get_args()

    # Pytorch distributed.
    initialize_distributed(args)

    # Random seeds for reproducability.
    set_random_seed(args.seed)

    # get the tokenizer
    tokenizer = prepare_tokenizer(args)

    # Model, optimizer, and learning rate.
    model = setup_model(args)

    # setting default batch size to 1
    # args.batch_size = 1

    args.device = torch.cuda.current_device()

    # generate samples
    if args.num_samples == 0:
        args.batch_size = 1
        if args.sample_input_file != "":
            generate_samples_input_from_file(model, tokenizer, args)
        else:
            generate_samples_interactive(model, tokenizer, args)
    else:
        write_and_generate_samples_unconditional(model, tokenizer, args)
コード例 #6
0
def main():
    """Main training program."""

    # Disable CuDNN.
    torch.backends.cudnn.enabled = False

    # Timer.
    timers = Timers()

    # Arguments.
    args = get_args()
    file_len = 0
    for line in open(args.valid_data[0], 'r', encoding='utf-8'):
        file_len += 1
    print("file_len= ", file_len)
    # Pytorch distributed.
    initialize_distributed(args)
    if torch.distributed.get_rank() == 0:
        print('Pretrain GPT2 model')
        print_args(args)

    # Random seeds for reproducability.
    set_random_seed(args.seed)

    # Data stuff.
    train_data, val_data, test_data, args.vocab_size, \
        args.eod_token = get_train_val_test_data(args)

    # Model, optimizer, and learning rate.
    model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
    #model.optimizer.dynamic_loss_scale=True

    if val_data is not None:
        val_data_iterator = iter(val_data)
    else:
        val_data_iterator = None

    #TODO: figure out how to properly set this especially when resuming training
    evaluate(val_data_iterator, model, args, timers, file_len, verbose=False)
コード例 #7
0
def main():
    """Main training program."""

    # Disable CuDNN.
    torch.backends.cudnn.enabled = False

    # Timer.
    timers = Timers()

    # Arguments.
    args = get_args()

    # Pytorch distributed.
    initialize_distributed(args)
    if torch.distributed.get_rank() == 0:
        print('Pretrain GPT2 model')
        print_args(args)

    # Random seeds for reproducability.
    set_random_seed(args.seed)

    # Model, optimizer, and learning rate.
    model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
    if torch.distributed.get_rank() == 0:
        print(args.iteration)

    train_data_iterator, val_data_iterator, test_data_iterator = \
            build_train_valid_test_data_iterators(
                    train_valid_test_dataset_provider, args)

    # Resume data loader if necessary.
    # if args.resume_dataloader:
    #    if train_data is not None:
    #        train_data.batch_sampler.start_iter = args.iteration % \
    #                                              len(train_data)
    #    if val_data is not None:
    #        start_iter_val = (args.train_iters // args.save_interval) * \
    #                         args.eval_interval
    #        val_data.batch_sampler.start_iter = start_iter_val % \
    #                                            len(val_data)
    # if train_data is not None:
    #    train_data_iterator = iter(train_data)
    # else:
    #    train_data_iterator = None
    # if val_data is not None:
    #    val_data_iterator = iter(val_data)
    # else:
    #    val_data_iterator = None

    # TODO: figure out how to properly set this especially when resuming training
    iteration = 0
    if args.train_iters > 0:
        iteration, skipped = train(model, optimizer,
                                   lr_scheduler,
                                   train_data_iterator,
                                   val_data_iterator,
                                   timers, args)

        prefix = 'the end of training for val data'
        val_loss = evaluate_and_print_results(prefix, val_data_iterator,
                                                  model, args, timers, False)

    if args.save and iteration != 0:
        save_checkpoint(iteration, model, optimizer, lr_scheduler, args)

    # if test_data is not None:
    #    test_data_iterator = iter(test_data)
    # else:
    #    test_data_iterator = None

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, test_data_iterator,
                                   model, args, timers, True)
コード例 #8
0
def main():
    """Main training program."""

    # Disable CuDNN.
    torch.backends.cudnn.enabled = False

    # Timer.
    timers = Timers()

    # Arguments.
    args = get_args()

    # Pytorch distributed.
    initialize_distributed(args)

    # Random seeds for reproducability.
    set_random_seed(args.seed)

    # get the tokenizer
    tokenizer = GPT2Tokenizer(os.path.join(args.tokenizer_path, 'vocab.json'), os.path.join(args.tokenizer_path, 'chinese_vocab.model'))

    # load train data
    if args.do_train:
        train_dataloader, _ = load_data(args, 'train', tokenizer, 1)
        dev_dataloader, dev_dataset = load_data(args, 'dev', tokenizer, 1)

        with open(args.deepspeed_config, "r") as f:
            deepspeed_conf = json.load(f)

        epoch = args.epoch
        grad_acc = deepspeed_conf["gradient_accumulation_steps"]
        args.train_iters = len(train_dataloader) * epoch / grad_acc

        # Model, optimizer, and learning rate.
        # TODO: maybe need to reinitialize optimizer
    elif args.do_eval:
        # Set an arbitrary positive integer since the optimizer and the scheduler will not be used when do eval.
        args.train_iters = 1

    model, optimizer, lr_scheduler = setup_model_and_optimizer_C(args)
    device = torch.cuda.current_device()

    # give a time stemp to the model
    cur_time = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())
    results_dir = os.path.join(args.results_dir, "{}-{}".format(args.model_name, cur_time))
    os.makedirs(results_dir, exist_ok=True)

    if args.do_train and torch.distributed.get_rank() == 0:

        with open(os.path.join(results_dir, "train_log.txt"), "w") as f:
            f.write("Train losses:\n")

        with open(os.path.join(results_dir, "dev_log.txt"), "w") as f:
            f.write("Dev accs:\n")

    torch.distributed.barrier()

    if args.do_train:
        # cand_ids = torch.tensor(dev_dataset.cand_ids).to(device)
        total_loss, logging_loss, best_acc = 0.0, 0.0, 0.0
        global_step, total_step, best_step = 0, 0, 0
        
        for e in range(epoch):
            model.train()
            for batch, no_model_batch in tqdm(train_dataloader, disable=(torch.distributed.get_rank() != 0)):
                for k in batch:
                    batch[k] = batch[k].to(device)
                for k in no_model_batch:
                    no_model_batch[k] = no_model_batch[k].to(device)

                output = model(**batch)
                # get the loss of the last token
                output = torch.sum(output * no_model_batch["loss_mask"].unsqueeze(-1), 1) / torch.sum(no_model_batch["loss_mask"], -1).unsqueeze(-1)
                # get the label of the last token
                # labels = no_model_batch["labels"].float()
                labels = no_model_batch["truth"].float()
                # labels = (torch.sum(labels * no_model_batch["loss_mask"], 1) / torch.sum(no_model_batch["loss_mask"], -1)).long()
                # cross_entropy loss
                # losses = mpu.vocab_parallel_cross_entropy(output.unsqueeze(1).contiguous().float(), labels.unsqueeze(1))
                losses = CrossEntropyLoss(output.unsqueeze(1).contiguous().float(), labels.unsqueeze(1))
                loss = torch.mean(losses)

                model.backward(loss)
                model.step()

                torch.distributed.all_reduce(loss.data, group=mpu.get_data_parallel_group())
                loss.data = loss.data / mpu.get_data_parallel_world_size()
                total_loss += loss.item() / grad_acc

                if total_step % grad_acc == 0:
                    global_step += 1
                    if global_step != 0 and global_step % args.log_interval == 0:
                        # logging
                        if torch.distributed.get_rank() == 0:
                            train_log = "Epoch {}, global step {}, total step {}, train lm loss: {}".format(e, global_step, epoch * len(train_dataloader), (total_loss - logging_loss) / args.log_interval)
                            yprint(train_log)
                            with open(os.path.join(results_dir, "train_log.txt"), "a") as f:
                                f.write(train_log + "\n")

                        logging_loss = total_loss
    
                    if global_step != 0 and global_step % args.eval_interval == 0:
                        # evaluate on the dev
                        acc, _, _ = evaluate_tnews(args, model, dev_dataloader, device, mode="dev")
                        dev_results_dir = os.path.join(results_dir, "dev_step-{}".format(global_step))

                        if acc > best_acc:
                            best_acc = acc
                            best_step = global_step

                        if torch.distributed.get_rank() == 0:
                            # we will only write the log file once
                            dev_log = "Epoch: {}, Global step: {}, Acc: {}".format(e, global_step, acc)
                            yprint(dev_log)
                            os.makedirs(dev_results_dir, exist_ok=True)
                            with open(os.path.join(dev_results_dir, "dev_result.txt"), "w") as f:
                                f.write(dev_log + "\n")
                            with open(os.path.join(results_dir, "dev_log.txt"), "a") as f:
                                f.write(dev_log + "\n")

                        torch.distributed.barrier()
                        
                        args.save = dev_results_dir
                        save_checkpoint(global_step, model, optimizer, lr_scheduler, args)

                total_step += 1

        with open(os.path.join(dev_results_dir, "dev_log.txt"), "a") as f:
            f.write("Best acc: {} Best step: {}\n".format(best_acc, best_step))

    if args.do_eval:
        # evaluate on the test
        test_dataloader, test_dataset = load_data(args, 'test', tokenizer, 1)
        cand_ids = torch.tensor(test_dataset.cand_ids).to(device)

        if args.do_train:
            # if do training, then evaluate the one with the max acc on dev set.
            eval_ckpt_path = os.path.join(results_dir, "dev_step-{}".format(best_step))
            args.load = eval_ckpt_path
        else:
            # if only do eval, then evaluate the one specified by the user.
            args.load = args.eval_ckpt_path            
        
        load_checkpoint(model=model, optimizer=None, lr_scheduler=None, args=args)
        acc, _, _ = evaluate(args, model, test_dataloader, cand_ids, device, mode="test")

        if torch.distributed.get_rank() == 0:
            eval_log = "Checkpoint from {}: Acc: {}".format(args.load, acc)
            yprint(eval_log)
            with open(os.path.join(results_dir, "eval_log"), "w") as f:
                f.write(eval_log + "\n")

        torch.distributed.barrier()
コード例 #9
0
def main():
    """Main training program."""

    # Disable CuDNN.
    torch.backends.cudnn.enabled = False

    # Timer.
    timers = Timers()

    # Arguments.
    args = get_args()

    # Pytorch distributed.
    initialize_distributed(args)
    if torch.distributed.get_rank() == 0:
        print('Pretrain BERT model')
        print_args(args)

    # Random seeds for reproducability.
    set_random_seed(args.seed)

    # Data stuff.
    train_data, val_data, test_data, args.tokenizer_num_tokens, \
        args.tokenizer_num_type_tokens = get_train_val_test_data(args)

    # Model, optimizer, and learning rate.
    model, optimizer, lr_scheduler = setup_model_and_optimizer(args)

    if args.resume_dataloader:
        if train_data is not None:
            train_data.batch_sampler.start_iter = args.iteration % \
                                                  len(train_data)
        if val_data is not None:
            start_iter_val = (args.train_iters // args.save_interval) * \
                             args.eval_interval
            val_data.batch_sampler.start_iter = start_iter_val % \
                                                len(val_data)

    if train_data is not None:
        train_data_iterator = iter(train_data)
    else:
        train_data_iterator = None
    if val_data is not None:
        val_data_iterator = iter(val_data)
    else:
        val_data_iterator = None

    iteration = 0
    if args.train_iters > 0:
        if args.do_train:
            iteration, skipped = train(model, optimizer, lr_scheduler,
                                       train_data_iterator, val_data_iterator,
                                       timers, args)
        if args.do_valid:
            prefix = 'the end of training for val data'
            val_loss = evaluate_and_print_results(prefix, val_data_iterator,
                                                  model, args, timers, False)

    if args.save and iteration != 0:
        save_checkpoint(iteration, model, optimizer, lr_scheduler, args)

    if test_data is not None:
        test_data_iterator = iter(test_data)
    else:
        test_data_iterator = None

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, test_data_iterator, model, args,
                                   timers, True)
コード例 #10
0
def main():
    """Main training program."""

    # Disable CuDNN.
    torch.backends.cudnn.enabled = False
    # Timer.
    timers = Timers()

    # Arguments.
    args = get_args()
    args.mem_length = args.mem_length if args.transformer_xl else 0
    if args.load and not args.new_save_directory:
        args.experiment_name = os.path.basename(os.path.normpath(args.load))
    else:
        args.experiment_name = args.experiment_name + datetime.now().strftime(
            "%m-%d-%H-%M")
    if args.save:
        args.save = os.path.join(args.save, args.experiment_name)
    # Pytorch distributed.
    initialize_distributed(args)

    # Random seeds for reproducability.
    set_random_seed(args.seed)

    # Data stuff.
    global tokenizer
    tokenizer = prepare_tokenizer(args)
    train_data, val_data, test_data, = get_train_val_test_data(args, tokenizer)
    multi_train_data, multi_val_data = None, None
    if args.multi_task_ratio > 0.0:
        multi_train_data, multi_val_data = build_multi_task_dataset(
            args, tokenizer)

    # Model, optimizer, and learning rate.
    model, optimizer, lr_scheduler = setup_model_and_optimizer(args)

    if args.load is not None:
        with FileLock(os.path.join(pathlib.Path.home(), "checkpoint_lock"),
                      timeout=-1):
            args.iteration = load_checkpoint(model, optimizer, lr_scheduler,
                                             args)
    else:
        args.iteration = 0
    torch.distributed.barrier()
    if args.switch_linear:
        lr_scheduler.switch_linear(args)

    summary_writer = None
    if torch.distributed.get_rank() == 0:
        print('Pretrain GPT2 model')
        args.log_dir = None
        if args.train_iters > 0:
            args.log_dir = get_log_dir(base=args.summary_dir,
                                       name=args.experiment_name)
            summary_writer = get_sample_writer(log_dir=args.log_dir,
                                               iteration=args.iteration)
        print_and_save_args(args, verbose=True, log_dir=args.log_dir)

    # Resume data loader if necessary.
    if args.resume_dataloader:
        print_rank_0("Resume dataloader")
        if train_data is not None:
            train_data.batch_sampler.start_iter = args.iteration % len(
                train_data)
        if val_data is not None:
            start_iter_val = (args.iteration //
                              args.eval_interval) * args.eval_iters
            val_data.batch_sampler.start_iter = start_iter_val % len(val_data)
        if multi_train_data is not None:
            multi_train_data.batch_sampler.start_iter = int(
                args.iteration * args.multi_task_ratio) % len(multi_train_data)
        if multi_val_data is not None:
            start_iter_val = (args.iteration // args.eval_interval
                              ) * args.eval_iters * args.multi_task_ratio
            multi_val_data.batch_sampler.start_iter = start_iter_val % len(
                multi_val_data)
    if train_data is not None:
        train_data_iterator = iter(train_data)
    else:
        train_data_iterator = None
    if multi_train_data is not None:
        multi_train_iterator = iter(multi_train_data)
    else:
        multi_train_iterator = None
    if val_data is not None:
        val_data_iterator = iter(val_data)
    else:
        val_data_iterator = None
    if multi_val_data is not None:
        multi_val_iterator = iter(multi_val_data)
    else:
        multi_val_iterator = None

    # TODO: figure out how to properly set this especially when resuming training
    iteration = 0
    if args.train_iters > 0:
        if args.do_train:
            with ExitStack() as stack:

                def save_on_exit(args_, model_, optimizer_, lr_scheduler_):
                    save_checkpoint(args_.iteration, model_, optimizer_,
                                    lr_scheduler_, args_)

                # stack.callback(save_on_exit, args, model, optimizer, lr_scheduler)
                iteration, skipped = train(
                    model,
                    optimizer,
                    lr_scheduler, (train_data_iterator, multi_train_iterator),
                    (val_data_iterator, multi_val_iterator),
                    timers,
                    args,
                    summary_writer=summary_writer)

        if args.do_valid:
            prefix = 'the end of training for val data'
            val_loss = evaluate_and_print_results(
                prefix,
                val_data_iterator,
                model,
                args,
                timers,
                verbose=False,
                forward_step_func=forward_step)

    if args.save and iteration != 0:
        save_checkpoint(iteration, model, optimizer, lr_scheduler, args)

    if test_data is not None:
        test_data_iterator = iter(test_data)
    else:
        test_data_iterator = None

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, (test_data_iterator, None),
                                   model,
                                   args,
                                   timers,
                                   verbose=True,
                                   forward_step_func=forward_step)
コード例 #11
0
def main():
    """Main training program."""

    # Disable CuDNN.
    torch.backends.cudnn.enabled = False
    # Timer.
    timers = Timers()

    # Arguments.
    args = get_args()
    args.mem_length = args.mem_length if args.transformer_xl else 0
    if args.load:
        args.experiment_name = os.path.basename(os.path.normpath(args.load))
    else:
        args.experiment_name = args.experiment_name + datetime.now().strftime(
            "%m-%d-%H-%M")
    if args.save:
        args.save = os.path.join(args.save, args.experiment_name)
    # Pytorch distributed.
    initialize_distributed(args)

    # Random seeds for reproducability.
    set_random_seed(args.seed)

    # Data stuff.
    train_data, val_data, test_data, args.vocab_size, \
        args.eod_token = get_train_val_test_data(args)

    # Model, optimizer, and learning rate.
    model, optimizer, lr_scheduler = setup_model_and_optimizer(args)

    if args.load is not None:
        with FileLock("/root/checkpoint_lock", timeout=-1):
            args.iteration = load_checkpoint(model, optimizer, lr_scheduler,
                                             args)
    else:
        args.iteration = 0
    torch.distributed.barrier()

    summary_writer = None
    if torch.distributed.get_rank() == 0:
        print('Pretrain GPT2 model')
        print_args(args)
        summary_writer = get_sample_writer(base=args.summary_dir,
                                           name=args.experiment_name,
                                           iteration=args.iteration)

    # Resume data loader if necessary.
    if args.resume_dataloader:
        if train_data is not None:
            train_data.batch_sampler.start_iter = args.iteration % \
                                                  len(train_data)
        if val_data is not None:
            start_iter_val = (args.train_iters // args.save_interval) * \
                             args.eval_interval
            val_data.batch_sampler.start_iter = start_iter_val % \
                                                len(val_data)
    if train_data is not None:
        train_data_iterator = iter(train_data)
    else:
        train_data_iterator = None
    if val_data is not None:
        val_data_iterator = iter(val_data)
    else:
        val_data_iterator = None

    # TODO: figure out how to properly set this especially when resuming training
    iteration = 0
    if args.train_iters > 0:
        if args.do_train:
            with ExitStack() as stack:

                def save_on_exit(args_, model_, optimizer_, lr_scheduler_):
                    save_checkpoint(args_.iteration, model_, optimizer_,
                                    lr_scheduler_, args_)

                # stack.callback(save_on_exit, args, model, optimizer, lr_scheduler)
                iteration, skipped = train(model,
                                           optimizer,
                                           lr_scheduler,
                                           train_data_iterator,
                                           val_data_iterator,
                                           timers,
                                           args,
                                           summary_writer=summary_writer)

        if args.do_valid:
            prefix = 'the end of training for val data'
            val_loss = evaluate_and_print_results(prefix, val_data_iterator,
                                                  model, args, timers, False)

    if args.save and iteration != 0:
        save_checkpoint(iteration, model, optimizer, lr_scheduler, args)

    if test_data is not None:
        test_data_iterator = iter(test_data)
    else:
        test_data_iterator = None

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, test_data_iterator, model, args,
                                   timers, True)
コード例 #12
0
def finetune(args,
             train_valid_datasets_provider,
             model_kwargs,
             forward_step=finetune_forward_step,
             end_of_epoch_callback_provider=None):
    """Main finetune function used across all tasks."""
    global tokenizer
    timers = Timers()
    tokenizer = prepare_tokenizer(args)
    pretrain_glm.tokenizer = tokenizer
    if args.save:
        args.save = os.path.join(args.save, args.experiment_name)
    # Train and validation data loaders.
    timers('train/valid/test dataset/dataloder').start()
    train_dataloader, valid_dataloader = None, None
    train_block_dataloader, valid_block_dataloader = None, None
    if train_valid_datasets_provider is not None and args.epochs > 0:
        if mpu.get_model_parallel_rank() == 0:
            train_dataset, valid_dataset = train_valid_datasets_provider(
                args, tokenizer)
            train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
                train_dataset, valid_dataset, args)
            if args.no_validation:
                valid_dataloader = None
            train_iters = torch.cuda.LongTensor([len(train_dataloader)])
        else:
            train_iters = torch.cuda.LongTensor([0])
        torch.distributed.broadcast(train_iters,
                                    mpu.get_model_parallel_src_rank(),
                                    group=mpu.get_model_parallel_group())
        if mpu.get_model_parallel_rank() != 0:
            args.train_iters_per_epoch = train_iters[0].item()
            args.train_iters = args.epochs * args.train_iters_per_epoch

            train_dataloader = FakeDataloader(args.train_iters_per_epoch)
            if args.no_validation:
                valid_dataloader = None
            else:
                valid_dataloader = FakeDataloader(None)
        if args.block_lm_ratio > 0.0:
            if mpu.get_model_parallel_rank() == 0:
                train_block_dataset, valid_block_dataset = train_valid_datasets_provider(
                    args, tokenizer, pattern_text=True)
                train_block_dataloader = make_data_loader(
                    train_block_dataset,
                    tokenizer,
                    args.batch_size * mpu.get_data_parallel_world_size(),
                    args.train_iters,
                    args,
                    shuffle=True,
                    block_collate=True)
                valid_block_dataloader = make_data_loader(
                    valid_block_dataset,
                    tokenizer,
                    args.batch_size * mpu.get_data_parallel_world_size(),
                    (args.train_iters // args.eval_interval + 1) *
                    args.eval_iters,
                    args,
                    shuffle=True,
                    block_collate=True)
            else:
                train_block_dataloader = FakeDataloader(args.train_iters)
                valid_block_dataloader = FakeDataloader(None)
            train_block_dataloader, valid_block_dataloader = iter(
                train_block_dataloader), iter(valid_block_dataloader)

    timers('train/valid/test dataset/dataloder').stop()
    # Build calback function.
    timers('callback function').start()
    end_of_epoch_callback, end_of_train_callback = None, None
    if end_of_epoch_callback_provider is not None:
        if train_valid_datasets_provider is not None and args.epochs > 0 and not args.no_validation:
            end_of_epoch_callback = end_of_epoch_callback_provider(
                args, tokenizer, is_test=False)
        end_of_train_callback = end_of_epoch_callback_provider(args,
                                                               tokenizer,
                                                               is_test=True)
    timers('callback function').stop()

    # Build model, optimizer and learning rate scheduler.
    timers('model and optimizer').start()
    model, optimizer, lr_scheduler = setup_model_and_optimizer(
        args, **model_kwargs)
    timers('model and optimizer').stop()

    # If pretrained checkpoint is provided and we have not trained for
    # any iteration (i.e., iteration is zero), then load the pretrained
    # checkpoint.
    timers('pretrained checkpoint').start()
    if args.load_pretrained is not None and not args.pretrained_bert:
        task_tokens = None
        if args.continuous_prompt and args.prompt_init:
            if mpu.get_model_parallel_rank() == 0:
                dataset = train_dataloader.dataset
                processor, pvp = dataset.processor, dataset.pvp
                task_tokens = []
                for label in processor.get_labels():
                    verbalizer = pvp.verbalize(label)[0]
                    verbalizer_ids = tokenizer.EncodeAsIds(
                        verbalizer).tokenization
                    task_tokens += verbalizer_ids
                print_rank_0("Task tokens: " +
                             tokenizer.DecodeIds(task_tokens))
                num_task_tokens = len(task_tokens)
            else:
                num_task_tokens, task_tokens = 0, []
            num_task_tokens = torch.cuda.LongTensor([num_task_tokens])
            torch.distributed.broadcast(num_task_tokens,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())
            num_task_tokens = num_task_tokens.item()
            if num_task_tokens > 0:
                if mpu.get_model_parallel_rank() == 0:
                    task_tokens = torch.cuda.LongTensor(task_tokens)
                else:
                    task_tokens = torch.empty(
                        num_task_tokens,
                        device=torch.cuda.current_device(),
                        dtype=torch.long)
                torch.distributed.broadcast(
                    task_tokens,
                    mpu.get_model_parallel_src_rank(),
                    group=mpu.get_model_parallel_group())
                task_tokens = task_tokens.tolist()
        with FileLock(os.path.join(pathlib.Path.home(), "checkpoint_lock"),
                      timeout=-1):
            load_pretrained(model,
                            args.load_pretrained,
                            args,
                            task_tokens=task_tokens)
        # This is critical when only model is loaded. We should make sure
        # master parameters are also updated.
        if args.fp16 and optimizer is not None:
            if args.deepspeed:
                optimizer.refresh_fp32_params()
            else:
                optimizer._model_params_to_master_params()
    if args.load is not None:
        with FileLock(os.path.join(pathlib.Path.home(), "checkpoint_lock"),
                      timeout=-1):
            load_checkpoint(model,
                            optimizer,
                            lr_scheduler,
                            args,
                            no_deepspeed=args.no_deepspeed_load)
        # This is critical when only model is loaded. We should make sure
        # master parameters are also updated.
        if args.fp16 and optimizer is not None:
            if args.deepspeed:
                optimizer.refresh_fp32_params()
            else:
                optimizer._model_params_to_master_params()
    torch.distributed.barrier()
    timers('pretrained checkpoint').stop()
    args.iteration = 0
    summary_writer = None
    if torch.distributed.get_rank() == 0:
        args.log_dir = get_log_dir(base=args.summary_dir,
                                   name=args.experiment_name)
        if os.path.exists(os.path.join(args.log_dir, "test_results.json")
                          ) and args.load is None and not args.overwrite:
            raise ValueError(
                "Output directory ({}) already exists and is not empty.".
                format(args.log_dir))
        summary_writer = get_sample_writer(log_dir=args.log_dir,
                                           iteration=args.iteration)
        print_and_save_args(args, verbose=True, log_dir=args.log_dir)

    # Print setup timing.
    print_rank_0('done with setups ...')
    timers.log([
        'train/valid/test dataset/dataloder', 'callback function',
        'model and optimizer', 'pretrained checkpoint'
    ])
    print_rank_0('training ...')

    # Finetune the model.
    score_dict = None
    if train_dataloader is not None and args.epochs > 0:
        if args.block_lm_ratio > 0.0:
            forward_step = mix_forward_step
        best_iteration = _train(model,
                                optimizer,
                                lr_scheduler,
                                forward_step,
                                (train_dataloader, train_block_dataloader),
                                (valid_dataloader, valid_block_dataloader),
                                end_of_epoch_callback,
                                args,
                                timers,
                                summary_writer=summary_writer)
        if end_of_train_callback is not None and best_iteration is not None:
            with FileLock(os.path.join(pathlib.Path.home(), "checkpoint_lock"),
                          timeout=-1):
                args.load = os.path.join(args.save, "best")
                load_checkpoint(model,
                                optimizer,
                                lr_scheduler,
                                args,
                                no_load_optim=True,
                                no_deepspeed=True)
                args.load = None
        torch.distributed.barrier()
        if end_of_train_callback is not None:
            score_dict = end_of_train_callback(model,
                                               epoch=-1,
                                               output_predictions=True)
    # Or just evaluate.
    else:
        if end_of_train_callback is not None:
            print_rank_0('evaluation only mode, setting epoch to -1')
            score_dict = end_of_train_callback(model,
                                               epoch=-1,
                                               output_predictions=True)
    if score_dict is not None and torch.distributed.get_rank() == 0:
        score_dict.update({"type": "test"})
        with open(os.path.join(args.log_dir, "test_results.json"),
                  "w") as output:
            output.write(json.dumps(score_dict) + "\n")

    print_rank_0('done :-)')
コード例 #13
0
ファイル: pretrain_bert.py プロジェクト: qhduan/gpt-lm
def main():
    """Main training program."""

    print('Pretrain BERT model')

    # Disable CuDNN.
    torch.backends.cudnn.enabled = False
    # Arguments.
    args = get_args()

    # Pytorch distributed.
    initialize_distributed(args)

    set_random_seed(args.seed)
    print(args)
    # Data stuff.
    data_config = configure_data()
    data_config.set_defaults(data_set_type='BERT', transpose=False)
    (train_data, val_data), tokenizer = data_config.apply(args)

    args.train_iters = len(train_data)
    evaluate.best_val_loss = float("inf")

    # Model, optimizer, and learning rate.
    model, optimizer, lr_scheduler, criterion = setup_model_and_optimizer(
        args, tokenizer)
    # evaluate(val_data, model, tokenizer, criterion, args)
    # At any point you can hit Ctrl + C to break out of training early.
    try:
        total_iters = 0
        skipped_iters = 0
        start_epoch = 1
        best_val_loss = float('inf')
        # Resume data loader if necessary.
        if args.resume_dataloader:
            start_epoch = args.epoch
            total_iters = args.total_iters
        # For all epochs.
        for epoch in range(start_epoch, args.epochs + 1):
            timers = Timers()
            # if args.shuffle:
            #     train_data.batch_sampler.sampler.set_epoch(epoch + args.seed)
            timers('epoch time').start()
            iteration, skipped = train_epoch(epoch, model, tokenizer,
                                             optimizer, train_data, val_data,
                                             lr_scheduler, criterion, timers,
                                             args)
            elapsed_time = timers('epoch time').elapsed()
            total_iters += iteration
            skipped_iters += skipped
            lm_loss, nsp_loss = evaluate(val_data, model, tokenizer, criterion,
                                         args)
            val_loss = lm_loss + nsp_loss
            print('-' * 100)
            print(
                '| end of epoch {:3d} | time: {:.3f}s | valid loss {:.3f} | '
                'valid LM Loss {:.3f} | valid LM PPL {:.3f} | valid NSP Loss {:.3f}'
                .format(epoch, elapsed_time, val_loss, lm_loss,
                        math.exp(lm_loss), nsp_loss))
            print('-' * 100)
            if val_loss < evaluate.best_val_loss:
                evaluate.best_val_loss = val_loss
                if args.save:
                    best_path = 'checkpoints-best.pt'
                    print('saving best model to:',
                          os.path.join(args.save, best_path))
                    save_checkpoint(best_path, epoch + 1, 0, model, optimizer,
                                    lr_scheduler, args)
    except KeyboardInterrupt:
        print('-' * 100)
        print('Exiting from training early')
        if args.save:
            cur_path = 'checkpoints-last.pt'
            print('saving current model to:',
                  os.path.join(args.save, cur_path))
            save_checkpoint(cur_path, epoch, args.cur_iteration, model,
                            optimizer, lr_scheduler, args)
        exit()
コード例 #14
0
def main():
    """Main training program."""

    global global_example_count, global_token_count, event_writer, logdir, train_step, train_loss, best_val_loss, eval_start_time, log_start_time, epoch

    global_token_count = 0

    # Arguments.
    args = get_args()

    # global global_example_count, global_token_count, event_writer, logdir
    logdir = f'{args.logdir}'
    os.system(f'mkdir -p {logdir}')

    event_writer = SummaryWriter(logdir)
    log_tb("first", time.time())
    print('Pretrain BERT model')

    # Disable CuDNN.
    torch.backends.cudnn.enabled = False

    # Timer.
    timers = Timers()

    # Pytorch distributed.
    initialize_distributed(args)

    # Random seeds for reproducability.
    set_random_seed(args.seed)

    # Data stuff.
    data_config = configure_data()
    data_config.set_defaults(data_set_type='BERT', transpose=False)
    (train_data, val_data, test_data), tokenizer = data_config.apply(args)
    args.data_size = tokenizer.num_tokens

    # Model, optimizer, and learning rate.
    model, optimizer, lr_scheduler, criterion = setup_model_and_optimizer(
        args, tokenizer)

    # At any point you can hit Ctrl + C to break out of training early.
    try:
        total_iters = 0
        skipped_iters = 0
        start_epoch = 1
        best_val_loss = float('inf')
        # Resume data loader if necessary.
        if args.resume_dataloader:
            start_epoch = args.epoch
            total_iters = args.total_iters
            train_data.batch_sampler.start_iter = total_iters % len(train_data)
        # For all epochs.
        for epoch in range(start_epoch, args.epochs + 1):
            timers('epoch time').start()
            iteration, skipped = train_epoch(epoch, model, optimizer,
                                             train_data, lr_scheduler,
                                             criterion, timers, args)
            elapsed_time = timers('epoch time').elapsed()
            total_iters += iteration
            skipped_iters += skipped
            lm_loss, nsp_loss = evaluate(val_data, model, criterion, args)
            val_loss = lm_loss + nsp_loss
            print('-' * 100)
            print(
                '| end of epoch {:3d} | time: {:5.2f}s | valid loss {:.4E} | '
                'valid LM Loss {:.4E} | valid NSP Loss {:.4E}'.format(
                    epoch, elapsed_time, val_loss, lm_loss, nsp_loss))
            print('-' * 100)
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                if args.save:
                    best_path = 'best/model.pt'
                    print('saving best model to:',
                          os.path.join(args.save, best_path))
                    save_checkpoint(best_path, epoch + 1, total_iters, model,
                                    optimizer, lr_scheduler, args)

    except KeyboardInterrupt:
        print('-' * 100)
        print('Exiting from training early')
        if args.save:
            cur_path = 'current/model.pt'
            print('saving current model to:',
                  os.path.join(args.save, cur_path))
            save_checkpoint(cur_path, epoch, total_iters, model, optimizer,
                            lr_scheduler, args)
        exit()

    if args.save:
        final_path = 'final/model.pt'
        print('saving final model to:', os.path.join(args.save, final_path))
        save_checkpoint(final_path, args.epochs, total_iters, model, optimizer,
                        lr_scheduler, args)

    if test_data is not None:
        # Run on test data.
        print('entering test')
        lm_loss, nsp_loss = evaluate(test_data, model, criterion, args)
        test_loss = lm_loss + nsp_loss
        print('=' * 100)
        print('| End of training | test loss {:5.4f} | valid LM Loss {:.4E} |'
              ' valid NSP Loss {:.4E}'.format(test_loss, lm_loss, nsp_loss))
        print('=' * 100)
コード例 #15
0
def main():
    """Main training program."""

    # Disable CuDNN.
    torch.backends.cudnn.enabled = False

    # Timer.
    timers = Timers()

    # Arguments.
    args = get_args()

    # Pytorch distributed.
    initialize_distributed(args)
    if torch.distributed.get_rank() == 0:
        print('Pretrain GPT2 model')
        print_args(args)

    # Random seeds for reproducability.
    set_random_seed(args.seed)

    # prepare log file
    os.makedirs(args.save, exist_ok=True)
    with open(args.log_file, "w") as f:
        f.write("Logging:\n")

    # Model, optimizer, and learning rate.
    with open(args.student_config_path, "r") as f:
        student_config = json.load(f)

    student_model, optimizer, lr_scheduler, student_iteration = setup_model_and_optimizer(
        args,
        student_config,
        need_optim=True,
        ckpt_path=args.student_load,
        do_fp16=args.fp16)

    args.iteration = student_iteration

    teacher_model = None
    if args.teacher_config_path is not None:
        with open(args.teacher_config_path, "r") as f:
            teacher_config = json.load(f)
        teacher_model, _, _, _ = setup_model_and_optimizer(
            args,
            teacher_config,
            need_optim=True,
            ckpt_path=args.teacher_load,
            do_fp16=(args.fp16 or args.teacher_fp16))

    if torch.distributed.get_rank() == 0:
        print(student_iteration)

    train_data_iterator, val_data_iterator, test_data_iterator = \
        build_train_valid_test_data_iterators(
            train_valid_test_dataset_provider, args)

    iteration = 0
    if args.do_train:
        iteration, skipped = train(student_model, teacher_model, optimizer,
                                   lr_scheduler, train_data_iterator,
                                   val_data_iterator, timers, args)

        prefix = 'the end of training for val data'
        evaluate_and_print_results(prefix, val_data_iterator, student_model,
                                   teacher_model, args, timers, False)

    if args.save and iteration != 0:
        save_checkpoint(iteration, student_model, optimizer, lr_scheduler,
                        args)

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, test_data_iterator, student_model,
                                   teacher_model, args, timers, True)
コード例 #16
0
ファイル: finetune_glm.py プロジェクト: puraminy/GLM
def finetune(args,
             train_valid_datasets_provider,
             model_kwargs,
             forward_step=finetune_forward_step,
             end_of_epoch_callback_provider=None):
    """Main finetune function used across all tasks."""
    global tokenizer
    timers = Timers()
    tokenizer = prepare_tokenizer(args)
    if args.save:
        args.save = os.path.join(args.save, args.experiment_name)
    # Train and validation data loaders.
    timers('train/valid/test dataset/dataloder').start()
    train_dataloader, valid_dataloader = None, None
    if train_valid_datasets_provider is not None and args.epochs > 0:
        train_dataset, valid_dataset = train_valid_datasets_provider(
            args, tokenizer)
        train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
            train_dataset, valid_dataset, args)
    timers('train/valid/test dataset/dataloder').stop()
    # Build calback function.
    timers('callback function').start()
    end_of_epoch_callback, end_of_train_callback = None, None
    if end_of_epoch_callback_provider is not None:
        if train_valid_datasets_provider is not None and args.epochs > 0:
            end_of_epoch_callback = end_of_epoch_callback_provider(
                args, tokenizer, is_test=False)
        end_of_train_callback = end_of_epoch_callback_provider(args,
                                                               tokenizer,
                                                               is_test=True)
    timers('callback function').stop()

    # Build model, optimizer and learning rate scheduler.
    timers('model and optimizer').start()
    model, optimizer, lr_scheduler = setup_model_and_optimizer(
        args, **model_kwargs)
    timers('model and optimizer').stop()

    # If pretrained checkpoint is provided and we have not trained for
    # any iteration (i.e., iteration is zero), then load the pretrained
    # checkpoint.
    timers('pretrained checkpoint').start()
    if args.load_pretrained is not None and not args.pretrained_bert and not args.load:
        module = model
        if isinstance(module, (LocalDDP, TorchDDP)):
            module = module.module
        if isinstance(module, FP16_Module):
            module = module.module
        if not isinstance(module, GLMModel):
            module = module.model
        args.load = args.load_pretrained
        load_checkpoint(module, optimizer, lr_scheduler, args)
        args.load = None
        # This is critical when only model is loaded. We should make sure
        # master parameters are also updated.
        if args.fp16:
            optimizer._model_params_to_master_params()
    if args.load is not None:
        load_checkpoint(model, optimizer, lr_scheduler, args)
        # This is critical when only model is loaded. We should make sure
        # master parameters are also updated.
        if args.fp16:
            optimizer._model_params_to_master_params()
    timers('pretrained checkpoint').stop()
    args.iteration = 0
    summary_writer = None
    if torch.distributed.get_rank() == 0:
        args.log_dir = get_log_dir(base=args.summary_dir,
                                   name=args.experiment_name)
        if os.path.exists(os.path.join(args.log_dir, "test_results.json")
                          ) and args.load is None and not args.overwrite:
            raise ValueError(
                "Output directory ({}) already exists and is not empty.".
                format(args.log_dir))
        summary_writer = get_sample_writer(log_dir=args.log_dir,
                                           iteration=args.iteration)
        print_and_save_args(args, verbose=False, log_dir=args.log_dir)

    # Print setup timing.
    print_rank_0('done with setups ...')
    timers.log([
        'train/valid/test dataset/dataloder', 'callback function',
        'model and optimizer', 'pretrained checkpoint'
    ])
    print_rank_0('training ...')

    # Finetune the model.
    score_dict = None
    if train_dataloader is not None and args.epochs > 0:
        best_iteration = _train(model,
                                optimizer,
                                lr_scheduler,
                                forward_step,
                                train_dataloader,
                                valid_dataloader,
                                end_of_epoch_callback,
                                args,
                                timers,
                                summary_writer=summary_writer)
        if best_iteration is not None and end_of_train_callback is not None:
            args.load = os.path.join(args.save, "best")
            load_checkpoint(model, optimizer, lr_scheduler, args)
            args.load = None
        if end_of_train_callback is not None:
            score_dict = end_of_train_callback(model,
                                               epoch=-1,
                                               output_predictions=True)
    # Or just evaluate.
    else:
        if end_of_train_callback is not None:
            print_rank_0('evaluation only mode, setting epoch to -1')
            score_dict = end_of_train_callback(model,
                                               epoch=-1,
                                               output_predictions=True)
    if score_dict is not None and torch.distributed.get_rank() == 0:
        score_dict.update({"type": "test"})
        with open(os.path.join(args.log_dir, "test_results.json"),
                  "w") as output:
            output.write(json.dumps(score_dict) + "\n")

    print_rank_0('done :-)')
コード例 #17
0
def main():
    """Main training program."""

    # Disable CuDNN.
    torch.backends.cudnn.enabled = False

    # Timer.
    timers = Timers()

    # Arguments.
    args = get_args()

    writer = None
    if args.tensorboard_dir and torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
        try:
            from torch.utils.tensorboard import SummaryWriter
            writer = SummaryWriter(log_dir=args.tensorboard_dir)
        except ModuleNotFoundError:
            print_rank_0('WARNING: TensorBoard writing requested but is not '
                         'available (are you using PyTorch 1.1.0 or later?), '
                         'no TensorBoard logs will be written.')
            writer = None

    # Pytorch distributed.
    initialize_distributed(args)
    if torch.distributed.get_rank() == 0:
        print('Pretrain ruGPT3Large model')
        print_args(args)

    # Random seeds for reproducability.
    set_random_seed(args.seed)

    # Data stuff.
    train_data, val_data, test_data, args.vocab_size, \
    args.eod_token = get_train_val_test_data(args)

    # Model, optimizer, and learning rate.
    model, optimizer, lr_scheduler = setup_model_and_optimizer(args)

    # Resume data loader if necessary.
    if args.resume_dataloader:
        if train_data is not None:
            train_data.batch_sampler.start_iter = args.iteration % \
                                                  len(train_data)
        if val_data is not None:
            start_iter_val = (args.train_iters // args.save_interval) * \
                             args.eval_interval
            val_data.batch_sampler.start_iter = start_iter_val % \
                                                len(val_data)
    if train_data is not None:
        train_data_iterator = iter(train_data)
    else:
        train_data_iterator = None
    if val_data is not None:
        val_data_iterator = iter(val_data)
    else:
        val_data_iterator = None

    # TODO: figure out how to properly set this especially when resuming training
    iteration = 0
    if args.train_iters > 0:
        if args.do_train:
            iteration, skipped = train(model, optimizer,
                                       lr_scheduler,
                                       train_data_iterator,
                                       val_data_iterator,
                                       timers, args, writer=writer)

        if args.do_valid:
            prefix = 'the end of training for val data'
            val_loss = evaluate_and_print_results(prefix, val_data_iterator,
                                                  model, args, timers, False, writer=writer)

    if args.save and iteration != 0:
        save_checkpoint(iteration, model, optimizer, lr_scheduler, args)

    if test_data is not None:
        test_data_iterator = iter(test_data)
    else:
        test_data_iterator = None

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, test_data_iterator,
                                   model, args, timers, True, writer=writer)
def main():
    """Main training program."""

    print('Evaluate GPT2 model')

    # Disable CuDNN.
    torch.backends.cudnn.enabled = False

    # Timer.
    timers = Timers()

    # Arguments.
    args = get_args()

    # Pytorch distributed.
    initialize_distributed(args)

    # Random seeds for reproducability.
    set_random_seed(args.seed)

    # Data stuff.
    eval_data = get_eval_data(args)

    # Model, optimizer, and learning rate.
    if args.eval_hf:
        from pytorch_pretrained_bert import GPT2LMHeadModel
        from pytorch_pretrained_bert import GPT2Model as HFGPT2Model
        if args.num_layers == 24:
            model_path = args.load
            #model_path = '/home/universal-lm-data.cosmos549/repos/gpt2_mp/models/345M'
            hfmodel = HFGPT2Model.from_pretrained(model_path, cache_dir='gpt2_weights', from_tf=True).cuda()
            model = GPT2LMHeadModel(hfmodel.config)
            model.transformer.load_state_dict(hfmodel.state_dict())
            model.cuda()
        else:
            model = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir='gpt2_weights').cuda()
    else:
        if args.load_openai:
            from utils import move_weights
            model_path = args.load
            args.load = None
            model = setup_model(args)
            from pytorch_pretrained_bert import GPT2LMHeadModel
            from pytorch_pretrained_bert import GPT2Model as HFGPT2Model

            model_path = 'gpt2'
            from_tf = False
            print('loading openai weights')
            model.cpu()
            if args.num_layers == 24:
                #model_path = '/home/universal-lm-data.cosmos549/repos/gpt2_mp/models/345M'
                hfmodel = HFGPT2Model.from_pretrained(model_path, cache_dir='gpt2_weights', from_tf=True)
                gpt2model = GPT2LMHeadModel(hfmodel.config)
                gpt2model.transformer.load_state_dict(hfmodel.state_dict())
                gpt2model
            else:
                gpt2model = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir='gpt2_weights')
            model2fill = model
            while isinstance(model2fill, (DDP, FP16_Module)):
                model2fill = model2fill.module
            move_weights(model2fill, gpt2model)
            model.cuda()
        else:
            model = setup_model(args)

    # Run on test data.
    prefix = "wiki" #os.path.basename(args.valid_data)
    evaluate_and_print_results(prefix, eval_data,
                               model, args, timers)
コード例 #19
0
def main():
    """Main training program."""

    # Disable CuDNN.
    torch.backends.cudnn.enabled = False

    # Timer.
    timers = Timers()

    # Arguments.
    args = get_args()

    # Pytorch distributed.
    initialize_distributed(args)

    # Random seeds for reproducability.
    set_random_seed(args.seed)

    # get the tokenizer
    tokenizer = GPT2Tokenizer(
        os.path.join(args.tokenizer_path, 'vocab.json'),
        os.path.join(args.tokenizer_path, 'chinese_vocab.model'))

    # load data
    test_dataloader, test_dataset = load_data(args, 'test', tokenizer, 1)
    # Set an arbitrary positive integer since the optimizer and the scheduler will not be used when do eval.
    args.train_iters = 1

    # Model
    model, _, _ = setup_model_and_optimizer(args)

    device = torch.cuda.current_device()

    # give a time stemp to the model
    cur_time = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())
    results_dir = os.path.join(args.results_dir,
                               "{}-{}".format(args.model_name, cur_time))

    if torch.distributed.get_rank() == 0:
        os.makedirs(results_dir, exist_ok=True)

    model.eval()
    all_sids = []
    all_cids = []
    all_losses = []
    with torch.no_grad():
        for batch, no_model_batch in tqdm(
                test_dataloader,
                desc="Evaluating",
                disable=(torch.distributed.get_rank() != 0)):
            for k in batch:
                batch[k] = batch[k].to(device)
            for k in no_model_batch:
                no_model_batch[k] = no_model_batch[k].to(device)

            output = model(**batch)
            losses = mpu.vocab_parallel_cross_entropy(
                output.contiguous().float(), no_model_batch["labels"])
            loss_mask = no_model_batch["loss_mask"]
            loss = torch.sum(losses * loss_mask,
                             dim=-1) / loss_mask.sum(dim=-1)

            loss_tensor_list = [
                torch.zeros_like(loss).to(device)
                for _ in range(mpu.get_data_parallel_world_size())
            ]
            torch.distributed.all_gather(loss_tensor_list,
                                         loss.data,
                                         group=mpu.get_data_parallel_group())
            all_losses.extend(loss_tensor_list)

            sids = no_model_batch["sids"]
            sid_tensor_list = [
                torch.zeros_like(sids)
                for _ in range(mpu.get_data_parallel_world_size())
            ]
            torch.distributed.all_gather(sid_tensor_list,
                                         sids.data,
                                         group=mpu.get_data_parallel_group())
            all_sids.extend(sid_tensor_list)

            cids = no_model_batch["cids"]
            cid_tensor_list = [
                torch.zeros_like(cids)
                for _ in range(mpu.get_data_parallel_world_size())
            ]
            torch.distributed.all_gather(cid_tensor_list,
                                         cids.data,
                                         group=mpu.get_data_parallel_group())
            all_cids.extend(cid_tensor_list)

    if torch.distributed.get_rank() == 0:
        all_losses = torch.stack(all_losses).view(-1).cpu().detach().numpy()
        all_sids = torch.stack(all_sids).view(-1).cpu().detach().numpy()
        all_cids = torch.stack(all_cids).view(-1).cpu().detach().numpy()

        truth_labels = test_dataset.truth_labels
        preds = [[] for _ in truth_labels]

        for sid, cid, loss in zip(all_sids, all_cids, all_losses):
            preds[sid].append((cid, loss))

        preds = [min(p, key=lambda x: x[1])[0] for p in preds if len(p) > 0]

        yprint("Acc: {}".format(
            sum([int(p == l)
                 for p, l in zip(preds, truth_labels)]) / len(truth_labels)))
        with open(os.path.join(results_dir, "zero-shot_result.txt"), "w") as f:
            f.write("Acc: {}\n".format(
                sum([int(p == l) for p, l in zip(preds, truth_labels)]) /
                len(truth_labels)))

    torch.distributed.barrier()
def main():
    """Main training program."""

    num_of_gpus = 8
    num_of_layers = 24
    # Disable CuDNN.
    torch.backends.cudnn.enabled = False

    # Timer.
    timers = Timers()

    # Arguments.
    args = get_args()
    file_len = 0
    for line in open(args.valid_data[0], 'r', encoding='utf-8'):
        file_len += 1
    print("file_len= ", file_len)
    # Pytorch distributed.
    initialize_distributed(args)

    # Random seeds for reproducability.
    set_random_seed(args.seed)

    # Data stuff.
    train_data, val_data, test_data, args.vocab_size, \
        args.eod_token = get_train_val_test_data(args)

    # Model, optimizer, and learning rate.
    model, optimizer, lr_scheduler = setup_model_and_optimizer2(args)
    args2 = copy.deepcopy(args)
    args2.load = "/relevance2-nfs/romittel/DeepSpeedExamples-amawa-moe/Megatron-LM-base-iterator/checkpoints_mlm/"
    if torch.distributed.get_rank() == 0:
        print('Pretrain GPT2 model')
        print_args(args)
        print_args(args2)
    if torch.distributed.get_rank() == 0:
        print("args.load=", args.load)
        print("args2.load=", args2.load)
    model2, optimizer2, lr_scheduler2 = setup_model_and_optimizer(args2)
    #model.optimizer.dynamic_loss_scale=True
    j = torch.distributed.get_rank()
    # word_embeddings
    model.module.module.word_embeddings.weight.data.copy_(
        model2.module.module.module.word_embeddings.weight.data)

    # position_embeddings
    model.module.module.token_type_embeddings.weight.data.copy_(
        model2.module.module.module.token_type_embeddings.weight.data)
    model.module.module.position_embeddings.weight.data.copy_(
        model2.module.module.module.position_embeddings.weight.data)

    # input_layernorm
    model.module.module.input_layernorm.weight.data.copy_(
        model2.module.module.module.input_layernorm.weight.data)
    model.module.module.input_layernorm.bias.data.copy_(
        model2.module.module.module.input_layernorm.bias.data)
    for i in range(num_of_layers):
        # attention.query_key_value.bias
        model.module.module.transformer.layers[
            i].attention.query_key_value.weight.data.copy_(
                model2.module.module.module.transformer.layers[i].attention.
                query_key_value.weight.data)

        model.module.module.transformer.layers[
            i].attention.query_key_value.bias.data.copy_(
                model2.module.module.module.transformer.layers[i].attention.
                query_key_value.bias.data)

        # self_output.dense
        model.module.module.transformer.layers[
            i].self_output.dense.weight.data.copy_(
                model2.module.module.module.transformer.layers[i].self_output.
                dense.weight.data)
        model.module.module.transformer.layers[
            i].self_output.dense.bias.data.copy_(
                model2.module.module.module.transformer.layers[i].self_output.
                dense.bias.data)

        #self_output.layernorm
        model.module.module.transformer.layers[
            i].self_output.layernorm.weight.data.copy_(
                model2.module.module.module.transformer.layers[i].self_output.
                layernorm.weight.data)
        model.module.module.transformer.layers[
            i].self_output.layernorm.bias.data.copy_(
                model2.module.module.module.transformer.layers[i].self_output.
                layernorm.bias.data)

        #layernorm
        model.module.module.transformer.layers[i].layernorm.weight.data.copy_(
            model2.module.module.module.transformer.layers[i].layernorm.weight.
            data)
        model.module.module.transformer.layers[i].layernorm.bias.data.copy_(
            model2.module.module.module.transformer.layers[i].layernorm.bias.
            data)

        # mlp
        if i % 2 == 1:
            model.module.module.transformer.layers[
                i].mlp.dense_h_to_4h.weight.data.copy_(
                    model2.module.module.module.transformer.layers[i].mlp.
                    dense_h_to_4h.weight.data)
            model.module.module.transformer.layers[
                i].mlp.dense_h_to_4h.bias.data.copy_(
                    model2.module.module.module.transformer.layers[i].mlp.
                    dense_h_to_4h.bias.data)

            model.module.module.transformer.layers[
                i].mlp.dense_4h_to_h.weight.data.copy_(
                    model2.module.module.module.transformer.layers[i].mlp.
                    dense_4h_to_h.weight.data)
            model.module.module.transformer.layers[
                i].mlp.dense_4h_to_h.bias.data.copy_(
                    model2.module.module.module.transformer.layers[i].mlp.
                    dense_4h_to_h.bias.data)
        else:
            model.module.module.transformer.layers[
                i].mlp.deepspeed_moe.gate.wg.weight.data.copy_(
                    model2.module.module.module.transformer.layers[i].mlp.
                    deepspeed_moe.gate.wg.weight.data)
            model.module.module.transformer.layers[
                i].mlp.deepspeed_moe.gate.wg.bias.data.copy_(
                    model2.module.module.module.transformer.layers[i].mlp.
                    deepspeed_moe.gate.wg.bias.data)
            for k in range(32):
                model.module.module.transformer.layers[
                    i].mlp.deepspeed_moe.experts.deepspeed_experts[
                        k].dense_h_to_4h.weight.data.copy_(
                            model2.module.module.module.transformer.layers[i].
                            mlp.deepspeed_moe.experts.deepspeed_experts[k].
                            dense_h_to_4h.weight.data)
                model.module.module.transformer.layers[
                    i].mlp.deepspeed_moe.experts.deepspeed_experts[
                        k].dense_h_to_4h.bias.data.copy_(
                            model2.module.module.module.transformer.layers[i].
                            mlp.deepspeed_moe.experts.deepspeed_experts[k].
                            dense_h_to_4h.bias.data)

                model.module.module.transformer.layers[
                    i].mlp.deepspeed_moe.experts.deepspeed_experts[
                        k].dense_4h_to_h.weight.data.copy_(
                            model2.module.module.module.transformer.layers[i].
                            mlp.deepspeed_moe.experts.deepspeed_experts[k].
                            dense_4h_to_h.weight.data)
                model.module.module.transformer.layers[
                    i].mlp.deepspeed_moe.experts.deepspeed_experts[
                        k].dense_4h_to_h.bias.data.copy_(
                            model2.module.module.module.transformer.layers[i].
                            mlp.deepspeed_moe.experts.deepspeed_experts[k].
                            dense_4h_to_h.bias.data)

    if args.deepspeed:
        print_rank_0("DeepSpeed is enabled.")
        model, optimizer, _, lr_scheduler = deepspeed.initialize(
            model=model,
            optimizer=optimizer,
            args=args,
            lr_scheduler=lr_scheduler,
            mpu=mpu,
            dist_init_required=False)
        print("Optimizer's state_dict:")
        print(optimizer.state_dict()['fp32_groups'])
    iteration = 100
    save_checkpoint(iteration, model, optimizer, lr_scheduler, args)
コード例 #21
0
def main():
    """Main training program."""

    # Disable CuDNN.
    torch.backends.cudnn.enabled = False

    # Timer.
    timers = Timers()

    # Arguments.
    args = get_args()

    writer = None
    if args.tensorboard_dir and args.rank == 0:
        try:
            from torch.utils.tensorboard import SummaryWriter
            writer = SummaryWriter(log_dir=args.tensorboard_dir)
        except ModuleNotFoundError:
            print_rank_0('WARNING: TensorBoard writing requested but is not '
                         'available (are you using PyTorch 1.1.0 or later?), '
                         'no TensorBoard logs will be written.')
            writer = None

    # Pytorch distributed.
    initialize_distributed(args)
    if torch.distributed.get_rank() == 0:
        print('Pretrain BERT model')
        print_args(args, writer)

    # Autoresume.
    torch.distributed.barrier()
    if args.adlr_autoresume:
        enable_adlr_autoresume(args)

    # Random seeds for reproducability.
    set_random_seed(args.seed)

    # Data stuff.
    train_data, val_data, test_data, args.tokenizer_num_tokens, \
        args.tokenizer_num_type_tokens = get_train_val_test_data(args)

    # Model, optimizer, and learning rate.
    model, optimizer, lr_scheduler = setup_model_and_optimizer(args)

    if args.resume_dataloader:
        if train_data is not None:
            train_data.batch_sampler.start_iter = args.iteration % \
                                                  len(train_data)
            print_rank_0('setting training data start iteration to {}'.format(
                train_data.batch_sampler.start_iter))
        if val_data is not None:
            start_iter_val = (args.iteration // args.eval_interval) * \
                             args.eval_iters
            val_data.batch_sampler.start_iter = start_iter_val % \
                                                len(val_data)
            print_rank_0(
                'setting validation data start iteration to {}'.format(
                    val_data.batch_sampler.start_iter))

    if train_data is not None:
        train_data_iterator = iter(train_data)
    else:
        train_data_iterator = None
    if val_data is not None:
        val_data_iterator = iter(val_data)
    else:
        val_data_iterator = None

    iteration = 0
    if args.train_iters > 0:
        if args.do_train:
            iteration, skipped = train(model, optimizer, lr_scheduler,
                                       train_data_iterator, val_data_iterator,
                                       timers, args, writer)
        if args.do_valid:
            prefix = 'the end of training for val data'
            val_loss = evaluate_and_print_results(prefix, val_data_iterator,
                                                  model, args, writer,
                                                  iteration, timers, False)

    if args.save and iteration != 0:
        save_checkpoint(iteration, model, optimizer, lr_scheduler, args)

    if test_data is not None:
        test_data_iterator = iter(test_data)
    else:
        test_data_iterator = None

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, test_data_iterator, model, args,
                                   None, 0, timers, True)
コード例 #22
0
def main():
    """Main training program."""

    num_of_gpus = 8
    num_of_layers = 24
    hp = 1024 // num_of_gpus
    d_binglr = torch.load(
        '/relevance2-nfs/romittel/binglr_pretrained_model/pytorch_model.bin')
    emb_per_gpu = d_binglr['bert.embeddings.word_embeddings.weight'].size(
    )[0] // num_of_gpus
    # Disable CuDNN.
    torch.backends.cudnn.enabled = False

    # Timer.
    timers = Timers()

    # Arguments.
    args = get_args()
    file_len = 0
    for line in open(args.valid_data[0], 'r', encoding='utf-8'):
        file_len += 1
    print("file_len= ", file_len)
    # Pytorch distributed.
    initialize_distributed(args)
    if torch.distributed.get_rank() == 0:
        print('Pretrain GPT2 model')
        print_args(args)

    # Random seeds for reproducability.
    set_random_seed(args.seed)

    # Data stuff.
    train_data, val_data, test_data, args.vocab_size, \
        args.eod_token = get_train_val_test_data(args)

    # Model, optimizer, and learning rate.
    model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
    #model.optimizer.dynamic_loss_scale=True
    j = torch.distributed.get_rank()
    # word_embeddings
    wts = model.module.module.word_embeddings.weight
    num_embeddings_per_partition = model.module.module.word_embeddings.num_embeddings_per_partition
    embedding_dim = model.module.module.word_embeddings.embedding_dim
    model.module.module.word_embeddings.weight.data.copy_(
        torch.nn.Parameter(torch.cat(
            (torch.tensor(d_binglr['bert.embeddings.word_embeddings.weight'][
                j * emb_per_gpu:(j + 1) * emb_per_gpu, :].clone().detach(),
                          device=wts.device,
                          dtype=wts.dtype),
             torch.zeros(
                 (num_embeddings_per_partition - emb_per_gpu, embedding_dim),
                 device=wts.device,
                 dtype=wts.dtype))),
                           requires_grad=True).data)

    # position_embeddings
    wts = model.module.module.position_embeddings.weight
    model.module.module.position_embeddings.weight.data.copy_(
        torch.nn.Parameter(torch.tensor(
            d_binglr['bert.embeddings.position_embeddings.weight'].clone(
            ).detach(),
            device=wts.device,
            dtype=wts.dtype),
                           requires_grad=True).data)

    # input_layernorm
    wts = model.module.module.input_layernorm.weight
    model.module.module.input_layernorm.weight.data.copy_(
        torch.nn.Parameter(torch.tensor(
            d_binglr['bert.embeddings.LayerNorm.weight'].clone().detach(),
            device=wts.device,
            dtype=wts.dtype),
                           requires_grad=True).data)
    wts = model.module.module.input_layernorm.bias
    model.module.module.input_layernorm.bias.data.copy_(
        torch.nn.Parameter(torch.tensor(
            d_binglr['bert.embeddings.LayerNorm.bias'].clone().detach(),
            device=wts.device,
            dtype=wts.dtype),
                           requires_grad=True))

    for i in range(num_of_layers):
        # attention.query_key_value.bias
        query_weight = d_binglr[
            'bert.encoder.layer.' + str(i) +
            '.attention.self.query.weight'].clone().detach()
        query_bias = d_binglr['bert.encoder.layer.' + str(i) +
                              '.attention.self.query.bias'].clone().detach()
        key_weight = d_binglr['bert.encoder.layer.' + str(i) +
                              '.attention.self.key.weight'].clone().detach()
        key_bias = d_binglr['bert.encoder.layer.' + str(i) +
                            '.attention.self.key.bias'].clone().detach()
        value_weight = d_binglr[
            'bert.encoder.layer.' + str(i) +
            '.attention.self.value.weight'].clone().detach()
        value_bias = d_binglr['bert.encoder.layer.' + str(i) +
                              '.attention.self.value.bias'].clone().detach()
        wts = model.module.module.transformer.layers[
            i].attention.query_key_value.weight
        model.module.module.transformer.layers[
            i].attention.query_key_value.weight.data.copy_(
                torch.nn.Parameter(torch.cat(
                    (torch.tensor(query_weight[j * hp:(j + 1) * hp, :],
                                  device=wts.device,
                                  dtype=wts.dtype),
                     torch.tensor(key_weight[j * hp:(j + 1) * hp, :],
                                  device=wts.device,
                                  dtype=wts.dtype),
                     torch.tensor(value_weight[j * hp:(j + 1) * hp, :],
                                  device=wts.device,
                                  dtype=wts.dtype))),
                                   requires_grad=True).data)

        wts = model.module.module.transformer.layers[
            i].attention.query_key_value.bias
        model.module.module.transformer.layers[
            i].attention.query_key_value.bias.data.copy_(
                torch.nn.Parameter(torch.cat(
                    (torch.tensor(query_bias[j * hp:(j + 1) * hp],
                                  device=wts.device,
                                  dtype=wts.dtype),
                     torch.tensor(key_bias[j * hp:(j + 1) * hp],
                                  device=wts.device,
                                  dtype=wts.dtype),
                     torch.tensor(value_bias[j * hp:(j + 1) * hp],
                                  device=wts.device,
                                  dtype=wts.dtype))),
                                   requires_grad=True).data)

        # self_output.dense
        wts = model.module.module.transformer.layers[
            i].self_output.dense.weight
        model.module.module.transformer.layers[
            i].self_output.dense.weight.data.copy_(
                torch.nn.Parameter(
                    torch.tensor(d_binglr['bert.encoder.layer.' + str(i) +
                                          '.attention.output.dense.weight']
                                 [:, j * hp:(j + 1) * hp].clone().detach(),
                                 device=wts.device,
                                 dtype=wts.dtype),
                    requires_grad=True).data)
        wts = model.module.module.transformer.layers[i].self_output.dense.bias
        model.module.module.transformer.layers[
            i].self_output.dense.bias.data.copy_(
                torch.nn.Parameter(torch.tensor(
                    d_binglr['bert.encoder.layer.' + str(i) +
                             '.attention.output.dense.bias'].clone().detach(),
                    device=wts.device,
                    dtype=wts.dtype),
                                   requires_grad=True).data)

        #self_output.layernorm
        wts = model.module.module.transformer.layers[
            i].self_output.layernorm.weight
        model.module.module.transformer.layers[
            i].self_output.layernorm.weight.data.copy_(
                torch.nn.Parameter(torch.tensor(d_binglr[
                    'bert.encoder.layer.' + str(i) +
                    '.attention.output.LayerNorm.weight'].clone().detach(),
                                                device=wts.device,
                                                dtype=wts.dtype),
                                   requires_grad=True).data)
        wts = model.module.module.transformer.layers[
            i].self_output.layernorm.bias
        model.module.module.transformer.layers[
            i].self_output.layernorm.bias.data.copy_(
                torch.nn.Parameter(torch.tensor(d_binglr[
                    'bert.encoder.layer.' + str(i) +
                    '.attention.output.LayerNorm.bias'].clone().detach(),
                                                device=wts.device,
                                                dtype=wts.dtype),
                                   requires_grad=True).data)

        #layernorm
        wts = model.module.module.transformer.layers[i].layernorm.weight
        model.module.module.transformer.layers[i].layernorm.weight.data.copy_(
            torch.nn.Parameter(torch.tensor(
                d_binglr['bert.encoder.layer.' + str(i) +
                         '.output.LayerNorm.weight'].clone().detach(),
                device=wts.device,
                dtype=wts.dtype),
                               requires_grad=True).data)
        wts = model.module.module.transformer.layers[i].layernorm.bias
        model.module.module.transformer.layers[i].layernorm.bias.data.copy_(
            torch.nn.Parameter(torch.tensor(
                d_binglr['bert.encoder.layer.' + str(i) +
                         '.output.LayerNorm.bias'].clone().detach(),
                device=wts.device,
                dtype=wts.dtype),
                               requires_grad=True).data)

        if i % 2 == 1:
            wts = model.module.module.transformer.layers[
                i].mlp.dense_h_to_4h.weight
            model.module.module.transformer.layers[
                i].mlp.dense_h_to_4h.weight.data.copy_(
                    torch.nn.Parameter(torch.tensor(d_binglr[
                        'bert.encoder.layer.' + str(i) +
                        '.intermediate.dense.weight'][j * hp * 4:(j + 1) * hp *
                                                      4, :].clone().detach(),
                                                    device=wts.device,
                                                    dtype=wts.dtype),
                                       requires_grad=True).data)
            wts = model.module.module.transformer.layers[
                i].mlp.dense_h_to_4h.bias
            model.module.module.transformer.layers[
                i].mlp.dense_h_to_4h.bias.data.copy_(
                    torch.nn.Parameter(torch.tensor(d_binglr[
                        'bert.encoder.layer.' + str(i) +
                        '.intermediate.dense.bias'][j * hp * 4:(j + 1) * hp *
                                                    4].clone().detach(),
                                                    device=wts.device,
                                                    dtype=wts.dtype),
                                       requires_grad=True).data)

            wts = model.module.module.transformer.layers[
                i].mlp.dense_4h_to_h.weight
            model.module.module.transformer.layers[
                i].mlp.dense_4h_to_h.weight.data.copy_(
                    torch.nn.Parameter(torch.tensor(
                        d_binglr['bert.encoder.layer.' + str(i) +
                                 '.output.dense.weight'][:, j * hp *
                                                         4:(j + 1) * hp *
                                                         4].clone().detach(),
                        device=wts.device,
                        dtype=wts.dtype),
                                       requires_grad=True).data)
            wts = model.module.module.transformer.layers[
                i].mlp.dense_4h_to_h.bias
            model.module.module.transformer.layers[
                i].mlp.dense_4h_to_h.bias.data.copy_(
                    torch.nn.Parameter(torch.tensor(
                        d_binglr['bert.encoder.layer.' + str(i) +
                                 '.output.dense.bias'].clone().detach(),
                        device=wts.device,
                        dtype=wts.dtype),
                                       requires_grad=True).data)

    if args.deepspeed:
        print_rank_0("DeepSpeed is enabled.")
        model, optimizer, _, lr_scheduler = deepspeed.initialize(
            model=model,
            optimizer=optimizer,
            args=args,
            lr_scheduler=lr_scheduler,
            mpu=mpu,
            dist_init_required=False)
        print("Optimizer's state_dict:")
        print(optimizer.state_dict()['fp32_groups'])
    iteration = 100
    save_checkpoint(iteration, model, optimizer, lr_scheduler, args)