def prepare_model(args, device):

    # Prepare model
    config = modeling.BertConfig.from_json_file(args.config_file)

    # Padding for divisibility by 8
    if config.vocab_size % 8 != 0:
        config.vocab_size += 8 - (config.vocab_size % 8)

    model = modeling.BertForPreTraining(config)
    criterion = BertPretrainingCriterion(config.vocab_size,
                                         args.train_batch_size,
                                         args.max_seq_length)

    model.enable_apex(False)
    model = bert_model_with_loss(model, criterion)
    model = ort_supplement.create_ort_trainer(args, device, model)

    checkpoint = None
    if not args.resume_from_checkpoint:
        global_step = 0
    else:
        if args.resume_step == -1 and not args.init_checkpoint:
            model_names = [
                f for f in os.listdir(args.output_dir) if f.endswith(".pt")
            ]
            args.resume_step = max([
                int(x.split('.pt')[0].split('_')[1].strip())
                for x in model_names
            ])

        global_step = args.resume_step if not args.init_checkpoint else 0

        if not args.init_checkpoint:
            checkpoint = torch.load(os.path.join(
                args.output_dir, "ckpt_{}.pt".format(global_step)),
                                    map_location="cpu")
        else:
            checkpoint = torch.load(args.init_checkpoint, map_location="cpu")

        model.load_state_dict(checkpoint['model'], strict=False)

        if args.phase2 and not args.init_checkpoint:
            global_step -= args.phase1_end_step
        if is_main_process(args):
            print("resume step from ", args.resume_step)

    return model, checkpoint, global_step
Exemple #2
0
def prepare_model_and_optimizer(args, device):

    # Prepare model
    config = modeling.BertConfig.from_json_file(args.config_file)

    # Padding for divisibility by 8
    if config.vocab_size % 8 != 0:
        config.vocab_size += 8 - (config.vocab_size % 8)

    modeling.ACT2FN["bias_gelu"] = modeling.bias_gelu_training
    model = modeling.BertForPreTraining(config)

    if args.disable_weight_tying:
        import torch.nn as nn
        print ("WARNING!!!!!!! Disabling weight tying for this run")
        print ("BEFORE ", model.cls.predictions.decoder.weight is model.bert.embeddings.word_embeddings.weight)
        model.cls.predictions.decoder.weight = torch.nn.Parameter(model.cls.predictions.decoder.weight.clone().detach())
        print ("AFTER ", model.cls.predictions.decoder.weight is model.bert.embeddings.word_embeddings.weight)
        assert (model.cls.predictions.decoder.weight is model.bert.embeddings.word_embeddings.weight) == False

    checkpoint = None
    if not args.resume_from_checkpoint:
        global_step = 0
    else:
        if args.resume_step == -1 and not args.init_checkpoint:
            model_names = [f for f in os.listdir(args.output_dir) if f.endswith(".pt")]
            args.resume_step = max([int(x.split('.pt')[0].split('_')[1].strip()) for x in model_names])

        global_step = args.resume_step if not args.init_checkpoint else 0

        if not args.init_checkpoint:
            checkpoint = torch.load(os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step)), map_location="cpu")
        else:
            checkpoint = torch.load(args.init_checkpoint, map_location="cpu")

        model.load_state_dict(checkpoint['model'], strict=False)
        
        if args.phase2 and not args.init_checkpoint:
            global_step -= args.phase1_end_step
        if is_main_process():
            print("resume step from ", args.resume_step)

    model.to(device)
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']
    
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]

    optimizer = FusedAdam(optimizer_grouped_parameters,
                          lr=args.learning_rate)
    lr_scheduler = PolyWarmUpScheduler(optimizer, 
                                       warmup=args.warmup_proportion, 
                                       total_steps=args.max_steps,
                                       degree=1)
    if args.fp16:

        if args.loss_scale == 0:
            model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale="dynamic", cast_model_outputs=torch.float16)
        else:
            model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale=args.loss_scale, cast_model_outputs=torch.float16)
        amp._amp_state.loss_scalers[0]._loss_scale = args.init_loss_scale

    model.checkpoint_activations(args.checkpoint_activations)

    if args.resume_from_checkpoint:
        if args.phase2 or args.init_checkpoint:
            keys = list(checkpoint['optimizer']['state'].keys())
            #Override hyperparameters from previous checkpoint
            for key in keys:
                checkpoint['optimizer']['state'][key]['step'] = global_step
            for iter, item in enumerate(checkpoint['optimizer']['param_groups']):
                checkpoint['optimizer']['param_groups'][iter]['step'] = global_step
                checkpoint['optimizer']['param_groups'][iter]['t_total'] = args.max_steps
                checkpoint['optimizer']['param_groups'][iter]['warmup'] = args.warmup_proportion
                checkpoint['optimizer']['param_groups'][iter]['lr'] = args.learning_rate
        optimizer.load_state_dict(checkpoint['optimizer'])  # , strict=False)

        # Restore AMP master parameters          
        if args.fp16:
            optimizer._lazy_init_maybe_master_weights()
            optimizer._amp_stash.lazy_init_called = True
            optimizer.load_state_dict(checkpoint['optimizer'])
            for param, saved_param in zip(amp.master_params(optimizer), checkpoint['master params']):
                param.data.copy_(saved_param.data)

    if args.local_rank != -1:
        if not args.allreduce_post_accumulation:
            model = DDP(model, message_size=250000000, gradient_predivide_factor=get_world_size())
        else:
            flat_dist_call([param.data for param in model.parameters()], torch.distributed.broadcast, (0,) )
    elif args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    criterion = BertPretrainingCriterion(config.vocab_size)


    if args.disable_weight_tying:
       # Sanity Check that new param is in optimizer
       print ("SANITY CHECK OPTIMIZER: ", id(model.module.cls.predictions.decoder.weight) in [id(g) for g in optimizer.param_groups[0]['params']])
       assert id(model.module.cls.predictions.decoder.weight) in [id(g) for g in optimizer.param_groups[0]['params']]

    return model, optimizer, lr_scheduler, checkpoint, global_step, criterion
