Ejemplo n.º 1
0
def testToyBertCheckpointFrozenWeights():
    # Common setup
    seed = 1
    total_steps = 10
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    opts = orttrainer.ORTTrainerOptions({
        'debug': {
            'deterministic_compute': True
        },
        'utils': {
            'frozen_weights':
            ['bert.encoder.layer.0.attention.self.value.weight']
        }
    })

    # Create ORTTrainer and save initial state in a dict
    model = load_bert_onnx_model()
    model_desc = bert_model_description()
    optim_config = optim.LambConfig()
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)

    # Train for a few steps
    for i in range(total_steps):
        sample_input = generate_random_input_from_model_desc(model_desc, seed)
        _ = trainer.train_step(*sample_input)
    sample_input = generate_random_input_from_model_desc(
        model_desc, seed + total_steps + 1)
    # Evaluate once to get a base loss
    loss = trainer.eval_step(*sample_input)
    # Save checkpoint
    state_dict = checkpoint.experimental_state_dict(trainer)

    # Load previous state into another instance of ORTTrainer
    model2 = load_bert_onnx_model()
    model_desc2 = bert_model_description()
    optim_config2 = optim.LambConfig()
    trainer2 = orttrainer.ORTTrainer(model2,
                                     model_desc2,
                                     optim_config2,
                                     options=opts)
    checkpoint.experimental_load_state_dict(trainer2, state_dict)
    # Evaluate once to get a base loss
    ckpt_loss = trainer2.eval_step(*sample_input)

    # Must match as both trainers have the same dict state
    assert_allclose(loss.cpu(), ckpt_loss.cpu())
    loaded_state_dict = checkpoint.experimental_state_dict(trainer2)
    assert state_dict.keys() == loaded_state_dict.keys()
Ejemplo n.º 2
0
def testToyBertCheckpointBasic():
    # Common setup
    seed = 1
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    optim_config = optim.LambConfig()
    opts = orttrainer.ORTTrainerOptions(
        {'debug': {
            'deterministic_compute': True
        }})

    # Create ORTTrainer and save initial state in a dict
    model = load_bert_onnx_model()
    model_desc = bert_model_description()
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)
    sd = checkpoint.experimental_state_dict(trainer)

    ## All initializers must be present in the state_dict
    ##  when the specified model for ORTTRainer is an ONNX model
    for param in trainer._onnx_model.graph.initializer:
        assert param.name in sd

    ## Modify one of the state values and load into ORTTrainer
    sd['bert.encoder.layer.0.attention.output.LayerNorm.weight'] += 10
    checkpoint.experimental_load_state_dict(trainer, sd)

    ## Save a checkpoint
    ckpt_dir = 'testdata'
    checkpoint.experimental_save_checkpoint(trainer, ckpt_dir,
                                            'bert_toy_save_test')
    del trainer
    del model

    # Create a new ORTTrainer and load the checkpoint from previous ORTTrainer
    model2 = load_bert_onnx_model()
    model_desc2 = bert_model_description()
    trainer2 = orttrainer.ORTTrainer(model2,
                                     model_desc2,
                                     optim_config,
                                     options=opts)
    checkpoint.experimental_load_checkpoint(trainer2, ckpt_dir,
                                            'bert_toy_save_test')
    loaded_sd = checkpoint.experimental_state_dict(trainer2)

    # Assert whether original state and the one loaded from checkpoint matches
    for k, v in loaded_sd.items():
        assert torch.all(torch.eq(v, sd[k]))
Ejemplo n.º 3
0
def create_orttrainer_and_load_checkpoint(device, trainer_opts, checkpoint_dir):
    """Instantiate and load checkpoint into trainer

    - Instantiates the ORTTrainer with given input trainer_opts configuration for a simple transformer model
    - Loads the checkpoint from directory checkpoint_dir into the trainer
    - Runs eval_step on the trainer so the trainer onnx graph is initialized
    - Returns the trainer state_dict and the pytorch model
    """
    seed = 1
    torch.manual_seed(seed)
    set_seed(seed)

    # PyTorch transformer model setup
    learning_rate = 0.1
    optim_config = optim.LambConfig(lr=learning_rate)
    model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device)
    trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts))

    # load checkpoint into trainer
    checkpoint.experimental_load_checkpoint(trainer, checkpoint_dir)

    # run an eval step to innitialize the graph
    torch.manual_seed(seed)
    set_seed(seed)
    data, targets = batcher_fn(train_data, 0)
    trainer.eval_step(data, targets)

    return checkpoint.experimental_state_dict(trainer), model
