def _save(trainer, checkpoint_dir, state_dict_key_name):
    """Saves the ORTTrainer checkpoint and the complete state dictionary to the given checkpoint_dir directory""" 

    # save current model parameters as a checkpoint
    makedir(checkpoint_dir)
    checkpoint.experimental_save_checkpoint(trainer, checkpoint_dir)
    state_dict = checkpoint.experimental_state_dict(trainer)
    pickle.dump({state_dict_key_name : state_dict}, open(os.path.join(checkpoint_dir, state_dict_key_name+'.pkl'), "wb"))
Exemple #2
0
def testToyBertCheckpointBasic():
    # Common setup
    seed = 1
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    optim_config = optim.LambConfig()
    opts = orttrainer.ORTTrainerOptions(
        {'debug': {
            'deterministic_compute': True
        }})

    # Create ORTTrainer and save initial state in a dict
    model = load_bert_onnx_model()
    model_desc = bert_model_description()
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)
    sd = checkpoint.experimental_state_dict(trainer)

    ## All initializers must be present in the state_dict
    ##  when the specified model for ORTTRainer is an ONNX model
    for param in trainer._onnx_model.graph.initializer:
        assert param.name in sd

    ## Modify one of the state values and load into ORTTrainer
    sd['bert.encoder.layer.0.attention.output.LayerNorm.weight'] += 10
    checkpoint.experimental_load_state_dict(trainer, sd)

    ## Save a checkpoint
    ckpt_dir = 'testdata'
    checkpoint.experimental_save_checkpoint(trainer, ckpt_dir,
                                            'bert_toy_save_test')
    del trainer
    del model

    # Create a new ORTTrainer and load the checkpoint from previous ORTTrainer
    model2 = load_bert_onnx_model()
    model_desc2 = bert_model_description()
    trainer2 = orttrainer.ORTTrainer(model2,
                                     model_desc2,
                                     optim_config,
                                     options=opts)
    checkpoint.experimental_load_checkpoint(trainer2, ckpt_dir,
                                            'bert_toy_save_test')
    loaded_sd = checkpoint.experimental_state_dict(trainer2)

    # Assert whether original state and the one loaded from checkpoint matches
    for k, v in loaded_sd.items():
        assert torch.all(torch.eq(v, sd[k]))
def do_pretrain(args):
    if is_main_process(args) and args.tensorboard_dir:
        tb_writer = SummaryWriter(log_dir=args.tensorboard_dir)
        tb_writer.add_text("args", args.to_json_string())
        tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={})
    else:
        tb_writer = None

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    ort.set_seed(args.seed)

    device, args = setup_training(args)

    model = prepare_model(args, device)

    logger.info("Running training: Batch size = %d, initial LR = %f",
                args.train_batch_size, args.learning_rate)

    most_recent_ckpts_paths = []
    average_loss = 0.0
    epoch = 0
    training_steps = 0

    pool = ProcessPoolExecutor(1)
    while True:
        files = [
            os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir)
            if os.path.isfile(os.path.join(args.input_dir, f))
            and 'training' in f
        ]
        files.sort()
        random.shuffle(files)

        f_id = 0
        train_dataloader, data_file = create_pretraining_dataset(
            get_data_file(f_id, args.world_rank, args.world_size, files),
            args.max_predictions_per_seq, args)

        for f_id in range(1, len(files)):
            logger.info("data file %s" % (data_file))

            dataset_future = pool.submit(
                create_pretraining_dataset,
                get_data_file(f_id, args.world_rank, args.world_size, files),
                args.max_predictions_per_seq, args)

            train_iter = tqdm(train_dataloader, desc="Iteration"
                              ) if is_main_process(args) else train_dataloader
            for step, batch in enumerate(train_iter):
                training_steps += 1
                batch = [t.to(device) for t in batch]
                input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch

                loss, _, _ = model.train_step(input_ids, input_mask,
                                              segment_ids, masked_lm_labels,
                                              next_sentence_labels)
                average_loss += loss.item()

                global_step = model._train_step_info.optimization_step
                if training_steps % (args.log_freq *
                                     args.gradient_accumulation_steps) == 0:
                    if is_main_process(args):
                        divisor = args.log_freq * args.gradient_accumulation_steps
                        if tb_writer:
                            lr = model.options.lr_scheduler.get_last_lr()[0]
                            tb_writer.add_scalar(
                                'train/summary/scalar/Learning_Rate', lr,
                                global_step)
                            if args.fp16:
                                tb_writer.add_scalar(
                                    'train/summary/scalar/loss_scale_25', loss,
                                    global_step)
                                # TODO: ORTTrainer to expose all_finite
                                # tb_writer.add_scalar('train/summary/scalar/all_fp16_gradients_finite_859', all_finite, global_step)
                            tb_writer.add_scalar('train/summary/total_loss',
                                                 average_loss / divisor,
                                                 global_step)

                        print("Step:{} Average Loss = {}".format(
                            global_step, average_loss / divisor))

                    if global_step >= args.max_steps or global_step >= force_to_stop_max_steps:
                        if tb_writer:
                            tb_writer.close()

                    if global_step >= args.max_steps:
                        if args.save_checkpoint:
                            experimental_save_checkpoint(
                                model, args.output_dir)
                        final_loss = average_loss / (
                            args.log_freq * args.gradient_accumulation_steps)
                        return final_loss

                    average_loss = 0

            del train_dataloader

            train_dataloader, data_file = dataset_future.result(timeout=None)

        epoch += 1