Exemple #3
0
def prepare_model_and_optimizer(args, device):

    # Prepare model
    config = modeling.BertConfig.from_json_file(args.config_file)

    # Padding for divisibility by 8
    if config.vocab_size % 8 != 0:
        config.vocab_size += 8 - (config.vocab_size % 8)

    modeling.ACT2FN["bias_gelu"] = modeling.bias_gelu_training
    model = modeling.BertForPreTraining(config)

    checkpoint = None
    if not args.resume_from_checkpoint:
        global_step = 0
    else:
        if args.resume_step == -1 and not args.init_checkpoint:
            model_names = [
                f for f in os.listdir(args.output_dir) if f.endswith(".pt")
            ]
            args.resume_step = max([
                int(x.split('.pt')[0].split('_')[1].strip())
                for x in model_names
            ])

        global_step = args.resume_step if not args.init_checkpoint else 0

        if not args.init_checkpoint:
            checkpoint = torch.load(os.path.join(
                args.output_dir, "ckpt_{}.pt".format(global_step)),
                                    map_location="cpu")
        else:
            checkpoint = torch.load(args.init_checkpoint, map_location="cpu")

        model.load_state_dict(checkpoint['model'], strict=False)

        if args.phase2 and not args.init_checkpoint:
            global_step -= args.phase1_end_step
        if is_main_process():
            print("resume step from ", args.resume_step)

    model.to(device)
    # BERT modeling  uses weight sharing between word embedding and prediction decoder.
    # So make sure the storage is pointing properly even after model is moved to device.
    if args.use_habana:
        model.cls.predictions.decoder.weight = model.bert.embeddings.word_embeddings.weight

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']

    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    if args.use_habana:
        if args.use_fused_lamb:
            try:
                from hb_custom import FusedLamb
            except ImportError:
                raise ImportError("Please install hbopt.")
            optimizer = FusedLamb(optimizer_grouped_parameters,
                                  lr=args.learning_rate)
        else:
            optimizer = NVLAMB(optimizer_grouped_parameters,
                               lr=args.learning_rate)
    else:
        if torch.cuda.is_available():
            optimizer = FusedLAMB(optimizer_grouped_parameters,
                                  lr=args.learning_rate)
        else:
            optimizer = NVLAMB(optimizer_grouped_parameters,
                               lr=args.learning_rate)

    lr_scheduler = PolyWarmUpScheduler(optimizer,
                                       warmup=args.warmup_proportion,
                                       total_steps=args.max_steps)
    if args.fp16:

        if args.loss_scale == 0:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level="O2",
                                              loss_scale="dynamic",
                                              cast_model_outputs=torch.float16)
        else:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level="O2",
                                              loss_scale=args.loss_scale,
                                              cast_model_outputs=torch.float16)
        amp._amp_state.loss_scalers[0]._loss_scale = args.init_loss_scale

    model.checkpoint_activations(args.checkpoint_activations)

    if args.resume_from_checkpoint:
        if args.phase2 or args.init_checkpoint:
            keys = list(checkpoint['optimizer']['state'].keys())
            #Override hyperparameters from previous checkpoint
            for key in keys:
                checkpoint['optimizer']['state'][key]['step'] = global_step
            for iter, item in enumerate(
                    checkpoint['optimizer']['param_groups']):
                checkpoint['optimizer']['param_groups'][iter][
                    'step'] = global_step
                checkpoint['optimizer']['param_groups'][iter][
                    't_total'] = args.max_steps
                checkpoint['optimizer']['param_groups'][iter][
                    'warmup'] = args.warmup_proportion
                checkpoint['optimizer']['param_groups'][iter][
                    'lr'] = args.learning_rate
        optimizer.load_state_dict(checkpoint['optimizer'])  # , strict=False)

        # Restore AMP master parameters
        if args.fp16:
            optimizer._lazy_init_maybe_master_weights()
            optimizer._amp_stash.lazy_init_called = True
            optimizer.load_state_dict(checkpoint['optimizer'])
            for param, saved_param in zip(amp.master_params(optimizer),
                                          checkpoint['master params']):
                param.data.copy_(saved_param.data)

    if args.local_rank != -1:
        if not args.allreduce_post_accumulation:
            if not args.use_jit_trace:
                if args.use_habana:
                    model = DDP(model)
                else:
                    model = DDP(model,
                                message_size=250000000,
                                gradient_predivide_factor=get_world_size())
        else:
            flat_dist_call([param.data for param in model.parameters()],
                           torch.distributed.broadcast, (0, ))
    elif args.n_pu > 1:
        model = torch.nn.DataParallel(model)

    criterion = BertPretrainingCriterion(config.vocab_size)

    return model, optimizer, lr_scheduler, checkpoint, global_step, criterion