Ejemplo n.º 4
0
def load_model_optim_state_and_eval(device, trainer_opts, use_lamb=True):
    learning_rate = 0.1
    seed = 1

    torch.manual_seed(seed)
    set_seed(seed)

    optim_config = optim.LambConfig(
        lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate)
    model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(
        device)
    trainer = orttrainer.ORTTrainer(
        model,
        model_desc,
        optim_config,
        loss_fn=loss_fn,
        options=orttrainer.ORTTrainerOptions(trainer_opts))

    # load dummy state
    dummy_init_state = generate_dummy_optim_state(model, optim_config)
    checkpoint._experimental_load_optimizer_state(trainer, dummy_init_state)

    # run an eval step to innitialize the graph
    data, targets = batcher_fn(train_data, 0)
    trainer.eval_step(data, targets)

    return dummy_init_state, checkpoint.experimental_state_dict(trainer)
Ejemplo n.º 5
0
def _save(trainer, checkpoint_dir, state_dict_key_name):
    """Saves the ORTTrainer checkpoint and the complete state dictionary to the given checkpoint_dir directory""" 

    # save current model parameters as a checkpoint
    makedir(checkpoint_dir)
    checkpoint.experimental_save_checkpoint(trainer, checkpoint_dir)
    state_dict = checkpoint.experimental_state_dict(trainer)
    pickle.dump({state_dict_key_name : state_dict}, open(os.path.join(checkpoint_dir, state_dict_key_name+'.pkl'), "wb"))
Ejemplo n.º 6
0
 def update_torch_model(self, ):
     if self.ort_model:
         logger.info("Updating weights of torch model from ORT model.")
         ort_state_dict = checkpoint.experimental_state_dict(self.ort_model)
         self.model.load_state_dict(ort_state_dict, strict=False)
     else:
         logger.warning(
             "No ORT model found to update weights from, assuming torch model is up to date."
         )
Ejemplo n.º 7
0
def testToyBertLoadOptimState(optimizer, mixedprecision_enabled):
    # Common setup
    rtol = 1e-03
    device = 'cuda'
    seed = 1
    torch.manual_seed(seed)
    onnxruntime.set_seed(seed)
    optim_config = optimizer
    opts = orttrainer.ORTTrainerOptions({
        'debug': {
            'deterministic_compute': True
        },
        'device': {
            'id': device
        },
        'mixed_precision': {
            'enabled': mixedprecision_enabled,
        },
        'distributed': {
            'allreduce_post_accumulation': True
        }
    })

    # Create ORTTrainer and save initial state in a dict
    model = load_bert_onnx_model()
    model_desc = bert_model_description()
    dummy_init_state = _test_commons.generate_dummy_optim_state(
        model, optimizer)
    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)
    checkpoint._experimental_load_optimizer_state(trainer, dummy_init_state)

    # Expected values
    expected_eval_loss = [10.997552871]
    input_ids = torch.tensor(
        [[26598], [21379], [19922], [5219], [5644], [20559], [23777], [25672],
         [22969], [16824], [16822], [635], [27399], [20647], [18519], [15546]],
        device=device)
    segment_ids = torch.tensor([[0], [1], [0], [1], [0], [0], [1], [0], [0],
                                [1], [1], [0], [0], [1], [1], [1]],
                               device=device)
    input_mask = torch.tensor([[0], [0], [0], [0], [1], [1], [1], [0], [1],
                               [1], [0], [0], [0], [1], [0], [0]],
                              device=device)
    masked_lm_labels = torch.tensor(
        [[25496], [16184], [11005], [16228], [14884], [21660], [8678], [23083],
         [4027], [8397], [11921], [1333], [26482], [1666], [17925], [27978]],
        device=device)
    next_sentence_labels = torch.tensor(
        [0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0], device=device)

    # Actual values
    _ = trainer.eval_step(input_ids, segment_ids, input_mask, masked_lm_labels,
                          next_sentence_labels)

    actual_state = checkpoint.experimental_state_dict(trainer)
    actual_optim_state = _test_commons.get_optim_state_from_state_dict(
        actual_state, optimizer)
    _test_helpers.assert_optim_state(dummy_init_state, actual_optim_state)
Ejemplo n.º 8
0
def main():

    args = parse_arguments()

    if args.use_env and 'LOCAL_RANK' in os.environ:
        args.local_rank = int(os.environ['LOCAL_RANK'])

    random.seed(args.seed + args.local_rank)
    np.random.seed(args.seed + args.local_rank)
    torch.manual_seed(args.seed + args.local_rank)
    torch.cuda.manual_seed(args.seed + args.local_rank)
    worker_init = WorkerInitObj(args.seed + args.local_rank)

    device, args = setup_training(args)
    dllogger.log(step="PARAMETER", data={"Config": [str(args)]})

    # Prepare optimizer
    model, checkpoint, global_step = prepare_model(args, device)

    if is_main_process(args):
        dllogger.log(step="PARAMETER", data={"SEED": args.seed})

    raw_train_start = time.time()
    if args.do_train:
        if is_main_process(args):
            dllogger.log(step="PARAMETER", data={"train_start": True})
            dllogger.log(step="PARAMETER",
                         data={"batch_size_per_gpu": args.train_batch_size})
            dllogger.log(step="PARAMETER",
                         data={"learning_rate": args.learning_rate})

        most_recent_ckpts_paths = []
        average_loss = 0.0  # averaged loss every args.log_freq steps
        epoch = 0
        training_steps = 0

        pool = ProcessPoolExecutor(1)

        # Note: We loop infinitely over epochs, termination is handled via iteration count
        while True:
            thread = None
            if not args.resume_from_checkpoint or epoch > 0 or (
                    args.phase2 and global_step < 1) or args.init_checkpoint:
                files = [
                    os.path.join(args.input_dir, f)
                    for f in os.listdir(args.input_dir)
                    if os.path.isfile(os.path.join(args.input_dir, f))
                    and 'training' in f
                ]
                files.sort()
                num_files = len(files)
                random.shuffle(files)
                f_start_id = 0
            else:
                f_start_id = checkpoint['files'][0]
                files = checkpoint['files'][1:]
                args.resume_from_checkpoint = False
                num_files = len(files)

            shared_file_list = {}

            if torch.distributed.is_initialized():
                world_size = torch.distributed.get_world_size()
                world_rank = torch.distributed.get_rank()
            elif hasattr(args, 'world_size'):
                world_size = args.world_size
                world_rank = args.world_rank
            else:
                world_size = 1
                world_rank = 0

            if world_size > num_files:
                remainder = world_size % num_files
                data_file = files[(f_start_id * world_size + world_rank +
                                   remainder * f_start_id) % num_files]
            elif world_size > 1:
                data_file = files[(f_start_id * world_size + world_rank) %
                                  num_files]
            else:
                data_file = files[f_start_id % num_files]
            # ---

            previous_file = data_file

            train_data = pretraining_dataset(data_file,
                                             args.max_predictions_per_seq)
            train_sampler = RandomSampler(train_data)
            # we need to skip last batch when we hard code inputs as an optimization
            train_dataloader = DataLoader(train_data,
                                          sampler=train_sampler,
                                          batch_size=args.train_batch_size *
                                          args.n_gpu,
                                          num_workers=4,
                                          worker_init_fn=worker_init,
                                          pin_memory=True,
                                          drop_last=True)

            gpu_batch_size = args.train_batch_size // args.gradient_accumulation_steps

            if len(files) == 1:
                f_start_id = -1
            for f_id in range(f_start_id + 1, len(files)):

                if world_size > num_files:
                    data_file = files[(f_id * world_size + world_rank +
                                       remainder * f_id) % num_files]
                elif world_size > 1:
                    data_file = files[(f_id * world_size + world_rank) %
                                      num_files]
                else:
                    data_file = files[f_id % num_files]

                previous_file = data_file

                dataset_future = pool.submit(create_pretraining_dataset,
                                             data_file,
                                             args.max_predictions_per_seq,
                                             shared_file_list, args,
                                             worker_init)

                train_iter = tqdm(
                    train_dataloader,
                    desc="Iteration",
                    disable=args.disable_progress_bar) if is_main_process(
                        args) else train_dataloader
                prev_step_time = time.time()
                for step, batch in enumerate(train_iter):

                    training_steps += 1
                    batch = [t.to(device) for t in batch]
                    input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch
                    divisor = args.gradient_accumulation_steps

                    loss, global_step = ort_supplement.run_ort_training_step(
                        args, global_step, training_steps, model, batch)
                    average_loss += loss.item()

                    if global_step >= args.max_steps:
                        train_time_raw = time.time() - raw_train_start
                        last_num_steps = int(
                            training_steps /
                            args.gradient_accumulation_steps) % args.log_freq
                        last_num_steps = args.log_freq if last_num_steps == 0 else last_num_steps
                        average_loss = torch.tensor(
                            average_loss, dtype=torch.float32).cuda()
                        average_loss = average_loss / (last_num_steps *
                                                       divisor)
                        if (torch.distributed.is_initialized()):
                            average_loss /= torch.distributed.get_world_size()
                            torch.distributed.all_reduce(average_loss)
                        final_loss = average_loss.item()
                        if is_main_process(args):
                            dllogger.log(step=(
                                epoch,
                                global_step,
                            ),
                                         data={"final_loss": final_loss})
                    elif training_steps % (
                            args.log_freq *
                            args.gradient_accumulation_steps) == 0:
                        throughput = (args.train_batch_size *
                                      args.gradient_accumulation_steps) / (
                                          time.time() - prev_step_time)
                        print('throughput = ', throughput, 'seq/sec')
                        prev_step_time = time.time()
                        sys.stdout.flush()

                        if is_main_process(args):
                            data = {
                                "average_loss":
                                average_loss / (args.log_freq * divisor),
                                "step_loss":
                                loss.item() *
                                args.gradient_accumulation_steps / divisor
                            }
                            dllogger.log(step=(
                                epoch,
                                global_step,
                            ),
                                         data=data)
                        average_loss = 0

                    if global_step >= args.max_steps or training_steps % (
                            args.num_steps_per_checkpoint *
                            args.gradient_accumulation_steps) == 0:
                        if is_main_process(args) and not args.skip_checkpoint:
                            # Save a trained model
                            dllogger.log(step="PARAMETER",
                                         data={"checkpoint_step": global_step})
                            model_to_save = model.module if hasattr(
                                model, 'module'
                            ) else model  # Only save the model it-self
                            if args.resume_step < 0 or not args.phase2:
                                output_save_file = os.path.join(
                                    args.output_dir,
                                    "ckpt_{}.pt".format(global_step))
                            else:
                                output_save_file = os.path.join(
                                    args.output_dir,
                                    "ckpt_{}.pt".format(global_step +
                                                        args.phase1_end_step))
                            if args.do_train:
                                state = {
                                    'model':
                                    model_to_save.state_dict() if hasattr(
                                        model_to_save, 'state_dict') else
                                    experimental_state_dict(model_to_save),
                                    'files': [f_id] + files
                                }
                                torch.save(state, output_save_file)

                                most_recent_ckpts_paths.append(
                                    output_save_file)
                                if len(most_recent_ckpts_paths) > 3:
                                    ckpt_to_be_removed = most_recent_ckpts_paths.pop(
                                        0)
                                    os.remove(ckpt_to_be_removed)

                        if global_step >= args.max_steps:
                            if is_main_process(args):
                                print(
                                    '-----------------------save onnx model-----------------------'
                                )
                                if not args.phase2:
                                    model_to_save.save_as_onnx(
                                        '{}/phase1_bert.onnx'.format(
                                            args.output_dir))
                                else:
                                    model_to_save.save_as_onnx(
                                        '{}/final_bert.onnx'.format(
                                            args.output_dir))
                            del train_dataloader
                            # thread.join()
                            return args, final_loss, train_time_raw

                del train_dataloader
                # thread.join()
                # Make sure pool has finished and switch train_dataloader
                # NOTE: Will block until complete
                train_dataloader, data_file = dataset_future.result(
                    timeout=None)

            epoch += 1