Exemple #4
0
def main():
    global timeout_sent

    args = parse_arguments()

    random.seed(args.seed + args.local_rank)
    np.random.seed(args.seed + args.local_rank)
    torch.manual_seed(args.seed + args.local_rank)
    torch.cuda.manual_seed(args.seed + args.local_rank)
    worker_init = WorkerInitObj(args.seed + args.local_rank)

    device, args = setup_training(args)
    dllogger.log(step="PARAMETER", data={"Config": [str(args)]})

    # Prepare optimizer
    model, optimizer, lr_scheduler, checkpoint, global_step, criterion = prepare_model_and_optimizer(
        args, device)
    gradient_accumulation_steps = torch.tensor(
        args.gradient_accumulation_steps, dtype=torch.float32).to(device)
    world_size = torch.tensor(get_world_size(), dtype=torch.float32).to(device)

    if is_main_process():
        dllogger.log(step="PARAMETER", data={"SEED": args.seed})

    raw_train_start = None
    if args.do_train:
        if is_main_process():
            dllogger.log(step="PARAMETER", data={"train_start": True})
            dllogger.log(step="PARAMETER",
                         data={"batch_size_per_pu": args.train_batch_size})
            dllogger.log(step="PARAMETER",
                         data={"learning_rate": args.learning_rate})

        model.train()
        most_recent_ckpts_paths = []
        average_loss = 0.0  # averaged loss every args.log_freq steps
        epoch = 0
        training_steps = 0
        model_traced = False

        if device.type == 'cuda':
            pool = ProcessPoolExecutor(1)

        # Note: We loop infinitely over epochs, termination is handled via iteration count
        while True:
            thread = None
            restored_data_loader = None
            if not args.resume_from_checkpoint or epoch > 0 or (
                    args.phase2 and global_step < 1) or args.init_checkpoint:
                files = [
                    os.path.join(args.input_dir, f)
                    for f in os.listdir(args.input_dir)
                    if os.path.isfile(os.path.join(args.input_dir, f))
                    and 'training' in f
                ]
                files.sort()
                num_files = len(files)
                random.Random(args.seed + epoch).shuffle(files)
                f_start_id = 0
            else:
                f_start_id = checkpoint['files'][0]
                files = checkpoint['files'][1:]
                args.resume_from_checkpoint = False
                num_files = len(files)
                # may not exist in all checkpoints
                epoch = checkpoint.get('epoch', 0)
                restored_data_loader = checkpoint.get('data_loader', None)

            shared_file_list = {}

            if torch.distributed.is_initialized(
            ) and get_world_size() > num_files:
                remainder = get_world_size() % num_files
                data_file = files[(f_start_id * get_world_size() + get_rank() +
                                   remainder * f_start_id) % num_files]
            else:
                data_file = files[(f_start_id * get_world_size() + get_rank())
                                  % num_files]

            previous_file = data_file

            if restored_data_loader is None:
                use_pin_memory = False if args.no_cuda or args.use_habana else True
                num_workers = 0 if args.use_habana else 4
                train_data = pretraining_dataset(data_file,
                                                 args.max_predictions_per_seq)
                train_sampler = RandomSampler(train_data)
                train_dataloader = DataLoader(
                    train_data,
                    sampler=train_sampler,
                    batch_size=args.train_batch_size * args.n_pu,
                    num_workers=num_workers,
                    worker_init_fn=worker_init,
                    pin_memory=use_pin_memory,
                    drop_last=True)
                # shared_file_list["0"] = (train_dataloader, data_file)
            else:
                train_dataloader = restored_data_loader
                restored_data_loader = None

            overflow_buf = None
            if args.allreduce_post_accumulation:
                overflow_buf = torch.cuda.IntTensor([0])

            for f_id in range(f_start_id + 1, len(files)):

                if get_world_size() > num_files:
                    data_file = files[(f_id * get_world_size() + get_rank() +
                                       remainder * f_id) % num_files]
                else:
                    data_file = files[(f_id * get_world_size() + get_rank()) %
                                      num_files]

                previous_file = data_file

                if device.type == 'cuda':
                    dataset_future = pool.submit(create_pretraining_dataset,
                                                 data_file,
                                                 args.max_predictions_per_seq,
                                                 shared_file_list, args,
                                                 worker_init)

                train_iter = tqdm(train_dataloader,
                                  desc="Iteration",
                                  disable=args.disable_progress_bar
                                  ) if is_main_process() else train_dataloader

                if raw_train_start is None:
                    raw_train_start = time.time()
                for step, batch in enumerate(train_iter):

                    training_steps += 1
                    position_ids = compute_position_ids(batch[0])
                    if torch.distributed.is_initialized():
                        torch.distributed.barrier()
                    if args.use_habana:
                        batch = [t.to(dtype=torch.int32) for t in batch]
                        position_ids = position_ids.to(dtype=torch.int32)

                    position_ids = position_ids.to(device)
                    batch = [t.to(device) for t in batch]
                    input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch

                    if args.use_jit_trace:
                        if model_traced == False:
                            model = torch.jit.trace(model,
                                                    (input_ids, segment_ids,
                                                     input_mask, position_ids),
                                                    check_trace=False)
                            model_traced = True
                            if args.local_rank != -1 and not args.allreduce_post_accumulation:
                                if args.use_habana:
                                    model = DDP(model)
                                else:
                                    model = DDP(model,
                                                message_size=250000000,
                                                gradient_predivide_factor=
                                                get_world_size())
                        if args.local_rank != -1 and not args.allreduce_post_accumulation \
                                and (training_steps % args.gradient_accumulation_steps != 0):
                            with model.no_sync():
                                prediction_scores, seq_relationship_score = model(
                                    input_ids, segment_ids, input_mask,
                                    position_ids)
                        else:
                            prediction_scores, seq_relationship_score = model(
                                input_ids, segment_ids, input_mask,
                                position_ids)
                    else:
                        if args.local_rank != -1 and not args.allreduce_post_accumulation \
                                and (training_steps % args.gradient_accumulation_steps != 0):
                            with model.no_sync():
                                prediction_scores, seq_relationship_score = model(
                                    input_ids=input_ids,
                                    token_type_ids=segment_ids,
                                    attention_mask=input_mask,
                                    position_ids=position_ids)
                        else:
                            prediction_scores, seq_relationship_score = model(
                                input_ids=input_ids,
                                token_type_ids=segment_ids,
                                attention_mask=input_mask,
                                position_ids=position_ids)

                    loss = criterion(prediction_scores, seq_relationship_score,
                                     masked_lm_labels, next_sentence_labels)
                    if args.n_pu > 1:
                        loss = loss.mean()  # mean() to average on multi-pu.

                    divisor = args.gradient_accumulation_steps
                    if args.gradient_accumulation_steps > 1:
                        if not args.allreduce_post_accumulation:
                            # this division was merged into predivision
                            loss = loss / gradient_accumulation_steps
                            divisor = 1.0
                    if args.fp16:
                        with amp.scale_loss(
                                loss,
                                optimizer,
                                delay_overflow_check=args.
                                allreduce_post_accumulation) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()
                    average_loss += loss.item()

                    if training_steps % args.gradient_accumulation_steps == 0:
                        lr_scheduler.step()  # learning rate warmup
                        global_step = take_optimizer_step(
                            args, optimizer, model, overflow_buf, global_step)

                    if global_step >= args.steps_this_run or timeout_sent:
                        train_time_raw = time.time() - raw_train_start
                        last_num_steps = int(
                            training_steps /
                            args.gradient_accumulation_steps) % args.log_freq
                        last_num_steps = args.log_freq if last_num_steps == 0 else last_num_steps
                        average_loss = average_loss / (last_num_steps *
                                                       divisor)
                        average_loss = torch.tensor(
                            average_loss, dtype=torch.float32).to(device)
                        if (torch.distributed.is_initialized()):
                            average_loss /= world_size
                            torch.distributed.all_reduce(average_loss)
                        final_loss = average_loss.item()
                        if is_main_process():
                            dllogger.log(step=(
                                epoch,
                                global_step,
                            ),
                                         data={"final_loss": final_loss})
                    elif training_steps % (
                            args.log_freq *
                            args.gradient_accumulation_steps) == 0:
                        if is_main_process():
                            dllogger.log(
                                step=(
                                    epoch,
                                    global_step,
                                ),
                                data={
                                    "average_loss":
                                    average_loss / (args.log_freq * divisor),
                                    "step_loss":
                                    loss.item() *
                                    args.gradient_accumulation_steps / divisor,
                                    "learning_rate":
                                    optimizer.param_groups[0]['lr']
                                })
                        average_loss = 0

                    if global_step >= args.steps_this_run or training_steps % (
                            args.num_steps_per_checkpoint * args.
                            gradient_accumulation_steps) == 0 or timeout_sent:
                        if is_main_process() and not args.skip_checkpoint:
                            # Save a trained model
                            dllogger.log(step="PARAMETER",
                                         data={"checkpoint_step": global_step})
                            model_to_save = model.module if hasattr(
                                model, 'module'
                            ) else model  # Only save the model it-self
                            if args.resume_step < 0 or not args.phase2:
                                output_save_file = os.path.join(
                                    args.output_dir,
                                    "ckpt_{}.pt".format(global_step))
                            else:
                                output_save_file = os.path.join(
                                    args.output_dir,
                                    "ckpt_{}.pt".format(global_step +
                                                        args.phase1_end_step))
                            checkpoint_dict = {}
                            if args.do_train:
                                if args.use_habana:
                                    config = modeling.BertConfig.from_json_file(
                                        args.config_file)

                                    # Padding for divisibility by 8
                                    if config.vocab_size % 8 != 0:
                                        config.vocab_size += 8 - (
                                            config.vocab_size % 8)

                                    model_copy = modeling.BertForPreTraining(
                                        config)
                                    model_copy.load_state_dict(
                                        model_to_save.state_dict())

                                    param_groups_copy = optimizer.state_dict(
                                    )['param_groups']
                                    state_dict_copy = {}
                                    for st_key, st_val in optimizer.state_dict(
                                    )['state'].items():
                                        st_val_copy = {}
                                        for k, v in st_val.items():
                                            if isinstance(v, torch.Tensor):
                                                st_val_copy[k] = v.to('cpu')
                                            else:
                                                st_val_copy[k] = v
                                            state_dict_copy[
                                                st_key] = st_val_copy
                                    optim_dict = {}
                                    optim_dict['state'] = state_dict_copy
                                    optim_dict[
                                        'param_groups'] = param_groups_copy
                                    checkpoint_dict = {
                                        'model':
                                        model_copy.state_dict(),
                                        'optimizer':
                                        optim_dict,
                                        'files': [f_id] + files,
                                        'epoch':
                                        epoch,
                                        'data_loader':
                                        None if global_step >= args.max_steps
                                        else train_dataloader
                                    }
                                elif no_cuda:
                                    checkpoint_dict = {
                                        'model':
                                        model_to_save.state_dict(),
                                        'optimizer':
                                        optimizer.state_dict(),
                                        'files': [f_id] + files,
                                        'epoch':
                                        epoch,
                                        'data_loader':
                                        None if global_step >= args.max_steps
                                        else train_dataloader
                                    }
                                else:
                                    checkpoint_dict = {
                                        'model':
                                        model_to_save.state_dict(),
                                        'optimizer':
                                        optimizer.state_dict(),
                                        'master params':
                                        list(amp.master_params(optimizer)),
                                        'files': [f_id] + files,
                                        'epoch':
                                        epoch,
                                        'data_loader':
                                        None if global_step >= args.max_steps
                                        else train_dataloader
                                    }

                                torch.save(checkpoint_dict, output_save_file)
                                most_recent_ckpts_paths.append(
                                    output_save_file)
                                if len(most_recent_ckpts_paths) > 3:
                                    ckpt_to_be_removed = most_recent_ckpts_paths.pop(
                                        0)
                                    os.remove(ckpt_to_be_removed)

                        # Exiting the training due to hitting max steps, or being sent a
                        # timeout from the cluster scheduler
                        if global_step >= args.steps_this_run or timeout_sent:
                            del train_dataloader
                            # thread.join()
                            return args, final_loss, train_time_raw, global_step

                del train_dataloader
                # thread.join()
                # Make sure pool has finished and switch train_dataloader
                # NOTE: Will block until complete
                if device.type == 'cuda':
                    train_dataloader, data_file = dataset_future.result(
                        timeout=None)
                else:
                    train_dataloader, data_file = create_pretraining_dataset(
                        data_file, args.max_predictions_per_seq,
                        shared_file_list, args, worker_init)

            epoch += 1
Exemple #5
0
  MAX_PREDICTIONS_PER_SEQ = 20
  MAX_SEQ_LENGTH = 128
  DO_LOWER_CASE = True


  # LEARNING_RATE = 2e-5
  # NUM_TRAIN_STEPS = 1
  # NUM_WARMUP_STEPS = 10
  # USE_TPU = False
  # BATCH_SIZE = 1


  # load model
  bert_config = modeling.BertConfig(BERT_CONFIG_FILE)
  device = torch.device("cpu")
  model1 = modeling.BertForPreTraining(bert_config)
  # model2 = modeling.BertForPreTraining(bert_config)

  model1.load_state_dict(torch.load(INIT_CHECKPOINT_PT, map_location='cpu'))
  # model1.bert.from_pretrained(INIT_DIRECTORY)
  model1.to(device)
  print ('model loaded')


  #resolve features
  with open(INPUT_FILE, 'rb') as f:
    features = pickle.load(f)

  print ("%d total samples" % len(features))

  all_input_ids = torch.tensor([f['input_ids'] for f in features], dtype=torch.long)
Exemple #6
0
def prepare_model_and_optimizer(args, device):

    # Prepare model
    config = modeling.BertConfig.from_json_file(args.config_file)

    # Padding for divisibility by 8
    if config.vocab_size % 8 != 0:
        config.vocab_size += 8 - (config.vocab_size % 8)

    modeling.ACT2FN["bias_gelu"] = torch.jit.script(
        modeling.ACT2FN["bias_gelu"])
    model = modeling.BertForPreTraining(config)

    checkpoint = None
    if not args.resume_from_checkpoint:
        global_step = 0
    else:
        if args.resume_step == -1 and not args.init_checkpoint:
            model_names = [
                f for f in os.listdir(args.output_dir) if f.endswith(".pt")
            ]
            args.resume_step = max([
                int(x.split('.pt')[0].split('_')[1].strip())
                for x in model_names
            ])

        global_step = args.resume_step if not args.init_checkpoint else 0

        if not args.init_checkpoint:
            checkpoint = torch.load(os.path.join(
                args.output_dir, "ckpt_{}.pt".format(global_step)),
                                    map_location="cpu")
        else:
            checkpoint = torch.load(args.init_checkpoint, map_location="cpu")

        model.load_state_dict(checkpoint['model'], strict=False)

        if args.phase2 and not args.init_checkpoint:
            global_step -= args.phase1_end_step
        if is_main_process():
            print("resume step from ", args.resume_step)

    model.to(device)
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']

    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = FusedLAMB(optimizer_grouped_parameters, lr=args.learning_rate)
    lr_scheduler = PolyWarmUpScheduler(optimizer,
                                       warmup=args.warmup_proportion,
                                       total_steps=args.max_steps)
    if args.fp16:

        if args.loss_scale == 0:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level="O2",
                                              loss_scale="dynamic",
                                              cast_model_outputs=torch.float16)
        else:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level="O2",
                                              loss_scale=args.loss_scale,
                                              cast_model_outputs=torch.float16)
        amp._amp_state.loss_scalers[0]._loss_scale = 2**20

    if args.resume_from_checkpoint:
        if args.phase2 or args.init_checkpoint:
            keys = list(checkpoint['optimizer']['state'].keys())
            #Override hyperparameters from previous checkpoint
            for key in keys:
                checkpoint['optimizer']['state'][key]['step'] = global_step
            for iter, item in enumerate(
                    checkpoint['optimizer']['param_groups']):
                checkpoint['optimizer']['param_groups'][iter][
                    'step'] = global_step
                checkpoint['optimizer']['param_groups'][iter][
                    't_total'] = args.max_steps
                checkpoint['optimizer']['param_groups'][iter][
                    'warmup'] = args.warmup_proportion
                checkpoint['optimizer']['param_groups'][iter][
                    'lr'] = args.learning_rate
        optimizer.load_state_dict(checkpoint['optimizer'])  # , strict=False)

        # Restore AMP master parameters
        if args.fp16:
            optimizer._lazy_init_maybe_master_weights()
            optimizer._amp_stash.lazy_init_called = True
            optimizer.load_state_dict(checkpoint['optimizer'])
            for param, saved_param in zip(amp.master_params(optimizer),
                                          checkpoint['master params']):
                param.data.copy_(saved_param.data)

    if args.local_rank != -1:
        model = DDP(
            model,
            message_size=250000000,
            gradient_predivide_factor=torch.distributed.get_world_size())
    elif args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    criterion = BertPretrainingCriterion(config.vocab_size)

    return model, optimizer, lr_scheduler, checkpoint, global_step, criterion
def prepare_model_and_optimizer(args, device):

    # Prepare model
    config = modeling.BertConfig.from_json_file(args.config_file)

    # Padding for divisibility by 8
    if config.vocab_size % 8 != 0:
        config.vocab_size += 8 - (config.vocab_size % 8)

    if args.use_sequential > 0:
        config.use_sequential = True
    else:
        config.use_sequential = False

    modeling.ACT2FN["bias_gelu"] = modeling.bias_gelu_training
    model = modeling.BertForPreTraining(config)
    model.checkpoint_activations(args.checkpoint_activations)
    if args.smp > 0:
        # SMP: Use the DistributedModel container to provide the model
        # to be partitioned across different ranks. For the rest of the script,
        # the returned DistributedModel object should be used in place of
        # the model provided for DistributedModel class instantiation.
        model = smp.DistributedModel(model)

    checkpoint = None
    if not args.resume_from_checkpoint:
        global_step = 0
    else:
        if not args.init_checkpoint:
            if not args.s3_checkpoint_uri:
                raise ValueError(
                    "Need to set s3_checkpoint_uri, if init_checkpoint not set"
                )
            if smp.local_rank() == 0:
                sync_s3_checkpoints_to_local(args.output_dir,
                                             args.s3_checkpoint_uri)
            smp.barrier()
        if args.resume_step == -1 and not args.init_checkpoint:
            model_names = [
                f for f in os.listdir(args.output_dir) if ".pt" in f
            ]
            args.resume_step = max([
                int(x.split(".pt")[0].split("_")[1].strip())
                for x in model_names
            ])

        global_step = args.resume_step if not args.init_checkpoint else 0

        # SMP: Load a model that was saved with smp.save
        if not args.init_checkpoint:
            checkpoint = smp.load(
                os.path.join(args.output_dir,
                             "ckpt_{}.pt".format(global_step)),
                partial=args.partial_checkpoint,
            )
        else:
            checkpoint = smp.load(args.init_checkpoint)

        model.load_state_dict(checkpoint["model"], strict=False)

        if args.phase2 and not args.init_checkpoint:
            global_step -= args.phase1_end_step
        if is_main_process():
            print("resume step from ", args.resume_step)

    model.to(device)
    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "gamma", "beta", "LayerNorm"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.01,
        },
        {
            "params":
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            "weight_decay":
            0.0,
        },
    ]

    optimizer = FusedLAMB(optimizer_grouped_parameters, lr=args.learning_rate)
    if args.smp > 0:
        # SMP: Use Distributed Optimizer which allows the loading of optimizer state for a distributed model
        # Also provides APIs to obtain local optimizer state for the current mp_rank.
        optimizer = smp.DistributedOptimizer(optimizer)
    lr_scheduler = PolyWarmUpScheduler(optimizer,
                                       warmup=args.warmup_proportion,
                                       total_steps=args.max_steps)

    if args.fp16:
        if args.loss_scale == 0:
            model, optimizer = amp.initialize(
                model,
                optimizer,
                opt_level="O2",
                loss_scale="dynamic",
                cast_model_outputs=torch.float16,
            )
        else:
            model, optimizer = amp.initialize(
                model,
                optimizer,
                opt_level="O2",
                loss_scale=args.loss_scale,
                cast_model_outputs=torch.float16,
            )
        amp._amp_state.loss_scalers[0]._loss_scale = args.init_loss_scale

    if args.resume_from_checkpoint:
        if args.phase2 or args.init_checkpoint:
            keys = list(checkpoint["optimizer"]["state"].keys())
            # Override hyperparameters from previous checkpoint
            for key in keys:
                checkpoint["optimizer"]["state"][key]["step"] = global_step
            for iter, item in enumerate(
                    checkpoint["optimizer"]["param_groups"]):
                checkpoint["optimizer"]["param_groups"][iter][
                    "step"] = global_step
                checkpoint["optimizer"]["param_groups"][iter][
                    "t_total"] = args.max_steps
                checkpoint["optimizer"]["param_groups"][iter][
                    "warmup"] = args.warmup_proportion
                checkpoint["optimizer"]["param_groups"][iter][
                    "lr"] = args.learning_rate
        optimizer.load_state_dict(checkpoint["optimizer"])  # , strict=False)
        # Restore AMP master parameters
        if args.fp16:
            optimizer._lazy_init_maybe_master_weights()
            optimizer._amp_stash.lazy_init_called = True
            optimizer.load_state_dict(checkpoint["optimizer"])
            for param, saved_param in zip(amp.master_params(optimizer),
                                          checkpoint["master params"]):
                param.data.copy_(saved_param.data)

    # if args.local_rank != -1:
    #    if not args.allreduce_post_accumulation:
    #        model = DDP(model, message_size=250000000, gradient_predivide_factor=get_world_size())
    #    else:
    #        flat_dist_call([param.data for param in model.parameters()], torch.distributed.broadcast, (0,) )
    # elif args.n_gpu > 1:
    #    model = torch.nn.DataParallel(model)

    criterion = BertPretrainingCriterion(config.vocab_size)

    return model, optimizer, lr_scheduler, checkpoint, global_step, criterion
def prepare_model_and_optimizer(args, device, sequence_output_is_dense):

    # Prepare model
    config = modeling.BertConfig.from_json_file(args.config_file)

    # Padding for divisibility by 8
    if config.vocab_size % 8 != 0:
        config.vocab_size += 8 - (config.vocab_size % 8)

    model = modeling.BertForPreTraining(config, sequence_output_is_dense=sequence_output_is_dense)

    checkpoint = None
    if not args.resume_from_checkpoint:
        global_step = 0
    else:
        if args.resume_step == -1 and not args.init_checkpoint:
            model_names = [f for f in os.listdir(args.output_dir) if f.endswith(".pt")]
            args.resume_step = max([int(x.split('.pt')[0].split('_')[1].strip()) for x in model_names])

        global_step = args.resume_step if not args.init_checkpoint else 0

        if not args.init_checkpoint:
            checkpoint = torch.load(os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step)), map_location=device)
        else:
            checkpoint = torch.load(args.init_checkpoint, map_location=device)

        model.load_state_dict(checkpoint['model'], strict=False)

        if args.phase2 and not args.init_checkpoint:
            global_step -= args.phase1_end_step
        if is_main_process():
            print("resume step from ", args.resume_step)

    model.to(device)

    # If allreduce_post_accumulation_fp16 is not set, Native AMP Autocast is
    # used along with FP32 gradient accumulation and all-reduce
    if args.fp16 and args.allreduce_post_accumulation_fp16:
        model.half()

    if not args.disable_jit_fusions :
        model = torch.jit.script(model)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']

    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]

    optimizer = FusedLAMBAMP(optimizer_grouped_parameters,
                             lr=args.learning_rate)
    lr_scheduler = PolyWarmUpScheduler(optimizer,
                                       warmup=args.warmup_proportion,
                                       total_steps=args.max_steps,
                                       base_lr=args.learning_rate,
                                       device=device)
    grad_scaler = torch.cuda.amp.GradScaler(init_scale=args.init_loss_scale, enabled=args.fp16)

    model.checkpoint_activations(args.checkpoint_activations)

    if args.resume_from_checkpoint:
        # For phase2, need to reset the learning rate and step count in the checkpoint
        if args.phase2 or args.init_checkpoint :
            for group in checkpoint['optimizer']['param_groups'] :
                group['step'].zero_()
                group['lr'].fill_(args.learning_rate)
        else :
            if 'grad_scaler' in checkpoint and not args.phase2:
                grad_scaler.load_state_dict(checkpoint['grad_scaler'])
        optimizer.load_state_dict(checkpoint['optimizer'])  # , strict=False)

    if args.local_rank != -1:
        # Cuda Graphs requires that DDP is captured on a side stream
        # It is important to synchronize the streams after the DDP initialization
        # so anything after sees properly initialized model weights across GPUs
        side_stream = torch.cuda.Stream()
        with torch.cuda.stream(side_stream) :
            model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, bucket_cap_mb=torch.cuda.get_device_properties(device).total_memory, gradient_as_bucket_view=True)
        torch.cuda.current_stream().wait_stream(side_stream)

        from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import allreduce_hook
        def scale_by_grad_accum_steps_wrapper(hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:

            def scale_by_grad_accum_steps_wrapper_hook(
                hook_state, bucket: dist.GradBucket
            ) -> torch.futures.Future[torch.Tensor]:
                bucket.set_buffer(bucket.buffer().div_(args.gradient_accumulation_steps))
                fut = hook(hook_state, bucket)
                return fut

            return scale_by_grad_accum_steps_wrapper_hook

        # With gradient accumulation, the DDP comm hook divides the gradients by the number
        # gradient accumulation steps
        if args.gradient_accumulation_steps > 1:
            model.register_comm_hook(None, scale_by_grad_accum_steps_wrapper(allreduce_hook))

    optimizer.setup_fp32_params()

    criterion = BertPretrainingCriterion(config.vocab_size, sequence_output_is_dense=sequence_output_is_dense)

    if args.resume_from_checkpoint and args.init_checkpoint:
        start_epoch = checkpoint['epoch']
    else:
        start_epoch = 0

    return model, optimizer, grad_scaler, lr_scheduler, checkpoint, global_step, criterion, start_epoch
def main(argv):
    if len(argv) > 1:
        raise app.UsageError("Too many command-line arguments.")

    config = FLAGS.config

    input_files = sum([glob.glob(pattern) for pattern in config.input_files],
                      [])
    assert input_files, "No input files!"
    print(f"Training with {len(input_files)} input files, including:")
    print(f" - {input_files[0]}")

    model = modeling.BertForPreTraining(config=config.model)
    initial_params = get_initial_params(model,
                                        init_checkpoint=config.init_checkpoint)
    optimizer = create_optimizer(config, initial_params)
    del initial_params  # the optimizer takes ownership of all params

    output_dir = get_output_dir(config)
    gfile.makedirs(output_dir)

    # Restore from a local checkpoint, if one exists.
    optimizer = checkpoints.restore_checkpoint(output_dir, optimizer)
    if isinstance(optimizer.state, (list, tuple)):
        start_step = int(optimizer.state[0].step)
    else:
        start_step = int(optimizer.state.step)

    optimizer = optimizer.replicate()
    optimizer = training.harmonize_across_hosts(optimizer)

    data_pipeline = data.PretrainingDataPipeline(
        sum([glob.glob(pattern) for pattern in config.input_files], []),
        config.tokenizer,
        max_seq_length=config.max_seq_length,
        max_predictions_per_seq=config.max_predictions_per_seq,
    )

    learning_rate_fn = training.create_learning_rate_scheduler(
        factors="constant * linear_warmup * linear_decay",
        base_learning_rate=config.learning_rate,
        warmup_steps=config.num_warmup_steps,
        steps_per_cycle=config.num_train_steps - config.num_warmup_steps,
    )

    train_history = training.TrainStateHistory(learning_rate_fn)
    train_state = train_history.initial_state()

    if config.do_train:
        train_batch_size = config.train_batch_size
        if jax.host_count() > 1:
            assert (train_batch_size % jax.host_count() == 0
                    ), "train_batch_size must be divisible by number of hosts"
            train_batch_size = train_batch_size // jax.host_count()
        train_iter = data_pipeline.get_inputs(batch_size=train_batch_size,
                                              training=True)
        train_step_fn = training.create_train_step(
            model,
            compute_pretraining_loss_and_metrics,
            max_grad_norm=config.max_grad_norm,
        )

        for step, batch in zip(range(start_step, config.num_train_steps),
                               train_iter):
            optimizer, train_state = train_step_fn(optimizer, batch,
                                                   train_state)
            if jax.host_id() == 0 and (step % config.save_checkpoints_steps
                                       == 0
                                       or step == config.num_train_steps - 1):
                checkpoints.save_checkpoint(output_dir,
                                            optimizer.unreplicate(), step)
                config_path = os.path.join(output_dir, "config.json")
                if not os.path.exists(config_path):
                    with open(config_path, "w") as f:
                        json.dump({"model_type": "bert", **config.model}, f)
                tokenizer_path = os.path.join(output_dir,
                                              "sentencepiece.model")
                if not os.path.exists(tokenizer_path):
                    shutil.copy(config.tokenizer, tokenizer_path)

        # With the current Rust data pipeline code, running more than one pipeline
        # at a time will lead to a hang. A simple workaround is to fully delete the
        # training pipeline before potentially starting another for evaluation.
        del train_iter

    if config.do_eval:
        eval_iter = data_pipeline.get_inputs(batch_size=config.eval_batch_size)
        eval_iter = itertools.islice(eval_iter, config.max_eval_steps)
        eval_fn = training.create_eval_fn(model,
                                          compute_pretraining_stats,
                                          sample_feature_name="input_ids")
        eval_stats = eval_fn(optimizer, eval_iter)

        eval_metrics = {
            "loss":
            jnp.mean(eval_stats["loss"]),
            "masked_lm_loss":
            jnp.mean(eval_stats["masked_lm_loss"]),
            "next_sentence_loss":
            jnp.mean(eval_stats["next_sentence_loss"]),
            "masked_lm_accuracy":
            jnp.sum(eval_stats["masked_lm_correct"]) /
            jnp.sum(eval_stats["masked_lm_total"]),
            "next_sentence_accuracy":
            jnp.sum(eval_stats["next_sentence_correct"]) /
            jnp.sum(eval_stats["next_sentence_total"]),
        }

        eval_results = []
        for name, val in sorted(eval_metrics.items()):
            line = f"{name} = {val:.06f}"
            print(line, flush=True)
            eval_results.append(line)

        eval_results_path = os.path.join(output_dir, "eval_results.txt")
        with gfile.GFile(eval_results_path, "w") as f:
            for line in eval_results:
                f.write(line + "\n")