def train(args):

    label_name = ['Premise', 'Claim', 'None', 'MajorClaim']

    device = torch.device("cuda:0" if args['--cuda'] else "cpu")

    prefix = args['MODEL'] + '_' + args['BERT_CONFIG']

    bert_size = args['BERT_CONFIG'].split('-')[1]

    start_time = time.time()
    print('Importing data...', file=sys.stderr)
    df_train = pd.read_csv(args['--train'], index_col=0)
    df_val = pd.read_csv(args['--dev'], index_col=0)
    train_label = dict(df_train.InformationType_label.value_counts())
    label_max = float(max(train_label.values()))
    train_label_weight = torch.tensor(
        [label_max / train_label[i] for i in range(len(train_label))],
        device=device)
    print('Done! time elapsed %.2f sec' % (time.time() - start_time),
          file=sys.stderr)
    print('-' * 80, file=sys.stderr)

    start_time = time.time()
    print('Set up model...', file=sys.stderr)

    if args['MODEL'] == 'default':
        model = DefaultModel(args['BERT_CONFIG'], device, len(label_name))
        parameters = [{
            'params': model.bert.bert.parameters()
        }, {
            'params': model.bert.classifier.parameters(),
            'lr': float(args['--lr'])
        }]
        optimizer = AdamW(parameters,
                          lr=float(args['--lr-bert']),
                          correct_bias=False)
        scheduler = get_linear_schedule_with_warmup(optimizer=optimizer,
                                                    num_warmup_steps=100,
                                                    num_training_steps=1000)
    elif args['MODEL'] == 'nonlinear':
        model = NonlinearModel(args['BERT_CONFIG'], device, len(label_name),
                               float(args['--dropout']))
        parameters = [{
            'params': model.bert.parameters()
        }, {
            'params': model.linear1.parameters(),
            'lr': float(args['--lr'])
        }, {
            'params': model.linear2.parameters(),
            'lr': float(args['--lr'])
        }, {
            'params': model.linear3.parameters(),
            'lr': float(args['--lr'])
        }]
        optimizer = AdamW(parameters,
                          lr=float(args['--lr-bert']),
                          correct_bias=False)
        scheduler = get_linear_schedule_with_warmup(optimizer=optimizer,
                                                    num_warmup_steps=100,
                                                    num_training_steps=1000)
    elif args['MODEL'] == 'lstm':
        model = CustomBertLSTMModel(args['BERT_CONFIG'],
                                    device,
                                    float(args['--dropout']),
                                    len(label_name),
                                    lstm_hidden_size=int(
                                        args['--hidden-size']))
        parameters = [{
            'params': model.bert.parameters()
        }, {
            'params': model.lstm.parameters(),
            'lr': float(args['--lr'])
        }, {
            'params': model.hidden_to_softmax.parameters(),
            'lr': float(args['--lr'])
        }]
        optimizer = AdamW(parameters,
                          lr=float(args['--lr-bert']),
                          correct_bias=False)
        scheduler = get_linear_schedule_with_warmup(optimizer=optimizer,
                                                    num_warmup_steps=100,
                                                    num_training_steps=1000)
    elif args['MODEL'] == 'cnn':
        model = CustomBertConvModel(args['BERT_CONFIG'],
                                    device,
                                    float(args['--dropout']),
                                    len(label_name),
                                    out_channel=int(args['--out-channel']))
        parameters = [{
            'params': model.bert.parameters()
        }, {
            'params': model.conv.parameters(),
            'lr': float(args['--lr'])
        }, {
            'params': model.hidden_to_softmax.parameters(),
            'lr': float(args['--lr'])
        }]
        optimizer = AdamW(parameters,
                          lr=float(args['--lr-bert']),
                          correct_bias=False)
        scheduler = get_linear_schedule_with_warmup(optimizer=optimizer,
                                                    num_warmup_steps=100,
                                                    num_training_steps=1000)
    else:
        print('please input a valid model')
        exit(0)

    model = model.to(device)
    print('Use device: %s' % device, file=sys.stderr)
    print('Done! time elapsed %.2f sec' % (time.time() - start_time),
          file=sys.stderr)
    print('-' * 80, file=sys.stderr)

    model.train()

    cn_loss = torch.nn.CrossEntropyLoss(weight=train_label_weight,
                                        reduction='mean')
    torch.save(cn_loss, 'loss_func')  # for later testing

    train_batch_size = int(args['--batch-size'])
    valid_niter = int(args['--valid-niter'])
    log_every = int(args['--log-every'])
    model_save_path = prefix + '_model.bin'

    num_trial = 0
    train_iter = patience = cum_loss = report_loss = 0
    cum_examples = report_examples = epoch = 0
    hist_valid_scores = []
    train_time = begin_time = time.time()
    print('Begin Maximum Likelihood training...')

    while True:
        epoch += 1

        for sents, targets in batch_iter(df_train,
                                         batch_size=train_batch_size,
                                         shuffle=True,
                                         bert=bert_size):  # for each epoch
            train_iter += 1

            optimizer.zero_grad()

            batch_size = len(sents)

            outputs = model(sents)
            pre_softmax = outputs

            loss = cn_loss(
                pre_softmax,
                torch.tensor(targets, dtype=torch.long, device=device))

            loss.backward()

            optimizer.step()
            scheduler.step()

            batch_losses_val = loss.item() * batch_size
            report_loss += batch_losses_val
            cum_loss += batch_losses_val

            report_examples += batch_size
            cum_examples += batch_size

            if train_iter % log_every == 0:
                print('epoch %d, iter %d, avg. loss %.2f, '
                      'cum. examples %d, speed %.2f examples/sec, '
                      'time elapsed %.2f sec' %
                      (epoch, train_iter, report_loss / report_examples,
                       cum_examples, report_examples /
                       (time.time() - train_time), time.time() - begin_time),
                      file=sys.stderr)

                train_time = time.time()
                report_loss = report_examples = 0.

            # perform validation
            if train_iter % valid_niter == 0:
                print(
                    'epoch %d, iter %d, cum. loss %.2f, cum. examples %d' %
                    (epoch, train_iter, cum_loss / cum_examples, cum_examples),
                    file=sys.stderr)

                cum_loss = cum_examples = 0.

                print('begin validation ...', file=sys.stderr)

                validation_loss = validation(
                    model, df_val, bert_size, cn_loss,
                    device)  # dev batch size can be a bit larger

                print('validation: iter %d, loss %f' %
                      (train_iter, validation_loss),
                      file=sys.stderr)

                is_better = len(
                    hist_valid_scores
                ) == 0 or validation_loss < min(hist_valid_scores)
                hist_valid_scores.append(validation_loss)

                if is_better:
                    patience = 0
                    print('save currently the best model to [%s]' %
                          model_save_path,
                          file=sys.stderr)

                    model.save(model_save_path)

                    # also save the optimizers' state
                    torch.save(optimizer.state_dict(),
                               model_save_path + '.optim')
                elif patience < int(args['--patience']):
                    patience += 1
                    print('hit patience %d' % patience, file=sys.stderr)

                    if patience == int(args['--patience']):
                        num_trial += 1
                        print('hit #%d trial' % num_trial, file=sys.stderr)
                        if num_trial == int(args['--max-num-trial']):
                            print('early stop!', file=sys.stderr)
                            exit(0)

                        # decay lr, and restore from previously best checkpoint
                        print(
                            'load previously best model and decay learning rate to %f%%'
                            % (float(args['--lr-decay']) * 100),
                            file=sys.stderr)

                        # load model
                        params = torch.load(
                            model_save_path,
                            map_location=lambda storage, loc: storage)
                        model.load_state_dict(params['state_dict'])
                        model = model.to(device)

                        print('restore parameters of the optimizers',
                              file=sys.stderr)
                        optimizer.load_state_dict(
                            torch.load(model_save_path + '.optim'))

                        # set new lr
                        for param_group in optimizer.param_groups:
                            param_group['lr'] *= float(args['--lr-decay'])

                        # reset patience
                        patience = 0

                if epoch == int(args['--max-epoch']):
                    print('reached maximum number of epochs!', file=sys.stderr)
                    exit(0)
Esempio n. 2
0
def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, train_dataset_second, DP_classifier) -> Tuple[int, float]:
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    def collate(examples: List[torch.Tensor]):
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
        return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)
    

    train_sampler = SequentialSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(
        train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate
    )
    
    correct_sampler = SequentialSampler(train_dataset_second) if args.local_rank == -1 else DistributedSampler(train_dataset_second)
    correct_dataloader = DataLoader(
        train_dataset_second, sampler=correct_sampler, batch_size=args.train_batch_size, collate_fn=collate
    )

    wrong_sampler = RandomSampler(train_dataset_second) if args.local_rank == -1 else DistributedSampler(train_dataset_second)
    wrong_dataloader = DataLoader(
        train_dataset_second, sampler=wrong_sampler, batch_size=args.train_batch_size, collate_fn=collate
    )


    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)] + [p for n, p in DP_classifier.named_parameters()],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
    #scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if (
        args.model_name_or_path
        and os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt"))
        and os.path.isfile(os.path.join(args.model_name_or_path, "scheduler.pt"))
    ):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
        )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size
        * args.gradient_accumulation_steps
        * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if args.model_name_or_path and os.path.exists(args.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d", global_step)
            logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0

    model_to_resize = model.module if hasattr(model, "module") else model  # Take care of distributed/parallel training
    model_to_resize.resize_token_embeddings(len(tokenizer))

    model.zero_grad()
    DP_classifier.zero_grad()
    train_iterator = trange(
        epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
    )
    set_seed(args)  # Added here for reproducibility
    
    zipped_data = zip(train_dataloader, correct_dataloader, wrong_dataloader)
    
    correct_mc_tensor = torch.ones(args.train_batch_size, dtype=torch.float)
    correct_mc_tensor = correct_mc_tensor.to(args.device)
    wrong_mc_tensor = torch.zeros(args.train_batch_size, dtype=torch.float)
    wrong_mc_tensor = wrong_mc_tensor.to(args.device)
    
    print(correct_mc_tensor)
    print(wrong_mc_tensor)
    
    accumulated_lm_loss = 0.0
    accumulated_mc_loss = 0.0
    
    for _ in train_iterator:
        train_sampler = SequentialSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
        train_dataloader = DataLoader(
            train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate
        )
        
        correct_sampler = SequentialSampler(train_dataset_second) if args.local_rank == -1 else DistributedSampler(train_dataset_second)
        correct_dataloader = DataLoader(
            train_dataset_second, sampler=correct_sampler, batch_size=args.train_batch_size, collate_fn=collate
        )
    
        wrong_sampler = RandomSampler(train_dataset_second) if args.local_rank == -1 else DistributedSampler(train_dataset_second)
        wrong_dataloader = DataLoader(
            train_dataset_second, sampler=wrong_sampler, batch_size=args.train_batch_size, collate_fn=collate
        )
        zipped_data = zip(train_dataloader, correct_dataloader, wrong_dataloader)
        epoch_iterator = tqdm(zipped_data, desc="Iteration", disable=args.local_rank not in [-1, 0], total=len(train_dataloader))
        for step, zipped_batch in enumerate(epoch_iterator):
            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model.train()
            DP_classifier.train()
            

            # unpack zipped_batch
            batch, correct_batch, wrong_batch = zipped_batch
            
                        
            # First: original sentence
            inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
            labels = inputs.clone()
            
            cls_pos = []
            for curr in labels:
                for idx, tk in enumerate(curr):
                    if tk == tokenizer.cls_token_id:
                        curr[idx] = -100
                        cls_pos.append(idx)
                        break
                
                
            
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            
            outputs = model(inputs, lm_labels=labels)
            loss_lm_1 = outputs[0]
            hidden_1 = outputs[3]
            
            sentence_embed_1_pieces = [hh[cls_pos[idx]].unsqueeze(0) for idx, hh in enumerate(hidden_1)]
            sentence_embed_1 = torch.cat(sentence_embed_1_pieces)
            
            
            
            
            # Second: correct next sentence
            correct_input = correct_batch
            correct_labels = correct_input.clone()
            
            cls_pos = []
            for curr in correct_labels:
                for idx, tk in enumerate(curr):
                    if tk == tokenizer.cls_token_id:
                        curr[idx] = -100
                        cls_pos.append(idx)
                        break
            
            
            
            correct_input = correct_input.to(args.device)
            correct_labels = correct_labels.to(args.device)            
            
            outputs = model(correct_input, lm_labels=correct_labels)

            loss_lm_2 = outputs[0]
            hidden_2 = outputs[3]
            sentence_embed_2_pieces = [hh[cls_pos[idx]].unsqueeze(0) for idx, hh in enumerate(hidden_2)]
            sentence_embed_2 = torch.cat(sentence_embed_2_pieces)

            
            
            
            # Get correct loss
            if random.randint(0, 1) == 1:
                outputs = DP_classifier(sentence_embed_1, sentence_embed_2, correct_mc_tensor)
            else:
                outputs = DP_classifier(sentence_embed_2, sentence_embed_1, correct_mc_tensor)
            loss_mc = outputs[0]
            
            # MC_LOSS SCALING
            SCALING = 0.05
            loss_lm = loss_lm_1 + loss_lm_2
            
            
            #loss = loss_lm
            loss_first = loss_lm + SCALING * loss_mc
            #print("loss_mc: ", loss_mc.item())
            #print("loss_lm: ", loss_lm.item())
            
            accumulated_lm_loss += loss_lm.item() / 2.0
            accumulated_mc_loss += SCALING * loss_mc.item()
            
            # Second loss: wrong next sentence randomly sampled from training set
            wrong_input = wrong_batch
            wrong_labels = wrong_input.clone()
            
            cls_pos = []
            for curr in wrong_labels:
                for idx, tk in enumerate(curr):
                    if tk == tokenizer.cls_token_id:
                        curr[idx] = -100
                        cls_pos.append(idx)
                        break

            
            wrong_input = wrong_input.to(args.device)
            wrong_labels = wrong_labels.to(args.device)
            
            outputs = model(wrong_input, lm_labels=wrong_labels)

            loss_lm_3 = outputs[0]
            hidden_3 = outputs[3]
            sentence_embed_3_pieces = [hh[cls_pos[idx]].unsqueeze(0) for idx, hh in enumerate(hidden_3)]
            sentence_embed_3 = torch.cat(sentence_embed_3_pieces)

            
            if random.randint(0, 1) == 1:
                outputs = DP_classifier(sentence_embed_1, sentence_embed_3, wrong_mc_tensor)
            else:
                outputs = DP_classifier(sentence_embed_3, sentence_embed_1, wrong_mc_tensor)
            loss_mc = outputs[0]
            
            #loss = loss_lm
            loss_second = loss_lm_3 + SCALING * loss_mc
            #print("loss_mc: ", loss_mc.item())
            #print("loss_lm: ", loss_lm.item())
            accumulated_mc_loss += SCALING * loss_mc.item()
            
            # Total loss
            loss = loss_first + loss_second


            SKIP_STEP = 50
            if (step % SKIP_STEP == 0):
                print(' iter %d, avg. lm_loss %.2f, avg. mc_loss %.2f, avg. ppl %.2f ' % (step,
                                                                    accumulated_lm_loss / SKIP_STEP,
                                                                    accumulated_mc_loss / SKIP_STEP,
                                                                    math.exp(loss_lm.item() /2),
                                                                    ), file=sys.stderr)
                tb_writer.add_scalar("training_lm_loss", accumulated_lm_loss / SKIP_STEP, global_step)
                tb_writer.add_scalar("training_mc_loss", accumulated_mc_loss / SKIP_STEP, global_step)
                accumulated_lm_loss = 0.0
                accumulated_mc_loss = 0.0
                

            if args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                    torch.nn.utils.clip_grad_norm_(DP_classifier.parameters(), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                    torch.nn.utils.clip_grad_norm_(DP_classifier.parameters(), args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                DP_classifier.zero_grad()
                global_step += 1

                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if (
                        args.local_rank == -1 and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer, DP_classifier)
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value, global_step)
                    tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    checkpoint_prefix = "checkpoint"
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir, "{}-{}".format(checkpoint_prefix, global_step))
                    os.makedirs(output_dir, exist_ok=True)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args, os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    torch.save(DP_classifier, os.path.join(output_dir, "DP_classifier.bin"))
                    
                    _rotate_checkpoints(args, checkpoint_prefix)

                    torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
Esempio n. 3
0
    def train(
        self, train_dataset, output_dir, show_running_loss=True, eval_data=None, verbose=True, **kwargs,
    ):
        """
        Trains the model on train_dataset.

        Utility function to be used by the train_model() method. Not intended to be used directly.
        """

        model = self.model
        args = self.args
        device = self.device

        tb_writer = SummaryWriter(logdir=args.tensorboard_dir)
        train_sampler = RandomSampler(train_dataset)
        train_dataloader = DataLoader(
            train_dataset,
            sampler=train_sampler,
            batch_size=args.train_batch_size,
            num_workers=self.args.dataloader_num_workers,
        )

        if args.max_steps > 0:
            t_total = args.max_steps
            args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
        else:
            t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

        no_decay = ["bias", "LayerNorm.weight"]

        optimizer_grouped_parameters = []
        custom_parameter_names = set()
        for group in self.args.custom_parameter_groups:
            params = group.pop("params")
            custom_parameter_names.update(params)
            param_group = {**group}
            param_group["params"] = [p for n, p in model.named_parameters() if n in params]
            optimizer_grouped_parameters.append(param_group)

        for group in self.args.custom_layer_parameters:
            layer_number = group.pop("layer")
            layer = f"layer.{layer_number}."
            group_d = {**group}
            group_nd = {**group}
            group_nd["weight_decay"] = 0.0
            params_d = []
            params_nd = []
            for n, p in model.named_parameters():
                if n not in custom_parameter_names and layer in n:
                    if any(nd in n for nd in no_decay):
                        params_nd.append(p)
                    else:
                        params_d.append(p)
                    custom_parameter_names.add(n)
            group_d["params"] = params_d
            group_nd["params"] = params_nd

            optimizer_grouped_parameters.append(group_d)
            optimizer_grouped_parameters.append(group_nd)

        if not self.args.train_custom_parameters_only:
            optimizer_grouped_parameters.extend(
                [
                    {
                        "params": [
                            p
                            for n, p in model.named_parameters()
                            if n not in custom_parameter_names and not any(nd in n for nd in no_decay)
                        ],
                        "weight_decay": args.weight_decay,
                    },
                    {
                        "params": [
                            p
                            for n, p in model.named_parameters()
                            if n not in custom_parameter_names and any(nd in n for nd in no_decay)
                        ],
                        "weight_decay": 0.0,
                    },
                ]
            )

        warmup_steps = math.ceil(t_total * args.warmup_ratio)
        args.warmup_steps = warmup_steps if args.warmup_steps == 0 else args.warmup_steps

        optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
        )

        if (
            args.model_name
            and os.path.isfile(os.path.join(args.model_name, "optimizer.pt"))
            and os.path.isfile(os.path.join(args.model_name, "scheduler.pt"))
        ):
            # Load in optimizer and scheduler states
            optimizer.load_state_dict(torch.load(os.path.join(args.model_name, "optimizer.pt")))
            scheduler.load_state_dict(torch.load(os.path.join(args.model_name, "scheduler.pt")))

        if args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        logger.info(" Training started")

        global_step = 0
        tr_loss, logging_loss = 0.0, 0.0
        model.zero_grad()
        train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.silent, mininterval=0)
        epoch_number = 0
        best_eval_metric = None
        early_stopping_counter = 0
        steps_trained_in_current_epoch = 0
        epochs_trained = 0

        if args.model_name and os.path.exists(args.model_name):
            try:
                # set global_step to gobal_step of last saved checkpoint from model path
                checkpoint_suffix = args.model_name.split("/")[-1].split("-")
                if len(checkpoint_suffix) > 2:
                    checkpoint_suffix = checkpoint_suffix[1]
                else:
                    checkpoint_suffix = checkpoint_suffix[-1]
                global_step = int(checkpoint_suffix)
                epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = global_step % (
                    len(train_dataloader) // args.gradient_accumulation_steps
                )

                logger.info("   Continuing training from checkpoint, will skip to saved global_step")
                logger.info("   Continuing training from epoch %d", epochs_trained)
                logger.info("   Continuing training from global step %d", global_step)
                logger.info("   Will skip the first %d steps in the current epoch", steps_trained_in_current_epoch)
            except ValueError:
                logger.info("   Starting fine-tuning.")

        if args.evaluate_during_training:
            training_progress_scores = self._create_training_progress_scores(**kwargs)

        if args.wandb_project:
            wandb.init(project=args.wandb_project, config={**asdict(args)}, **args.wandb_kwargs)
            wandb.watch(self.model)

        if args.fp16:
            from torch.cuda import amp

            scaler = amp.GradScaler()

        model.train()
        for current_epoch in train_iterator:
            if epochs_trained > 0:
                epochs_trained -= 1
                continue
            train_iterator.set_description(f"Epoch {epoch_number + 1} of {args.num_train_epochs}")
            batch_iterator = tqdm(
                train_dataloader,
                desc=f"Running Epoch {epoch_number} of {args.num_train_epochs}",
                disable=args.silent,
                mininterval=0,
            )
            for step, batch in enumerate(batch_iterator):
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue
                batch = tuple(t.to(device) for t in batch)

                inputs = self._get_inputs_dict(batch)
                if args.fp16:
                    with amp.autocast():
                        outputs = model(**inputs)
                        # model outputs are always tuple in pytorch-transformers (see doc)
                        loss = outputs[0]
                else:
                    outputs = model(**inputs)
                    # model outputs are always tuple in pytorch-transformers (see doc)
                    loss = outputs[0]

                if args.n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu parallel training

                current_loss = loss.item()

                if show_running_loss:
                    batch_iterator.set_description(
                        f"Epochs {epoch_number}/{args.num_train_epochs}. Running Loss: {current_loss:9.4f}"
                    )

                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    scaler.scale(loss).backward()
                else:
                    loss.backward()

                tr_loss += loss.item()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

                    if args.fp16:
                        scaler.step(optimizer)
                        scaler.update()
                    else:
                        optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    model.zero_grad()
                    global_step += 1

                    if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                        # Log metrics
                        tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                        tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
                        logging_loss = tr_loss
                        if args.wandb_project:
                            wandb.log(
                                {
                                    "Training loss": current_loss,
                                    "lr": scheduler.get_lr()[0],
                                    "global_step": global_step,
                                }
                            )

                    if args.save_steps > 0 and global_step % args.save_steps == 0:
                        # Save model checkpoint
                        output_dir_current = os.path.join(output_dir, "checkpoint-{}".format(global_step))

                        self.save_model(output_dir_current, optimizer, scheduler, model=model)

                    if args.evaluate_during_training and (
                        args.evaluate_during_training_steps > 0
                        and global_step % args.evaluate_during_training_steps == 0
                    ):
                        # Only evaluate when single GPU otherwise metrics may not average well
                        results = self.eval_model(
                            eval_data,
                            verbose=verbose and args.evaluate_during_training_verbose,
                            silent=args.evaluate_during_training_silent,
                            **kwargs,
                        )
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value, global_step)

                        output_dir_current = os.path.join(output_dir, "checkpoint-{}".format(global_step))

                        if args.save_eval_checkpoints:
                            self.save_model(output_dir_current, optimizer, scheduler, model=model, results=results)

                        training_progress_scores["global_step"].append(global_step)
                        training_progress_scores["train_loss"].append(current_loss)
                        for key in results:
                            training_progress_scores[key].append(results[key])
                        report = pd.DataFrame(training_progress_scores)
                        report.to_csv(
                            os.path.join(args.output_dir, "training_progress_scores.csv"), index=False,
                        )

                        if args.wandb_project:
                            wandb.log(self._get_last_metrics(training_progress_scores))

                        if not best_eval_metric:
                            best_eval_metric = results[args.early_stopping_metric]
                            self.save_model(args.best_model_dir, optimizer, scheduler, model=model, results=results)
                        if best_eval_metric and args.early_stopping_metric_minimize:
                            if results[args.early_stopping_metric] - best_eval_metric < args.early_stopping_delta:
                                best_eval_metric = results[args.early_stopping_metric]
                                self.save_model(
                                    args.best_model_dir, optimizer, scheduler, model=model, results=results
                                )
                                early_stopping_counter = 0
                            else:
                                if args.use_early_stopping:
                                    if early_stopping_counter < args.early_stopping_patience:
                                        early_stopping_counter += 1
                                        if verbose:
                                            logger.info(f" No improvement in {args.early_stopping_metric}")
                                            logger.info(f" Current step: {early_stopping_counter}")
                                            logger.info(f" Early stopping patience: {args.early_stopping_patience}")
                                    else:
                                        if verbose:
                                            logger.info(f" Patience of {args.early_stopping_patience} steps reached")
                                            logger.info(" Training terminated.")
                                            train_iterator.close()
                                        return global_step, tr_loss / global_step
                        else:
                            if results[args.early_stopping_metric] - best_eval_metric > args.early_stopping_delta:
                                best_eval_metric = results[args.early_stopping_metric]
                                self.save_model(
                                    args.best_model_dir, optimizer, scheduler, model=model, results=results
                                )
                                early_stopping_counter = 0
                            else:
                                if args.use_early_stopping:
                                    if early_stopping_counter < args.early_stopping_patience:
                                        early_stopping_counter += 1
                                        if verbose:
                                            logger.info(f" No improvement in {args.early_stopping_metric}")
                                            logger.info(f" Current step: {early_stopping_counter}")
                                            logger.info(f" Early stopping patience: {args.early_stopping_patience}")
                                    else:
                                        if verbose:
                                            logger.info(f" Patience of {args.early_stopping_patience} steps reached")
                                            logger.info(" Training terminated.")
                                            train_iterator.close()
                                        return global_step, tr_loss / global_step

            epoch_number += 1
            output_dir_current = os.path.join(output_dir, "checkpoint-{}-epoch-{}".format(global_step, epoch_number))

            if args.save_model_every_epoch or args.evaluate_during_training:
                os.makedirs(output_dir_current, exist_ok=True)

            if args.save_model_every_epoch:
                self.save_model(output_dir_current, optimizer, scheduler, model=model)

            if args.evaluate_during_training:
                results = self.eval_model(
                    eval_data,
                    verbose=verbose and args.evaluate_during_training_verbose,
                    silent=args.evaluate_during_training_silent,
                    **kwargs,
                )

                self.save_model(output_dir_current, optimizer, scheduler, results=results)

                training_progress_scores["global_step"].append(global_step)
                training_progress_scores["train_loss"].append(current_loss)
                for key in results:
                    training_progress_scores[key].append(results[key])
                report = pd.DataFrame(training_progress_scores)
                report.to_csv(os.path.join(args.output_dir, "training_progress_scores.csv"), index=False)

                if args.wandb_project:
                    wandb.log(self._get_last_metrics(training_progress_scores))

                if not best_eval_metric:
                    best_eval_metric = results[args.early_stopping_metric]
                    self.save_model(args.best_model_dir, optimizer, scheduler, model=model, results=results)
                if best_eval_metric and args.early_stopping_metric_minimize:
                    if results[args.early_stopping_metric] - best_eval_metric < args.early_stopping_delta:
                        best_eval_metric = results[args.early_stopping_metric]
                        self.save_model(args.best_model_dir, optimizer, scheduler, model=model, results=results)
                        early_stopping_counter = 0
                    else:
                        if args.use_early_stopping and args.early_stopping_consider_epochs:
                            if early_stopping_counter < args.early_stopping_patience:
                                early_stopping_counter += 1
                                if verbose:
                                    logger.info(f" No improvement in {args.early_stopping_metric}")
                                    logger.info(f" Current step: {early_stopping_counter}")
                                    logger.info(f" Early stopping patience: {args.early_stopping_patience}")
                            else:
                                if verbose:
                                    logger.info(f" Patience of {args.early_stopping_patience} steps reached")
                                    logger.info(" Training terminated.")
                                    train_iterator.close()
                                return global_step, tr_loss / global_step
                else:
                    if results[args.early_stopping_metric] - best_eval_metric > args.early_stopping_delta:
                        best_eval_metric = results[args.early_stopping_metric]
                        self.save_model(args.best_model_dir, optimizer, scheduler, model=model, results=results)
                        early_stopping_counter = 0
                    else:
                        if args.use_early_stopping and args.early_stopping_consider_epochs:
                            if early_stopping_counter < args.early_stopping_patience:
                                early_stopping_counter += 1
                                if verbose:
                                    logger.info(f" No improvement in {args.early_stopping_metric}")
                                    logger.info(f" Current step: {early_stopping_counter}")
                                    logger.info(f" Early stopping patience: {args.early_stopping_patience}")
                            else:
                                if verbose:
                                    logger.info(f" Patience of {args.early_stopping_patience} steps reached")
                                    logger.info(" Training terminated.")
                                    train_iterator.close()
                                return global_step, tr_loss / global_step

        return global_step, tr_loss / global_step
def train(train_dataset, model, tokenizer, hyperparams):

    verbose = hyperparams["verbose"]
    disable = False if verbose else True

    local_rank = hyperparams["local_rank"]
    per_gpu_train_batch_size = hyperparams["per_gpu_train_batch_size"]
    n_gpu = hyperparams["n_gpu"]
    max_steps = hyperparams["max_steps"]
    num_train_epochs = hyperparams["num_train_epochs"]
    gradient_accumulation_steps = hyperparams["gradient_accumulation_steps"]
    weight_decay = hyperparams["weight_decay"]
    learning_rate = hyperparams["learning_rate"]
    adam_epsilon = hyperparams["adam_epsilon"]
    warmup_steps = hyperparams["warmup_steps"]
    seed = hyperparams["random_state"]
    device = hyperparams["device"]
    model_type = hyperparams["model_type"]
    max_grad_norm = hyperparams["max_grad_norm"]

    save_steps = hyperparams['save_steps']

    output_dir = hyperparams["output_dir"]
    log_path = os.path.join(output_dir, "log.csv")
    fp16_opt_level = hyperparams["fp16_opt_level"]
    fp16 = hyperparams["fp16"]

    model_name_or_path = hyperparams["model_name_or_path"]
    opt_path = os.path.join(model_name_or_path, "optimizer.pt")
    sche_path = os.path.join(model_name_or_path, "scheduler.pt")

    training_logs = {"loss": [], "learning_rate": []}
    train_batch_size = per_gpu_train_batch_size * max(1, n_gpu)

    if local_rank == -1:
        train_sampler = RandomSampler(train_dataset)
    else:
        DistributedSampler(train_dataset)

    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=train_batch_size)

    if max_steps > 0:
        t_total = max_steps
        num_train_epochs = max_steps // (len(train_dataloader) //
                                         gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader) // gradient_accumulation_steps * num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]

    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]

    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=learning_rate,
                      eps=adam_epsilon)

    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=warmup_steps,
                                                num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(opt_path) and os.path.isfile(sche_path):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(torch.load(opt_path))
        scheduler.load_state_dict(torch.load(sche_path))

    if fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True)

    # Train!
    logging.info("***** Running training *****")
    logging.info("  Num examples = %d", len(train_dataset))
    logging.info("  Num Epochs = %d", num_train_epochs)
    logging.info("  Instantaneous batch size per GPU = %d",
                 per_gpu_train_batch_size)
    logging.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        train_batch_size * gradient_accumulation_steps *
        (torch.distributed.get_world_size() if local_rank != -1 else 1))
    logging.info("  Gradient Accumulation steps = %d",
                 gradient_accumulation_steps)
    logging.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0

    # Check if continuing training from a checkpoint
    if os.path.exists(
            model_name_or_path) and model_name_or_path.find("checkpoints") > 0:
        # set global_step to gobal_step of last saved checkpoint from model
        # path
        global_step = int(model_name_or_path.split("-")[-1].split("/")[0])
        epochs_trained = global_step // (len(train_dataloader) //
                                         gradient_accumulation_steps)
        steps_trained_in_current_epoch = global_step % (
            len(train_dataloader) // gradient_accumulation_steps)

        logging.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logging.info("  Continuing training from epoch %d", epochs_trained)
        logging.info("  Continuing training from global step %d", global_step)
        logging.info("  Will skip the first %d steps in the first epoch",
                     steps_trained_in_current_epoch)

    tr_loss = 0.0
    model.zero_grad()
    set_seed(seed, n_gpu=n_gpu)  # Added here for reproductibility

    train_iterator = trange(epochs_trained,
                            int(num_train_epochs),
                            desc="Epoch",
                            disable=disable)

    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=disable)

        for step, batch in enumerate(epoch_iterator):
            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue
            model.train()
            batch = tuple(t.to(device) for t in batch)
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "labels": batch[3]
            }
            # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
            if model_type != "distilbert":
                inputs["token_type_ids"] = (batch[2] if model_type in [
                    "bert", "xlnet", "albert"
                ] else None)
            outputs = model(**inputs)
            loss = outputs[0]
            if n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            if fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            training_logs["loss"].append(loss.item())
            training_logs["learning_rate"].append(scheduler.get_last_lr()[0])
            if (step + 1) % gradient_accumulation_steps == 0:
                if fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   max_grad_norm)

            optimizer.step()
            scheduler.step()  # Update learning rate schedule
            model.zero_grad()
            global_step += 1

            if local_rank in [
                    -1, 0
            ] and save_steps > 0 and global_step % save_steps == 0:
                # Save model checkpoint
                output_dir = os.path.join(output_dir,
                                          "checkpoint-{}".format(global_step))
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                # Take care of distributed/parallel training
                model_to_save = (model.module
                                 if hasattr(model, "module") else model)
                model_to_save.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)
                torch.save(
                    hyperparams,
                    os.path.join(output_dir, "training_hyperparams.bin"))
                logging.info("Saving model checkpoint to %s", output_dir)
                torch.save(optimizer.state_dict(),
                           os.path.join(output_dir, "optimizer.pt"))
                torch.save(scheduler.state_dict(),
                           os.path.join(output_dir, "scheduler.pt"))
                logging.info("Saving optimizer and scheduler states to %s",
                             output_dir)

            if max_steps > 0 and global_step > max_steps:
                epoch_iterator.close()
                break
        if max_steps > 0 and global_step > max_steps:
            train_iterator.close()
            break

    training_logs = pd.DataFrame(training_logs)
    training_logs.to_csv(log_path, index=False)
    return global_step, tr_loss / global_step
def train(args, model, train_dataset, dev_dataset=None, test_dataset=None):
    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)
    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    weight_decay_change = 'sentiment_embedding.weight'
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n
                       for nd in no_decay) and not n in weight_decay_change
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay) and not n in weight_decay_change
        ],
        'weight_decay':
        0.0
    }, {
        'params':
        [p for n, p in model.named_parameters() if n in weight_decay_change],
        'weight_decay':
        0.3
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt")):
        # Load optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Total train batch size = %d", args.train_batch_size)
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)
    logger.info("  Logging steps = %d", args.logging_steps)
    logger.info("  Save steps = %d", args.save_steps)

    global_step = 0
    tr_loss = 0.0

    model.zero_grad()
    mb = master_bar(range(int(args.num_train_epochs)))
    best_acc = 0
    acc = 0
    for epoch in mb:
        epoch_iterator = progress_bar(train_dataloader, parent=mb)
        ep_loss = []
        for step, (batch, txt) in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "token_type_ids": batch[2],
                "labels": batch[3]
            }
            if "KOSAC" in args.model_mode:
                inputs["polarity_ids"] = batch[4]
                inputs["intensity_ids"] = batch[5]
            if "KNU" in args.model_mode:
                inputs["polarity_ids"] = batch[4]
            if "CHAR" in args.model_mode:
                inputs["char_token_data"] = txt[1]
                inputs["word_token_data"] = txt[2]
                txt = txt[0]
            outputs = model(**inputs)
            # print(outputs)
            loss = outputs[0]
            # print(loss)
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if type(loss) == tuple:
                # print(list(map(lambda x:x.item(),loss)))
                ep_loss.append(list(map(lambda x: x.item(), loss)))
                loss = sum(loss)
            else:
                ep_loss.append([loss.item()])

            loss.backward()
            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0 or (
                    len(train_dataloader) <= args.gradient_accumulation_steps
                    and (step + 1) == len(train_dataloader)):
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

                optimizer.step()
                scheduler.step()
                model.zero_grad()
                global_step += 1

                if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    if args.evaluate_test_during_training:
                        results = evaluate(args, model, test_dataset, "test",
                                           global_step)
                        acc = str(results['acc'])
                    else:
                        results = evaluate(args, model, dev_dataset, "dev",
                                           global_step)
                        acc = str(results['acc'])

                if args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir,
                                              "checkpoint-best")

                    if float(best_acc) <= float(acc):
                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
                        torch.save(
                            model.state_dict(),
                            os.path.join(output_dir, "training_model.bin"))
                        torch.save(
                            args, os.path.join(output_dir,
                                               "training_args.bin"))
                        with open(os.path.join(output_dir, "model_code.txt"),
                                  "w") as fp:
                            fp.writelines(
                                inspect.getsource(MODEL_LIST[args.model_mode]))

                        logger.info(
                            "Saving model checkpoint to {}".format(output_dir))
                        temp = acc

                    if args.save_optimizer:
                        if float(best_acc) <= float(acc):
                            torch.save(
                                optimizer.state_dict(),
                                os.path.join(output_dir, "optimizer.pt"))
                            torch.save(
                                scheduler.state_dict(),
                                os.path.join(output_dir, "scheduler.pt"))
                            logger.info(
                                "Saving optimizer and scheduler states to {}".
                                format(output_dir))
                    best_acc = temp

            if args.max_steps > 0 and global_step > args.max_steps:
                break

        mb.write("Epoch {} done".format(epoch + 1))
        mb.write("Epoch loss = {} ".format(np.mean(np.array(ep_loss), axis=0)))

        if args.max_steps > 0 and global_step > args.max_steps:
            break

    return global_step, tr_loss / global_step
Esempio n. 6
0
def train(model, tokenizer, checkpoint):
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
    else:
        amp = None
    # 训练数据处理
    train_data = SampleTrainData(data_file=args.train_file,
                                 max_length=args.max_length,
                                 tokenizer=tokenizer)
    train_dataLoader = DataLoader(dataset=train_data,
                                  batch_size=args.batch_size,
                                  shuffle=args.shuffle)

    # 初始化 optimizer,scheduler
    t_total = len(train_dataLoader) * args.epochs
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    num_warmup_steps = args.warmup_steps * t_total
    optimizer = AdamW(model.parameters(),
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=t_total)
    # apex
    if args.fp16:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fptype)

    # 读取断点 optimizer、scheduler
    checkpoint_dir = args.save_dir + "/checkpoint-" + str(checkpoint)
    if os.path.isfile(os.path.join(checkpoint_dir, "optimizer.pt")):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(checkpoint_dir, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(checkpoint_dir, "scheduler.pt")))
        if args.fp16:
            amp.load_state_dict(
                torch.load(os.path.join(checkpoint_dir, "amp.pt")))

    # 开始训练
    logger.debug("***** Running training *****")
    logger.debug("  Num Steps = %d", len(train_dataLoader))
    logger.debug("  Num Epochs = %d", args.epochs)
    logger.debug("  Set_Batch size = %d", args.batch_size)
    logger.debug("  Real_Batch_size = %d", args.batch_size * args.accumulate)
    #logger.debug("  Loss_rate_ = " + str(args.loss_rate))
    logger.debug("  Shuffle = " + str(args.shuffle))
    logger.debug("  warmup_steps = " + str(num_warmup_steps))

    # 没有历史断点,则从0开始
    if checkpoint < 0:
        checkpoint = 0
    else:
        checkpoint += 1
    logger.debug("  Start Batch = %d", checkpoint)
    for epoch in range(checkpoint, args.epochs):
        model.train()
        epoch_loss = []
        step = 0
        for batch in tqdm(train_dataLoader, desc="Iteration"):
            model.zero_grad()
            # 设置tensor gpu运行
            batch = tuple(t.to(args.device) for t in batch)
            input_ids, token_type_ids, attention_mask, labels = batch

            outputs = model(input_ids=input_ids.long(),
                            token_type_ids=token_type_ids.long(),
                            attention_mask=attention_mask,
                            labels=labels)

            loss = outputs[0]

            epoch_loss.append(loss.item())

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            optimizer.step()
            scheduler.step()
            step += 1
            if step % 500 == 0:
                logger.debug("loss:" + str(np.array(epoch_loss).mean()))
                logger.debug(
                    'learning_rate:' +
                    str(optimizer.state_dict()['param_groups'][0]['lr']))
            # 保存模型
        output_dir = args.save_dir + "/checkpoint-" + str(epoch)
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        model_to_save = (model.module if hasattr(model, "module") else model)
        model_to_save.save_pretrained(output_dir)
        tokenizer.save_pretrained(output_dir)
        torch.save(args, os.path.join(output_dir, "training_args.bin"))
        logger.debug("Saving model checkpoint to %s", output_dir)
        if args.fp16:
            torch.save(amp.state_dict(), os.path.join(output_dir, "amp.pt"))
        torch.save(optimizer.state_dict(),
                   os.path.join(output_dir, "optimizer.pt"))
        torch.save(scheduler.state_dict(),
                   os.path.join(output_dir, "scheduler.pt"))
        logger.debug("Saving optimizer and scheduler states to %s", output_dir)

        # eval dev
        eval_loss, eval_map, eval_mrr = evaluate(model,
                                                 tokenizer,
                                                 eval_file=args.dev_file,
                                                 checkpoint=epoch,
                                                 output_dir=output_dir)
        # eval test
        test_eval_loss, test_eval_map, test_eval_mrr = evaluate(
            model,
            tokenizer,
            eval_file=args.test_file,
            checkpoint=epoch,
            output_dir=output_dir)

        # 输出日志 + 保存日志
        logger.info(
            '【DEV 】Train Epoch %d: train_loss=%.4f, map=%.4f, mrr=%.4f' %
            (epoch, np.array(epoch_loss).mean(), eval_map, eval_mrr))
        logger.info(
            '【TEST】Train Epoch %d: train_loss=%.4f, map=%.4f, mrr=%.4f' %
            (epoch, np.array(epoch_loss).mean(), test_eval_map, test_eval_mrr))
Esempio n. 7
0
def train(args,
          model,
          train_dataset,
          dev_dataset=None,
          test_dataset=None):
    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
         'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)

    if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
            os.path.join(args.model_name_or_path, "scheduler.pt")
    ):
        # Load optimizer and scheduler states
        optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Total train batch size = %d", args.train_batch_size)
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)
    logger.info("  Logging steps = %d", args.logging_steps)
    logger.info("  Save steps = %d", args.save_steps)

    global_step = 0
    tr_loss = 0.0

    model.zero_grad()
    mb = master_bar(range(int(args.num_train_epochs)))
    for epoch in mb:
        epoch_iterator = progress_bar(train_dataloader, parent=mb)
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "labels": batch[3]
            }
            if args.model_type not in ["distilkobert"]:
                inputs["token_type_ids"] = (
                    batch[2] if args.model_type in ["kobert", "hanbert", "electra-base", "electra-small"] else None
                )  # XLM-Roberta don't use segment_ids
            outputs = model(**inputs)

            loss = outputs[0]

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            loss.backward()
            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0 or (
                    len(train_dataloader) <= args.gradient_accumulation_steps
                    and (step + 1) == len(train_dataloader)
            ):
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

                optimizer.step()
                scheduler.step()
                model.zero_grad()
                global_step += 1

                if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    if args.evaluate_test_during_training:
                        evaluate(args, model, test_dataset, "test", global_step)
                    else:
                        evaluate(args, model, dev_dataset, "dev", global_step)

                if args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )
                    model_to_save.save_pretrained(output_dir)

                    torch.save(args, os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to {}".format(output_dir))

                    if args.save_optimizer:
                        torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                        torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                        logger.info("Saving optimizer and scheduler states to {}".format(output_dir))

            if args.max_steps > 0 and global_step > args.max_steps:
                break

        mb.write("Epoch {} done".format(epoch + 1))

        if args.max_steps > 0 and global_step > args.max_steps:
            break

    return global_step, tr_loss / global_step
Esempio n. 8
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument("--src_file",
                        default=None,
                        type=str,
                        help="The input data file name.")
    parser.add_argument("--model_type",
                        default=None,
                        type=str,
                        required=True,
                        help="Model type selected in the list: " +
                        ", ".join(MODEL_CLASSES.keys()))
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list: "
        + ", ".join(ALL_MODELS))
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--log_dir",
        default='',
        type=str,
        help="The output directory where the log will be written.")
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument("--optim_recover_path",
                        default=None,
                        type=str,
                        help="The file of pretraining optimizer.")
    parser.add_argument(
        "--config_name",
        default="",
        type=str,
        help="Pretrained config name or path if not the same as model_name")
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name")

    # Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=None,
                        help="max position embeddings")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--label_smoothing",
                        default=0,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.01,
                        type=float,
                        help="The weight decay rate for Adam.")
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm",
                        default=1.0,
                        type=float,
                        help="Max gradient norm.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--hidden_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for hidden states.")
    parser.add_argument("--attention_probs_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for attention probabilities.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp16_opt_level',
        type=str,
        default='O1',
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument('--tokenized_input',
                        action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--max_len_a',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment A.")
    parser.add_argument('--max_len_b',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment B.")
    parser.add_argument(
        '--trunc_seg',
        default='',
        help="Truncate_config: first truncate segment A/B (option: a, b).")
    parser.add_argument(
        '--always_truncate_tail',
        action='store_true',
        help="Truncate_config: Whether we should always truncate tail.")
    parser.add_argument(
        "--mask_prob",
        default=0.20,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument(
        "--mask_prob_eos",
        default=0,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument('--max_pred',
                        type=int,
                        default=20,
                        help="Max tokens of prediction.")
    parser.add_argument("--num_workers",
                        default=0,
                        type=int,
                        help="Number of workers for the data loader.")

    parser.add_argument('--mask_source_words',
                        action='store_true',
                        help="Whether to mask source words for training")
    parser.add_argument('--skipgram_prb',
                        type=float,
                        default=0.0,
                        help='prob of ngram mask')
    parser.add_argument('--skipgram_size',
                        type=int,
                        default=1,
                        help='the max size of ngram mask')
    parser.add_argument('--mask_whole_word',
                        action='store_true',
                        help="Whether masking a whole word.")

    args = parser.parse_args()

    if not (args.model_recover_path
            and Path(args.model_recover_path).exists()):
        args.model_recover_path = None

    args.output_dir = args.output_dir.replace('[PT_OUTPUT_DIR]',
                                              os.getenv('PT_OUTPUT_DIR', ''))
    args.log_dir = args.log_dir.replace('[PT_OUTPUT_DIR]',
                                        os.getenv('PT_OUTPUT_DIR', ''))

    os.makedirs(args.output_dir, exist_ok=True)
    if args.log_dir:
        os.makedirs(args.log_dir, exist_ok=True)
    json.dump(args.__dict__,
              open(os.path.join(args.output_dir, 'opt.json'), 'w'),
              sort_keys=True,
              indent=2)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        dist.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    args.model_type = args.model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    config = config_class.from_pretrained(
        args.config_name if args.config_name else args.model_name_or_path,
        max_position_embeddings=args.max_position_embeddings,
        label_smoothing=args.label_smoothing)
    tokenizer = tokenizer_class.from_pretrained(
        args.tokenizer_name
        if args.tokenizer_name else args.model_name_or_path,
        do_lower_case=args.do_lower_case)
    data_tokenizer = WhitespaceTokenizer(
    ) if args.tokenized_input else tokenizer
    if args.local_rank == 0:
        dist.barrier()

    if args.do_train:
        print("Loading Train Dataset", args.data_dir)
        bi_uni_pipeline = [
            utils_seq2seq.Preprocess4LeftLM(
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                mask_source_words=args.mask_source_words,
                skipgram_prb=args.skipgram_prb,
                skipgram_size=args.skipgram_size,
                mask_whole_word=args.mask_whole_word,
                tokenizer=data_tokenizer),
            utils_seq2seq.Preprocess4RightLM(
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                mask_source_words=args.mask_source_words,
                skipgram_prb=args.skipgram_prb,
                skipgram_size=args.skipgram_size,
                mask_whole_word=args.mask_whole_word,
                tokenizer=data_tokenizer)
        ]

        file = os.path.join(args.data_dir,
                            args.src_file if args.src_file else 'train.tgt')
        train_dataset = utils_seq2seq.LRDataset(
            file,
            args.train_batch_size,
            data_tokenizer,
            args.max_seq_length,
            bi_uni_pipeline=bi_uni_pipeline)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset, replacement=False)
            _batch_size = args.train_batch_size
        else:
            train_sampler = DistributedSampler(train_dataset)
            _batch_size = args.train_batch_size // dist.get_world_size()
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=_batch_size,
            sampler=train_sampler,
            num_workers=args.num_workers,
            collate_fn=utils_seq2seq.batch_list_to_batch_tensors,
            pin_memory=False)

    # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps)
    # t_total = int(math.ceil(len(train_dataset.ex_list) / args.train_batch_size)
    t_total = int(
        len(train_dataloader) * args.num_train_epochs /
        args.gradient_accumulation_steps)

    # Prepare model
    recover_step = _get_max_epoch_model(args.output_dir)
    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    global_step = 0
    if (recover_step is None) and (args.model_recover_path is None):
        model_recover = None
    else:
        if recover_step:
            logger.info("***** Recover model: %d *****", recover_step)
            model_recover = torch.load(os.path.join(
                args.output_dir, "model.{0}.bin".format(recover_step)),
                                       map_location='cpu')
            # recover_step == number of epochs
            global_step = math.floor(recover_step * t_total /
                                     args.num_train_epochs)
        elif args.model_recover_path:
            logger.info("***** Recover model: %s *****",
                        args.model_recover_path)
            model_recover = torch.load(args.model_recover_path,
                                       map_location='cpu')
    model = model_class.from_pretrained(args.model_name_or_path,
                                        state_dict=model_recover,
                                        config=config)
    if args.local_rank == 0:
        dist.barrier()

    model.to(device)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(args.warmup_proportion * t_total),
        num_training_steps=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    if args.local_rank != -1:
        try:
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("DistributedDataParallel")
        model = DDP(model,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank,
                    find_unused_parameters=True)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    if recover_step:
        logger.info("***** Recover optimizer: %d *****", recover_step)
        optim_recover = torch.load(os.path.join(
            args.output_dir, "optim.{0}.bin".format(recover_step)),
                                   map_location='cpu')
        if hasattr(optim_recover, 'state_dict'):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)

        logger.info("***** Recover amp: %d *****", recover_step)
        amp_recover = torch.load(os.path.join(
            args.output_dir, "amp.{0}.bin".format(recover_step)),
                                 map_location='cpu')
        amp.load_state_dict(amp_recover)

        logger.info("***** Recover scheduler: %d *****", recover_step)
        scheduler_recover = torch.load(os.path.join(
            args.output_dir, "sched.{0}.bin".format(recover_step)),
                                       map_location='cpu')
        scheduler.load_state_dict(scheduler_recover)

    logger.info("***** CUDA.empty_cache() *****")
    torch.cuda.empty_cache()

    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", t_total)

        model.train()
        if recover_step:
            start_epoch = recover_step + 1
        else:
            start_epoch = 1
        for i_epoch in trange(start_epoch,
                              int(args.num_train_epochs) + 1,
                              desc="Epoch",
                              disable=args.local_rank not in (-1, 0)):
            if args.local_rank != -1:
                train_sampler.set_epoch(i_epoch)
            iter_bar = tqdm(train_dataloader,
                            desc='Iter (loss=X.XXX)',
                            disable=args.local_rank not in (-1, 0))
            for step, batch in enumerate(iter_bar):
                batch = [
                    t.to(device) if t is not None else None for t in batch
                ]
                input_ids, segment_ids, input_mask, lm_label_ids, masked_pos, masked_weights, _ = batch
                masked_lm_loss = model(input_ids,
                                       segment_ids,
                                       input_mask,
                                       lm_label_ids,
                                       masked_pos=masked_pos,
                                       masked_weights=masked_weights)
                if n_gpu > 1:  # mean() to average on multi-gpu.
                    # loss = loss.mean()
                    masked_lm_loss = masked_lm_loss.mean()
                loss = masked_lm_loss

                # logging for each step (i.e., before normalization by args.gradient_accumulation_steps)
                iter_bar.set_description(
                    'Iter (loss={:.2f}) (lr={:0.2e})'.format(
                        loss.item(),
                        scheduler.get_lr()[0]))

                # ensure that accumlated gradients are normalized
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                if (step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    optimizer.zero_grad()
                    global_step += 1

            # Save a trained model
            if (args.local_rank == -1 or torch.distributed.get_rank() == 0):
                logger.info(
                    "** ** * Saving fine-tuned model and optimizer ** ** * ")
                model_to_save = model.module if hasattr(
                    model, 'module') else model  # Only save the model it-self
                output_model_file = os.path.join(
                    args.output_dir, "model.{0}.bin".format(i_epoch))
                torch.save(model_to_save.state_dict(), output_model_file)
                output_optim_file = os.path.join(
                    args.output_dir, "optim.{0}.bin".format(i_epoch))
                torch.save(optimizer.state_dict(), output_optim_file)
                if args.fp16:
                    output_amp_file = os.path.join(
                        args.output_dir, "amp.{0}.bin".format(i_epoch))
                    torch.save(amp.state_dict(), output_amp_file)
                output_sched_file = os.path.join(
                    args.output_dir, "sched.{0}.bin".format(i_epoch))
                torch.save(scheduler.state_dict(), output_sched_file)

                logger.info("***** CUDA.empty_cache() *****")
                torch.cuda.empty_cache()
Esempio n. 9
0
    pytorch_total_params = sum(p.numel() for p in model.parameters()
                               if p.requires_grad)
    logger_log.write(
        'total num. parameters to be trained: {}'.format(pytorch_total_params))

    # load trained model
    if args.input_weights is not None:
        model_dict = model.state_dict()
        trained_model = torch.load(os.path.join(output_dir,
                                                args.input_weights))
        trained_dict = {
            k: v
            for k, v in trained_model.items() if k in model_dict
        }
        model_dict.update(trained_dict)
        model.load_state_dict(model_dict)
        logger_log.write('trained model [{}] is loaded...'.format(
            args.input_weights))

        optim = AdamW(filter(lambda p: p.requires_grad, model.parameters()))
        optim.load_state_dict(
            trained_model.get('optimizer_state', trained_model))
        epoch = trained_model['epoch'] + 1
    else:
        optim = None
        epoch = 0

    train(model, trainval_loader, test_loader, args.epochs, output_dir,
          logger_log, optim, epoch, args.batch_size, device, n_gpu,
          args.lr_init)
Esempio n. 10
0
def main():
    parser = ArgumentParser()
    parser.add_argument(
        "--dataset_path",
        type=str,
        default="",
        help="Path or url of the dataset. If empty download from S3.")
    parser.add_argument("--dataset_cache",
                        type=str,
                        default='./dataset_cache',
                        help="Path or url of the dataset cache")
    parser.add_argument("--model",
                        type=str,
                        help="Model type, one of: %s" %
                        ', '.join(MODELS.keys()))
    parser.add_argument("--model_checkpoint",
                        type=str,
                        default="",
                        help="Path, url or short name of a pretrained model")
    parser.add_argument("--num_candidates",
                        type=int,
                        default=2,
                        help="Number of candidates for training")
    parser.add_argument("--max_history",
                        type=int,
                        default=2,
                        help="Number of previous exchanges to keep in history")
    parser.add_argument("--train_batch_size",
                        type=int,
                        default=4,
                        help="Batch size for training")
    parser.add_argument("--valid_batch_size",
                        type=int,
                        default=4,
                        help="Batch size for validation")
    parser.add_argument("--gradient_accumulation_steps",
                        type=int,
                        default=8,
                        help="Accumulate gradients on several steps")
    parser.add_argument("--lr",
                        type=float,
                        default=6.25e-5,
                        help="Learning rate")
    parser.add_argument("--lm_coef",
                        type=float,
                        default=1.0,
                        help="LM loss coefficient")
    parser.add_argument("--mc_coef",
                        type=float,
                        default=1.0,
                        help="Multiple-choice loss coefficient")
    parser.add_argument("--adv_coef",
                        type=float,
                        default=1.0,
                        help="Adversarial dataset prediction loss coefficient")
    parser.add_argument("--max_norm",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--weight_decay",
                        default=0.0,
                        type=float,
                        help="Weight decay if we apply some.")
    parser.add_argument("--warmup_steps",
                        default=0,
                        type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument("--n_epochs",
                        type=int,
                        default=3,
                        help="Number of training epochs")
    #parser.add_argument("--personality_permutations", type=int, default=1, help="Number of permutations of personality sentences")
    parser.add_argument(
        "--eval_before_start",
        action='store_true',
        help="If true start with a first evaluation before training")
    #parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Avoid using CUDA when available")
    parser.add_argument(
        "--fp16",
        type=str,
        default="",
        help=
        "Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="Local rank for distributed training (-1: not distributed)")
    parser.add_argument(
        "--max_sequence_length",
        type=int,
        default=-1,
        help="If set, use this to manually restrict the sequence length. "
        "This might be helpful to save resources (memory). "
        "If not set, this is looked up from the model config (n_ctx value).")
    parser.add_argument(
        "--adversarial_dataset_prediction",
        action='store_true',
        help="Set to train with adversarial dataset prediction")
    parser.add_argument("--seed",
                        type=int,
                        default=None,
                        help='set random seed')
    args = parser.parse_args()

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Running process %d", args.local_rank
    )  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(args))

    if args.seed is not None:
        torch.manual_seed(args.seed)

    args.distributed = (args.local_rank != -1)

    logger.info("Prepare tokenizer and data")

    if not args.model_checkpoint:
        args.model_checkpoint = args.model
    model_super_type = None
    if args.model in GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP:
        model_super_type = 'gpt2'
    assert model_super_type is not None, f'unknown model class: "{args.model}". use one of: {", ".join(GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())}'

    if model_super_type not in MODELS:
        raise NotImplementedError(
            f'model type "{model_super_type}" not implemented. use one of: {", ".join(MODELS.keys())}'
        )
    config_class, tokenizer_class, model_class, _ = MODELS[model_super_type]

    model_config = config_class.from_pretrained(args.model_checkpoint)
    tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint)

    additional_special_tokens = [TYPE_BACKGROUND, TYPE_BOT, TYPE_USER]
    # for adversarial training (dataset prediction)
    dataset_labels = None
    if args.adversarial_dataset_prediction:
        dataset_labels = [
            get_dataset_label(dataset_path)
            for dataset_path in args.dataset_path.split(',')
        ]
        #additional_special_tokens.extend(dataset_labels)
        #if model_class not in ADV_MODELS.values():
        assert model_class in ADV_MODELS, f'no adversarial model implemented for model class: {model_class.__name__}'
        model_class = ADV_MODELS[model_class]
        if not hasattr(model_config, 'cls'):
            model_config.cls = {}
        if 'dataset_labels' in model_config.cls:
            assert all([dl in model_config.cls['dataset_labels']['labels'] for dl in dataset_labels]), \
                f'loaded dataset_labels [{model_config.cls["dataset_labels"]["labels"]}] do not contain all ' \
                f'current dataset_labels [{dataset_labels}]'
            dataset_labels = model_config.cls['dataset_labels']['labels']
        else:
            model_config.cls['dataset_labels'] = {
                'labels': dataset_labels,
                'is_adversarial': True
            }
        model_input_names = [
            "input_ids", "mc_token_ids", "lm_labels", "mc_labels",
            "dataset_labels", "token_type_ids"
        ]
        # not yet used
        model_output_names = [
            "lm_loss", "mc_loss", "cl_loss_0", "lm_logits", "mc_logits",
            "cl_logits_0", "presents"
        ]
    else:
        model_input_names = [
            "input_ids", "mc_token_ids", "lm_labels", "mc_labels",
            "token_type_ids"
        ]
        # not yet used
        model_output_names = [
            "lm_loss", "mc_loss", "lm_logits", "mc_logits", "presents"
        ]

    tokenizer.add_special_tokens({
        'bos_token':
        TYPE_BOS,
        'eos_token':
        TYPE_EOS,
        'pad_token':
        TYPE_PAD,
        'additional_special_tokens':
        additional_special_tokens
    })

    logger.info("Prepare datasets")
    max_sequence_length = model_config.n_ctx if args.max_sequence_length <= 0 else args.max_sequence_length
    assert max_sequence_length <= model_config.n_ctx, 'max_sequence_length [%i] was set to a value higher than ' \
                                                      'supported by the model (config.n_ctx [%i]). Please use a lower ' \
                                                      'value or do not set it [-1] to use the highest supported one.' \
                                                      % (max_sequence_length, model_config.n_ctx)
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(
        args=args,
        tokenizer=tokenizer,
        model_input_names=model_input_names,
        max_sequence_length=max_sequence_length,
        dataset_labels=dataset_labels)

    logger.info(
        "Prepare pretrained model and optimizer - add special tokens for fine-tuning"
    )

    # Initialize distributed training if needed
    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        args.n_gpu = 1
    args.device = device

    # Load pretrained model and tokenizer
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier(
        )  # Barrier to make sure only the first process in distributed training download model & vocab

    #model = model_class.from_pretrained(args.model_checkpoint, num_cl_labels=len(dataset_ids))    # for GPT2DoubleHeadsModelwithAdversarial
    model = model_class.from_pretrained(args.model_checkpoint,
                                        config=model_config)
    model.resize_token_embeddings(len(tokenizer))
    model.to(args.device)

    if args.local_rank == 0:
        torch.distributed.barrier(
        )  # End of barrier to make sure only the first process in distributed training download model & vocab

    ####################################################################################################################

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    #optimizer = OpenAIAdam(model.parameters(), lr=args.lr)
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.lr)
    # scheduler is set below (see ignite)
    #scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps,
    #                                            num_training_steps=len(train_loader) // args.train_batch_size + 1)

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_checkpoint, 'optimizer.pt')) and os.path.isfile(
                os.path.join(args.model_checkpoint, 'scheduler.pt')):
        # Load in optimizer and scheduler states
        # TODO: this needs to be dumped somewhere
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_checkpoint, 'optimizer.pt')))
        #scheduler.load_state_dict(torch.load(os.path.join(args.model_checkpoint, 'scheduler.pt')))

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Training function and trainer
    def update(engine, batch):
        model.train()
        batch = {
            model_input_names[i]: input_tensor.to(args.device)
            for i, input_tensor in enumerate(batch)
        }
        model_output = model(**batch)
        losses = model_output[:
                              3] if args.adversarial_dataset_prediction else model_output[:
                                                                                          2]
        if args.n_gpu > 1:  # mean() to average on multi-gpu.
            losses = list(losses)
            for i in range(len(losses)):
                losses[i] = losses[i].mean()
        lm_loss, mc_loss = losses[0], losses[1]
        loss = (lm_loss * args.lm_coef +
                mc_loss * args.mc_coef) / args.gradient_accumulation_steps

        # handle adversarial loss
        loss_wo_adv = loss.clone()
        if args.adversarial_dataset_prediction:
            adv_loss = model_output[2]
            loss += (adv_loss *
                     args.adv_coef) / args.gradient_accumulation_steps

        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            #scheduler.step()  # Update learning rate schedule # already DONE below!
            optimizer.zero_grad()
        return loss_wo_adv.item(), loss.item()

    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(
                input_tensor.to(args.device) for input_tensor in batch)
            if args.adversarial_dataset_prediction:
                input_ids, mc_token_ids, lm_labels, mc_labels, dataset_labels, token_type_ids = batch
            else:
                input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch

            logger.debug(
                tokenizer.decode(input_ids[0, -1, :].tolist()).replace(
                    TYPE_PAD, ''))
            model_outputs = model(input_ids=input_ids,
                                  mc_token_ids=mc_token_ids,
                                  token_type_ids=token_type_ids)
            lm_logits, mc_logits = model_outputs[0], model_outputs[
                1]  # So we can also use GPT2 outputs
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return (lm_logits_flat_shifted,
                    mc_logits), (lm_labels_flat_shifted, mc_labels)

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if args.distributed:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero (scheduler)
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, args.lr),
                                 (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x[0]).attach(trainer, "loss")
    if args.adversarial_dataset_prediction:
        RunningAverage(output_transform=lambda x: x[1]).attach(
            trainer, "loss_w/_adv")
        RunningAverage(output_transform=lambda x: x[1] - x[0]).attach(
            trainer, "loss_only_adv")
    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-1),
             output_transform=lambda x: (x[0][0], x[1][0])),
        "accuracy":
        Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))
    }
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], args),
        "average_accuracy":
        MetricsLambda(average_distributed_scalar, metrics["accuracy"], args)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: pbar.log_message(
                "Validation: %s" % pformat(evaluator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir=None)
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        if args.adversarial_dataset_prediction:
            tb_logger.attach(trainer,
                             log_handler=OutputHandler(
                                 tag="training", metric_names=["loss_w/_adv"]),
                             event_name=Events.ITERATION_COMPLETED)
            tb_logger.attach(trainer,
                             log_handler=OutputHandler(
                                 tag="training",
                                 metric_names=["loss_only_adv"]),
                             event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator,
                         log_handler=OutputHandler(tag="validation",
                                                   metric_names=list(
                                                       metrics.keys()),
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)

        logger.info('save checkpoints to: %s' % tb_logger.writer.log_dir)
        checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir,
                                             'checkpoint',
                                             save_interval=1,
                                             n_saved=3)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, checkpoint_handler,
            {'mymodel': getattr(model, 'module', model)
             })  # "getattr" take care of distributed encapsulation

        torch.save(args, tb_logger.writer.log_dir + '/model_training_args.bin')
        getattr(model, 'module', model).config.to_json_file(
            os.path.join(tb_logger.writer.log_dir, CONFIG_NAME))
        tokenizer.save_pretrained(tb_logger.writer.log_dir)

        #logger.debug("Saving optimizer and scheduler states to %s", tb_logger.writer.log_dir)
        #torch.save(optimizer.state_dict(), os.path.join(tb_logger.writer.log_dir, 'optimizer.pt'))
        #torch.save(scheduler.state_dict(), os.path.join(tb_logger.writer.log_dir, 'scheduler.pt'))

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        os.rename(
            checkpoint_handler._saved[-1][1][-1],
            os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME)
        )  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
Esempio n. 11
0
                  valid_loader=valid_loader,
                  test_loader=test_loader,
                  save_path=args.save_path,
                  tokenizer=tokenizer)
trainer.train()

del model, optimizer
# Init model and optimizer
model = BERT_LSTM_Classification(class_num=config.class_num,
                                 bidirectional=config.bidirectional)
optimizer = AdamW(model.parameters(), lr=config.learning_rate, eps=1e-8)

# load checkpoint
checkpoint = torch.load('{}/best_checkpoint.pt'.format(args.save_path))
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']

# get initial lexicon
filter_words = Lexicon.filter_words(sampled_text, sampled_label)
lexicon_instance = Lexicon(epoch=epoch,
                           fname=args.save_path,
                           class_label=config.class_num,
                           filter_words=filter_words)

# Semi-supervised learning
ssl_trainer = SSL_Trainer(expt_dir=args.save_path,
                          criterion=torch.nn.CrossEntropyLoss(),
                          lexicon_instance=lexicon_instance,
                          config=config)
Esempio n. 12
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    # n_gpu 와 1 사이 max 해서 곱하는 거였는데, 어차피 한개니까 삭제함
    train_batch_size = per_gpu_train_batch_size
    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=train_batch_size)

    # step의 최대 개수 정해져 있다면 그대로 사용하고, 아니라면 t_total 값 새로 지정
    if max_steps > 0:
        t_total = max_steps
        num_train_epochs = max_steps // (len(train_dataloader) // gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // gradient_accumulation_steps * num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    # 옵티마이저와 스케줄 준비
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]
    
    #optimizer와 scheduler 정하기
    optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total
    )

    # Check if saved optimizer or scheduler states exist
    # optimizer나 scheduler state가 이미 존재하는지 확인
    if os.path.isfile(os.path.join(model_name_or_path, "optimizer.pt")) and os.path.isfile(
        os.path.join(model_name_or_path, "scheduler.pt")
    ):
        # Load in optimizer and scheduler states
        # 만약에 이미 저장된 상태 있다면 로드 한다
        optimizer.load_state_dict(torch.load(os.path.join(model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(torch.load(os.path.join(model_name_or_path, "scheduler.pt")))

    # fp 16 일 경우 삭제
    # multi-gpu training (should be after apex fp16 initialization)
    # Distributed training (//) 둘 다 삭제


    # Train!
    # 학습 시작
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        train_batch_size
        * gradient_accumulation_steps,
    )
    logger.info("  Gradient Accumulation steps = %d", gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 1
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    
    # Check if continuing training from a checkpoint
    # 만약 체크포인트 로부터 학습 계속할 것이라면 확인한다
    if os.path.exists(model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            # 로드 할 거라면 global_step 을 마지막 체크포인트의 global_step으로 업데이트 한다
            checkpoint_suffix = model_name_or_path.split("-")[-1].split("/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) // gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (len(train_dataloader) // gradient_accumulation_steps)

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d", global_step)
            logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
        except ValueError:
            # 만약에 없으면? 하는건가
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(
        epochs_trained, int(num_train_epochs), desc="Epoch", disable=local_rank not in [-1, 0]
    )
    # Added here for reproductibility
    # 재샌상성을 위해 추가된 코드. 
    set_seed(args)

    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            # 만약 이미 학습된 스텝이 있다면 건너 뛴다
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model.train()
            batch = tuple(t.to(device) for t in batch)

            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "token_type_ids": batch[2],
                "start_positions": batch[3],
                "end_positions": batch[4],
            }

            if model_type in ["xlm", "roberta", "distilbert"]:
                del inputs["token_type_ids"]

            outputs = model(**inputs)
            # model outputs are always tuple in transformers (see doc)
            loss = outputs[0]

            # n-gpu>1 일 경우 삭제
            if gradient_accumulation_steps > 1:
                loss = loss / gradient_accumulation_steps

            # fp 16일 경우 삭제
            loss.backward()

            tr_loss += loss.item()
            if (step + 1) % gradient_accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                # Log metrics
                if local_rank in [-1, 0] and logging_steps > 0 and global_step % logging_steps == 0:
                    # Only evaluate when single GPU otherwise metrics may not average well
                    if local_rank == -1 and evaluate_during_training:
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value, global_step)
                    tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) / logging_steps, global_step)
                    logging_loss = tr_loss

                # Save model checkpoint
                # 모델의 체크포인트를 저장한다
                if local_rank in [-1, 0] and save_steps > 0 and global_step % save_steps == 0:
                    output_dir = os.path.join(output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    # Take care of distributed/parallel training
                    model_to_save = model.module if hasattr(model, "module") else model
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args, os.path.join(output_dir, "training_bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s", output_dir)

            if max_steps > 0 and global_step > max_steps:
                epoch_iterator.close()
                break
        if max_steps > 0 and global_step > max_steps:
            train_iterator.close()
            break
Esempio n. 13
0
class PrototypicalNetwork(Learner):
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)

        self.inner_lr = config.learner.inner_lr
        self.meta_lr = config.learner.meta_lr
        self.mini_batch_size = config.training.batch_size

        self.pn = TransformerClsModel(model_name=config.learner.model_name,
                                      n_classes=config.data.n_classes,
                                      max_length=config.data.max_length,
                                      device=self.device)
        if config.wandb:
            wandb.watch(self.pn, log='all')

        self.memory = ClassMemoryStore(key_dim=TRANSFORMER_HDIM, device=self.device,
                                       class_discount=config.learner.class_discount, n_classes=config.data.n_classes,
                                       discount_method=config.learner.class_discount_method)
        self.loss_fn = nn.CrossEntropyLoss()

        meta_params = [p for p in self.pn.parameters() if p.requires_grad]
        self.meta_optimizer = AdamW(meta_params, lr=self.meta_lr)

        inner_params = [p for p in self.pn.parameters() if p.requires_grad]
        self.inner_optimizer = optim.SGD(inner_params, lr=self.inner_lr)
        #TODO: remove below line
        self.episode_samples_seen = 0 # have to keep track of per-task samples seen as we might use replay as well

    def training(self, datasets, **kwargs):
        representations_log = []
        replay_freq, replay_steps = self.replay_parameters()
        self.logger.info("Replay frequency: {}".format(replay_freq))
        self.logger.info("Replay steps: {}".format(replay_steps))

        datas, order, n_samples, eval_train_dataset, eval_eval_dataset, eval_dataset = self.prepare_data(datasets)
        for i, (data, dataset_name, n_sample) in enumerate(zip(datas, order, n_samples)):
            self.logger.info(f"Observing dataset {dataset_name} for {n_sample} samples. "
                             f"Evaluation={dataset_name=='evaluation'}")
            if dataset_name == "evaluation" and self.config.testing.few_shot:
                self.few_shot_testing(train_dataset=eval_train_dataset, eval_dataset=eval_eval_dataset, split="test",
                                      increment_counters=False)
            else:
                train_dataloader = iter(DataLoader(data, batch_size=self.mini_batch_size, shuffle=False))
                self.episode_samples_seen = 0 # have to keep track of per-task samples seen as we might use replay as well
                # iterate over episodes
                while True:
                    self.set_train()
                    support_set, task = self.get_support_set(train_dataloader, n_sample)
                    # TODO: return flag that indicates whether the query set is from the memory. Don't include this in the online accuracy calc
                    query_set = self.get_query_set(train_dataloader, replay_freq, replay_steps, n_sample, write_memory=False)
                    if support_set is None or query_set is None:
                        break
                    self.training_step(support_set, query_set, task=task)

                    if self.current_iter % 5 == 0:
                        class_representations = self.memory.class_representations
                        extra_text, extra_labels, datasets = next(self.extra_dataloader)
                        with torch.no_grad():
                            extra_representations = self.forward(extra_text, extra_labels, no_grad=True)["representation"]
                            query_text, query_labels = query_set[0]
                            query_representations = self.forward(query_text, query_labels, no_grad=True)["representation"]
                            extra_dist, extra_dist_normalized, extra_unique_labels = model_utils.class_dists(extra_representations, extra_labels, class_representations)
                            query_dist, query_dist_normalized, query_unique_labels = model_utils.class_dists(query_representations, query_labels, class_representations)
                            class_representation_distances = model_utils.euclidean_dist(class_representations, class_representations)
                            representations_log.append(
                                {
                                    "query_dist": query_dist.tolist(),
                                    "query_dist_normalized": query_dist_normalized.tolist(),
                                    "query_labels": query_labels.tolist(),
                                    "query_unique_labels": query_unique_labels.tolist(),
                                    "extra_dist": extra_dist.tolist(),
                                    "extra_dist_normalized": extra_dist_normalized.tolist(),
                                    "extra_labels": extra_labels.tolist(),
                                    "extra_unique_labels": extra_unique_labels.tolist(),
                                    "class_representation_distances": class_representation_distances.tolist(),
                                    "class_tsne": TSNE().fit_transform(class_representations.cpu()).tolist(),
                                    "current_iter": self.current_iter,
                                    "examples_seen": self.examples_seen()
                                }
                                )
                            if self.current_iter % 100 == 0:
                                with open(self.representations_dir / f"classDists_{self.current_iter}.json", "w") as f:
                                    json.dump(representations_log, f)
                                representations_log = []

                    self.meta_training_log()
                    self.write_metrics()
                    self.current_iter += 1
                    if self.episode_samples_seen >= n_sample:
                        break
            if i == 0:
                self.metrics["eval_task_first_encounter_evaluation"] = \
                    self.evaluate(DataLoader(eval_dataset, batch_size=self.mini_batch_size))["accuracy"]
                # self.save_checkpoint("first_task_learned.pt", save_optimizer_state=True)
            if dataset_name == self.config.testing.eval_dataset:
                self.eval_task_first_encounter = False
                
    def training_step(self, support_set, query_set=None, task=None):
        self.inner_optimizer.zero_grad()

        self.logger.debug("-------------------- TRAINING STEP  -------------------")
        # with higher.innerloop_ctx(self.pn, self.inner_optimizer,
        #                           copy_initial_weights=False,
        #                           track_higher_grads=False) as (fpn, diffopt):
        do_memory_update = self.config.learner.prototype_update_freq > 0 and \
                        (self.current_iter % self.config.learner.prototype_update_freq) == 0
        ### GET SUPPORT SET REPRESENTATIONS ###
        self.logger.debug("----------------- SUPPORT SET ----------------- ")
        representations, all_labels = self.get_representations(support_set[:1])
        representations_merged = torch.cat(representations)
        class_means, unique_labels = model_utils.get_class_means(representations_merged, all_labels)
        self._examples_seen += len(representations_merged)
        self.logger.debug(f"Examples seen increased by {len(representations_merged)}")

        ### UPDATE MEMORY ###
        if do_memory_update:
            memory_update = self.memory.update(class_means, unique_labels, logger=self.logger)
            updated_memory_representations = memory_update["new_class_representations"]
            self.log_discounts(memory_update["class_discount"], unique_labels)
        ### DETERMINE WHAT'S SEEN AS PROTOTYPE ###
        if self.config.learner.prototypes == "class_means":
            prototypes = expand_class_representations(self.memory.class_representations, class_means, unique_labels)
        elif self.config.learner.prototypes == "memory":
            prototypes = updated_memory_representations
        else:
            raise AssertionError("Prototype type not in {'class_means', 'memory'}, fix config file.")

        ### INITIALIZE LINEAR LAYER WITH PROTOYPICAL-EQUIVALENT WEIGHTS ###
        # self.init_prototypical_classifier(prototypes, linear_module=fpn.linear)
        weight = 2 * prototypes # divide by number of dimensions, otherwise blows up
        bias = - (prototypes ** 2).sum(dim=1)

        self.logger.debug("----------------- QUERY SET  ----------------- ")
        ### EVALUATE ON QUERY SET AND UPDATE ENCODER ###
        # Outer loop
        if query_set is not None:
            for text, labels in query_set:
                labels = torch.tensor(labels).to(self.device)
                query_representations = self.forward(text, labels)["representation"]

                # distance query representations to prototypes (BATCH X N_PROTOTYPES)
                # distances = euclidean_dist(query_representations, prototypes)
                # logits = - distances
                logits = query_representations @ weight.T + bias
                loss = self.loss_fn(logits, labels)
                # log_probability = F.log_softmax(-distances, dim=1)
                # loss is negation of the log probability, index using the labels for each observation
                # loss = (- log_probability[torch.arange(len(log_probability)), labels]).mean()
                self.meta_optimizer.zero_grad()
                loss.backward()
                self.meta_optimizer.step()

                predictions = model_utils.make_prediction(logits.detach())
                # predictions = torch.tensor([inv_label_map[p.item()] for p in predictions])
                # to_print = pprint.pformat(list(map(lambda x: (x[0].item(), x[1].item(),
                #                         [round(z, 3) for z in x[2].tolist()]),
                #                         list(zip(labels, predictions, distances)))))
                self.logger.debug(
                    f"Unique Labels: {unique_labels.tolist()}\n"
                    # f"Labels, Indices, Predictions, Distances:\n{to_print}\n"
                    f"Loss:\n{loss.item()}\n"
                    f"Predictions:\n{predictions}\n"
                )
                self.update_query_tracker(loss, predictions, labels)
                metrics = model_utils.calculate_metrics(predictions.tolist(), labels.tolist())
                online_metrics = {
                    "accuracy": metrics["accuracy"],
                    "examples_seen": self.examples_seen(),
                    "task": task if task is not None else "none"
                }
                self.metrics["online"].append(online_metrics)
                if task is not None and task == self.config.testing.eval_dataset and \
                    self.eval_task_first_encounter:
                    self.metrics["eval_task_first_encounter"].append(online_metrics)
                self._examples_seen += len(text)
                self.logger.debug(f"Examples seen increased by {len(text)}")

            # Meta optimizer step
            # self.meta_optimizer.step()
            # self.meta_optimizer.zero_grad()
        self.logger.debug("-------------------- TRAINING STEP END  -------------------")

    def get_representations(self, support_set, prediction_network=None):
        """
        Parameters
        ---
        support_set: List[Tuple[batch text, batch labels]]
        prediction network: pytorch module 

        Returns
        ---
        Tuple[List[Tensor], List[Int]] where the first result is the hidden representation and the second
        the labels.
        """
        representations = []
        all_labels = []
        for text, labels in support_set:
            labels = torch.tensor(labels).to(self.device)
            all_labels.extend(labels.tolist())
            # labels = labels.to(self.device)
            output = self.forward(text, labels, prediction_network=prediction_network)
            representations.append(output["representation"])
        return representations, all_labels

    def forward(self, text, labels, prediction_network=None, no_grad=False):
        if prediction_network is None:
            prediction_network = self.pn
        input_dict = self.pn.encode_text(text)
        context_manager = torch.no_grad() if no_grad else nullcontext()
        with context_manager:
            representation = prediction_network(input_dict, out_from="transformers")
            logits = prediction_network(representation, out_from="linear")
        return {"representation": representation, "logits": logits}
    
    def update_memory(self, class_means, unique_labels):
        to_update = unique_labels
        # selection of old class representations here
        old_class_representations = self.memory.class_representations[to_update]
        # if old class representations haven't been given values yet, don't bias towards 0 by exponential update
        if (old_class_representations == 0).bool().all():
            new_class_representations = class_means
        else:
            # memory update rule here
            new_class_representations = (1 - self.config.learner.class_discount) * old_class_representations + self.config.learner.class_discount * class_means
        self.logger.debug(f"Updating class representations for classes {unique_labels}.\n"
                         f"Distance old class representations and class means: {[round(z, 2) for z in (old_class_representations - class_means).norm(dim=1).tolist()]}\n"
                         f"Distance old and new class representations: {[round(z, 2) for z in (new_class_representations - old_class_representations).norm(dim=1).tolist()]}"
                         )
        # for returning new class representations while keeping gradients intact
        result = torch.clone(self.memory.class_representations)
        result[to_update] = new_class_representations
        # update memory
        self.memory.class_representations[to_update] = new_class_representations.detach()

        return result

    def init_prototypical_classifier(self, prototypes, linear_module=None):
        if linear_module is None:
            linear_module = self.pn.linear
        weight = 2 * prototypes / TRANSFORMER_HDIM # divide by number of dimensions, otherwise blows up
        bias = - (prototypes ** 2).sum(dim=1) / TRANSFORMER_HDIM
        # otherwise the bias of the classes observed in the support set is always smaller than 
        # not observed ones, which favors the unobserved ones. However, it is expected that labels
        # in the support set are more likely to be in the query set.
        bias_unchanged = bias == 0
        bias[bias_unchanged] = bias.min()
        self.logger.info(f"Prototype is zero vector for classes {bias_unchanged.nonzero(as_tuple=True)[0].tolist()}. "
                         f"Setting their bias entries to the minimum of the uninitialized bias vector.")
        # prototypical-equivalent network initialization
        linear_module.weight.data = weight
        linear_module.bias.data = bias
        self.logger.info(f"Classifier bias initialized to {bias}.")
        
        # a = mmaml.classifier.weight
        # # https://stackoverflow.com/questions/61279403/gradient-flow-through-torch-nn-parameter
        # # a = torch.nn.Parameter(torch.ones((10,)), requires_grad=True) 
        # b = a[:] # silly hack to convert in a raw tensor including the computation graph
        # # b.retain_grad() # Otherwise backward pass will not store the gradient since it is not a leaf 
        # it is necessary to do it this way to retain the gradient information on the classifier parameters
        # https://discuss.pytorch.org/t/non-leaf-variables-as-a-modules-parameters/65775
        # del self.classifier.weight
        # self.classifier.weight = 2 * prototypes
        # del self.classifier.bias
        # self.classifier.bias = bias
        # weight_copy = self.classifier.weight[:]
        # bias_copy = self.classifier.bias[:]

    def update_meta_gradients(self, loss, fpn):
        # PN meta gradients
        pn_params = [p for p in fpn.parameters() if p.requires_grad]
        meta_pn_grads = torch.autograd.grad(loss, pn_params, allow_unused=True)
        pn_params = [p for p in self.pn.parameters() if p.requires_grad]
        for param, meta_grad in zip(pn_params, meta_pn_grads):
            if meta_grad is not None:
                if param.grad is not None:
                    param.grad += meta_grad.detach()
                else:
                    param.grad = meta_grad.detach()

    def update_support_tracker(self, loss, pred, labels):
        self.tracker["support_loss"].append(loss.item())
        self.tracker["support_predictions"].extend(pred.tolist())
        self.tracker["support_labels"].extend(labels.tolist())

    def update_query_tracker(self, loss, pred, labels):
        self.tracker["query_loss"].append(loss.item())
        self.tracker["query_predictions"].extend(pred.tolist())
        self.tracker["query_labels"].extend(labels.tolist())

    def reset_tracker(self):
        self.tracker = {
            "support_loss": [],
            "support_predictions": [],
            "support_labels": [],
            "query_loss": [],
            "query_predictions": [],
            "query_labels": []
        }

    def evaluate(self, dataloader, prediction_network=None):
        # if self.config.learner.evaluation_support_set:
        #     support_set = []
        #     for _ in range(self.config.learner.updates):
        #         text, labels = self.memory.read_batch(batch_size=self.mini_batch_size)
        #         support_set.append((text, labels))

        # with higher.innerloop_ctx(self.pn, self.inner_optimizer,
        #                         copy_initial_weights=False,
        #                         track_higher_grads=False) as (fpn, diffopt):
        #     if self.config.learner.evaluation_support_set:
        #         self.set_train()
        #         support_prediction_network = fpn
        #         # Inner loop
        #         task_predictions, task_labels = [], []
        #         support_loss = []
        #         for text, labels in support_set:
        #             labels = torch.tensor(labels).to(self.device)
        #             # labels = labels.to(self.device)
        #             output = self.forward(text, labels, fpn)
        #             loss = self.loss_fn(output["logits"], labels)
        #             diffopt.step(loss)

        #             pred = model_utils.make_prediction(output["logits"].detach())
        #             support_loss.append(loss.item())
        #             task_predictions.extend(pred.tolist())
        #             task_labels.extend(labels.tolist())
        #         results = model_utils.calculate_metrics(task_predictions, task_labels)
        #         self.logger.info("Support set metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
        #                     "F1 score = {:.4f}".format(np.mean(support_loss), results["accuracy"],
        #                     results["precision"], results["recall"], results["f1"]))
        #         self.set_eval()
        #     else:
        #         support_prediction_network = self.pn
        #     if prediction_network is None:
        #         prediction_network = support_prediction_network

        self.set_eval()
        prototypes = self.memory.class_representations
        weight = 2 * prototypes
        bias = - (prototypes ** 2).sum(dim=1)
        all_losses, all_predictions, all_labels = [], [], []
        for i, (text, labels, _) in enumerate(dataloader):
            labels = torch.tensor(labels).to(self.device)
            representations = self.forward(text, labels)["representation"]
            logits = representations @ weight.T + bias
            # labels = labels.to(self.device)
            loss = self.loss_fn(logits, labels)
            loss = loss.item()
            pred = model_utils.make_prediction(logits.detach())
            all_losses.append(loss)
            all_predictions.extend(pred.tolist())
            all_labels.extend(labels.tolist())

        results = model_utils.calculate_metrics(all_predictions, all_labels)
        self.logger.debug("Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
                    "F1 score = {:.4f}".format(np.mean(all_losses), results["accuracy"],
                    results["precision"], results["recall"], results["f1"]))
        return results

    def model_state(self):
        return {"pn": self.pn.state_dict()}

    def optimizer_state(self):
        return self.meta_optimizer.state_dict()

    def load_model_state(self, checkpoint):
        self.pn.load_state_dict(checkpoint["model_state"]["pn"])

    def load_optimizer_state(self, checkpoint):
        self.meta_optimizer.load_state_dict(checkpoint["optimizer"])

    def save_other_state_information(self, state):
        """Any learner specific state information is added here"""
        state["memory"] = self.memory
        return state

    def load_other_state_information(self, checkpoint):
        self.memory = checkpoint["memory"]

    def set_eval(self):
        self.pn.eval()

    def set_train(self):
        self.pn.train()

    def few_shot_testing(self, train_dataset, eval_dataset, increment_counters=False, split="test"):
        """
        Allow the model to train on a small amount of datapoints at a time. After every training step,
        evaluate on many samples that haven't been seen yet.

        Results are saved in learner's `metrics` attribute.

        Parameters
        ---
        train_dataset: Dataset
            Contains examples on which the model is trained before being evaluated
        eval_dataset: Dataset
            Contains examples on which the model is evaluated
        increment_counters: bool
            If True, update online metrics and current iteration counters.
        """
        self.logger.info(f"few shot testing on dataset {self.config.testing.eval_dataset} "
                         f"with {len(train_dataset)} samples")
        train_dataloader, eval_dataloader = self.few_shot_preparation(train_dataset, eval_dataset, split=split)
        all_predictions, all_labels = [], []
        def add_none(iterator):
            yield None
            for x in iterator:
                yield x
        shifted_dataloader = add_none(train_dataloader)
        # prototypes = self.memory.class_representations
        for i, (support_set, (query_text, query_labels, datasets)) in enumerate(zip(shifted_dataloader, train_dataloader)):
            query_labels = torch.tensor(query_labels).to(self.device)
            # happens on the first one
            # prototypes = self.memory.class_representations
            if support_set is None:
                prototypes = self.memory.class_representations
            else:
                support_text, support_labels, _ = support_set
                support_labels = torch.tensor(support_labels).to(self.device)
                support_representations = self.forward(support_text, support_labels)["representation"]
                support_class_means, unique_labels = model_utils.get_class_means(support_representations, support_labels)
                memory_update = self.memory.update(support_class_means, unique_labels, logger=self.logger)
                updated_memory_representations = memory_update["new_class_representations"]
                self.log_discounts(memory_update["class_discount"], unique_labels,
                                    few_shot_examples_seen=(i+1) * self.config.testing.few_shot_batch_size)
                prototypes = updated_memory_representations
                if self.config.learner.few_shot_detach_prototypes:
                    prototypes = prototypes.detach()
            weight = 2 * prototypes
            bias = - (prototypes ** 2).sum(dim=1)
            query_representations = self.forward(query_text, query_labels)["representation"]
            logits = query_representations @ weight.T + bias

            # new part
            # q_norm = query_representations.norm(dim=1).unsqueeze(-1)
            # c_norm = prototypes.norm(dim=1).unsqueeze(-1)
            # c_norm[c_norm == 0] = 1
            # q_norm[q_norm == 0] = 1

            # query_representations_normalized = query_representations / q_norm
            # class_representations_normalized = prototypes / c_norm
            # dists = model_utils.euclidean_dist(query_representations_normalized, class_representations_normalized)
            # dists = model_utils.euclidean_dist(query_representations, prototypes)
            # memory loss
            # memory_loss = dists[torch.arange(len(query_representations)), query_labels].mean()

            # loss = (1 - memory_loss_weight) * self.loss_fn(logits, query_labels) + memory_loss_weight * 
            # cross_entropy_loss = self.loss_fn(logits, query_labels)
            # loss = cross_entropy_loss + memory_loss
            loss = self.loss_fn(logits, query_labels)
            # self.logger.debug(f"Memory loss: {memory_loss} -- Cross Entropy Loss: {cross_entropy_loss}")

            self.meta_optimizer.zero_grad()
            loss.backward()
            self.meta_optimizer.step()

            predictions = model_utils.make_prediction(logits.detach())
            all_predictions.extend(predictions.tolist())
            all_labels.extend(query_labels.tolist())
            dataset_results = self.evaluate(dataloader=eval_dataloader)
            self.log_few_shot(all_predictions, all_labels, datasets, dataset_results,
                                increment_counters, query_text, i, split=split)
            if (i * self.config.testing.few_shot_batch_size) % self.mini_batch_size == 0 and i > 0:
                all_predictions, all_labels = [], []
        self.few_shot_end()

    def log_discounts(self, class_discount, unique_labels, few_shot_examples_seen=None):
        prefix = f"few_shot_{self.few_shot_counter}_" if few_shot_examples_seen is not None else ""
        discounts = {prefix + "class_discount": {}}
        if not isinstance(class_discount, float) and not isinstance(class_discount, int):
            for l, discount in zip(unique_labels, class_discount):
                discounts[prefix + "class_discount"][f"Class {l}"] = discount.item()
        else:
            for l in unique_labels:
                discounts[prefix + "class_discount"][f"Class {l}"] = float(class_discount)
        for l in range(self.config.data.n_classes):
            if l not in unique_labels:
                discounts[prefix + "class_discount"][f"Class {l}"] = 0
        discounts["examples_seen"] = few_shot_examples_seen if few_shot_examples_seen is not None else self.examples_seen()
        if "class_discount" not in self.metrics:
            self.metrics["class_discount"] = []
        self.metrics["class_discount"].append(discounts)
        if few_shot_examples_seen is not None:
            self.logger.debug("Logging class discounts")
            self.logger.debug(f"Examples seen: {discounts['examples_seen']}")
        if self.config.wandb:
            wandb.log(discounts)
Esempio n. 14
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)

    if args.resume_training:
        optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'optimizer.pt')))
        scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'scheduler.pt')))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                          output_device=args.local_rank,
                                                          find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
                   args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    if args.resume_training:
        global_step = torch.load(os.path.join(args.model_name_or_path, "steps.pt"))["global_step"]

    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
    set_seed(args)  # Added here for reproductibility (even between python 2 and 3)
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            if args.model_type not in ['distilbert', "roberta"]:
                inputs = {'input_ids':      batch[0],
                          'attention_mask': batch[1],
                          'token_type_ids': batch[2],
                          'labels':         batch[3]}
            else:
                inputs = {'input_ids':      batch[0],
                          'attention_mask': batch[1],
                          'token_type_ids': None, 
                          'labels':         batch[2]}
            # inputs = {'input_ids':      batch[0],
            #           'attention_mask': batch[1],
            #           'labels':         batch[3]}
            # if args.model_type != 'distilbert':
            #     inputs['token_type_ids'] = batch[2] if args.model_type in ['bert', 'xlnet'] else None  # XLM, DistilBERT and RoBERTa don't use segment_ids
            outputs = model(**inputs)
            loss = outputs[0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean() # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0 and not args.tpu:
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
                    tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    
                    # save training states
                    torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer.pt'))
                    torch.save(scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt'))
                    torch.save({
                        "global_step": global_step,
                    }, os.path.join(output_dir, 'steps.pt'))

                    torch.save(args, os.path.join(output_dir, 'training_args.bin'))
                    logger.info("Saving model checkpoint to %s", output_dir)

            if args.tpu:
                args.xla_model.optimizer_step(optimizer, barrier=True)
                model.zero_grad()
                global_step += 1

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
Esempio n. 15
0
def train(args, model, tokenizer, train_dataloader, eval_during_training=False):
    """ Fine-tune the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0
        },
    ]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )

    # Check if saved optimizer or scheduler states exist
    if (
            os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt"))
            and os.path.isfile(os.path.join(args.model_name_or_path, "scheduler.pt"))
    ):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True
        )

    # Train
    logger.info("***** Running training *****")
    logger.info("  Num samples = %d", len(train_dataloader.dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.per_gpu_train_batch_size
        * max(1, args.n_gpu)
        * args.gradient_accumulation_steps
        * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to gobal_step of last saved checkpoint from model path
        global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
        epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
        steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)

        logger.info("  Continuing training from checkpoint, will skip to saved global_step")
        logger.info("  Continuing training from epoch %d", epochs_trained)
        logger.info("  Continuing training from global step %d", global_step)
        logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)

    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(
        epochs_trained, args.num_train_epochs, desc="Epoch", disable=args.local_rank not in [-1, 0]
    )
    set_seed(args)  # Added here for reproducibility
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model.train()
            loss = forward_gloss_selection(args, model, batch)[0]

            if args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    logs = {}
                    if args.local_rank == -1 and eval_during_training:
                        # Only evaluate when single GPU otherwise metrics may not average well
                        logs["eval_loss"] = evaluate(args, model, tokenizer, global_step)

                    loss_scalar = (tr_loss - logging_loss) / args.logging_steps
                    learning_rate_scalar = scheduler.get_lr()[0]
                    logs["learning_rate"] = learning_rate_scalar
                    logs["loss"] = loss_scalar
                    logging_loss = tr_loss

                    for key, value in logs.items():
                        tb_writer.add_scalar(key, value, global_step)

                    with open(os.path.join(args.output_dir, "train_log.txt"), 'a+') as f:
                        print(json.dumps({**logs, **{"step": global_step}}), file=f)

                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args, os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s", output_dir)

            if 0 < args.max_steps < global_step:
                epoch_iterator.close()
                break

        if 0 < args.max_steps < global_step:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
Esempio n. 16
0
def prepare_for_training(args, model, checkpoint_state_dict, amp):
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)

    if checkpoint_state_dict:
        optimizer.load_state_dict(checkpoint_state_dict['optimizer'])
        model.load_state_dict(checkpoint_state_dict['model'])

        # then remove optimizer state to make amp happy
        # https://github.com/NVIDIA/apex/issues/480#issuecomment-587154020
        if amp:
            optimizer.state = {}

    if amp:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)
        if checkpoint_state_dict:
            amp.load_state_dict(checkpoint_state_dict['amp'])

            # Black Tech from https://github.com/NVIDIA/apex/issues/480#issuecomment-587154020
            # forward, backward, optimizer step, zero_grad
            random_input = {
                'source_ids':
                torch.ones(size=(2, 2), device=args.device, dtype=torch.long),
                'target_ids':
                torch.ones(size=(2, 2), device=args.device, dtype=torch.long),
                'label_ids':
                torch.ones(size=(2, 2), device=args.device, dtype=torch.long),
                'pseudo_ids':
                torch.ones(size=(2, 2), device=args.device, dtype=torch.long),
                'num_source_tokens':
                torch.zeros(size=(2, ), device=args.device, dtype=torch.long),
                'num_target_tokens':
                torch.zeros(size=(2, ), device=args.device, dtype=torch.long)
            }
            loss = model(**random_input)
            print("Loss = %f" % loss.cpu().item())
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            optimizer.step()
            model.zero_grad()

            # then load optimizer state_dict again (this time without removing optimizer.state)
            optimizer.load_state_dict(checkpoint_state_dict['optimizer'])

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    return model, optimizer
Esempio n. 17
0
def train(args, train_dataset, model, tokenizer, input_file_name=None, multiple_shards=False, init_global_step=0, init_epochs_trained=0):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter(logdir=args.tensorboard_dir)

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    if args.sort_by_length:
        train_sampler = LengthSortedSampler(train_dataset, batch_size=args.train_batch_size*args.gradient_accumulation_steps, shuffle=True)
    else:
        train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=train_dataset.collate_fn)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)

    if args.scheduler == 'linear':
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
    elif args.scheduler == 'transformer':
        if args.model_type == 'bert':
            dimension = model.config.hidden_size
        elif args.model_type == 'gpt2':
            dimension = model.config.n_embd
        else:
            logger.error('Cannot detect hidden size dimensions in this model type. Config: %s', model.config)
        scheduler = get_transformer_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total, dimension=dimension)
    else:
        logger.error('Unknown scheduler type.')

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')) and os.path.isfile(os.path.join(args.model_name_or_path, 'scheduler.pt')):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'optimizer.pt')))
        scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'scheduler.pt')))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                          output_device=args.local_rank,
                                                          find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Input file name = %s", input_file_name)
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
                   args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = init_global_step
    epochs_trained = init_epochs_trained
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to gobal_step of last saved checkpoint from model path
        global_step = int(args.model_name_or_path.split('-')[-1].split('/')[0])
        epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
        steps_trained_in_current_epoch = (global_step % (len(train_dataloader) // args.gradient_accumulation_steps)) * args.gradient_accumulation_steps

        logger.info("  Continuing training from checkpoint, will skip to saved global_step")
        logger.info("  Continuing training from epoch %d", epochs_trained)
        logger.info("  Continuing training from global step %d", global_step)
        logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)

    tr_loss, logging_loss = 0, 0

    model_to_resize = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
    model_to_resize.resize_token_embeddings(len(tokenizer))

    model.zero_grad()
    if multiple_shards:
        train_iterator = prange(epochs_trained, 1, desc="Epoch", disable=args.local_rank not in [-1, 0])
    else:
        train_iterator = prange(epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
    set_seed(args)  # Added here for reproducibility (even between python 2 and 3)
    best_eval_perplexity = float('Inf')
    for _ in train_iterator:
        if args.max_steps > 0 and not multiple_shards:
            total_steps = args.max_steps*args.gradient_accumulation_steps
        else:
            total_steps = len(train_dataloader)
        epoch_iterator = progress_bar(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0], total=total_steps)
        for step, batch in enumerate(epoch_iterator):
            
            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            inputs, attention_mask, labels, position_ids, segment_ids = batch
            
            if args.mlm:
                inputs, labels = mask_tokens(inputs, labels, tokenizer, args.mlm_probability, args.mlm_ignore_index)
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            attention_mask = attention_mask.to(args.device)
            position_ids = position_ids.to(args.device)
            segment_ids = segment_ids.to(args.device)
            model.train()
            
            model_inputs = {'input_ids': inputs, 'use_cache': False}
            
            # prepare inputs for mbart, and marian
            if args.model_type in ['mbart', 'marian']:
                model_inputs['attention_mask'] = attention_mask
                decoder_input_ids = shift_tokens_right(labels, args.mlm_ignore_index)
                decoder_input_ids[decoder_input_ids == args.mlm_ignore_index] = tokenizer.pad_token_id
                model_inputs['decoder_input_ids'] = decoder_input_ids
            elif args.model_type == 'bart':
                # TODO according to huggingface bart should also use shift_tokens_right
                # check if that change affects results
                model_inputs['attention_mask'] = attention_mask
                decoder_input_ids = labels.contiguous()
                decoder_input_ids[decoder_input_ids == args.mlm_ignore_index] = tokenizer.pad_token_id
                model_inputs['decoder_input_ids'] = decoder_input_ids
            else:
                model_inputs.update({'position_ids': position_ids, 'token_type_ids': segment_ids})

            outputs = model(**model_inputs)
            lm_logits = outputs.logits.contiguous()
            assert lm_logits.shape[-1] == model.config.vocab_size
            
            # CrossEntropyLoss ignore_index defaults to -100
            # If a different mlm_ignore_index is provided we make sure it is ignored when calculating the loss
            ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=args.mlm_ignore_index)
            loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1))

            if args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

            if args.local_rank in [-1, 0] and ((args.logging_steps > 0 and global_step % args.logging_steps == 0 and global_step != 0) or step == total_steps-1):
                # Log metrics
                if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                    results = evaluate(args, model, tokenizer)
                    if args.aux_eval_data_file is not None:
                        aux_results = evaluate(args, model, tokenizer, aux=True)
                        for key, value in aux_results.items():
                            tb_writer.add_scalar('auxiliary_eval_{}'.format(key), value, global_step)
                    if best_eval_perplexity > results['perplexity']:
                        best_eval_perplexity = results['perplexity']
                        if not os.path.exists(args.output_dir):
                            os.makedirs(args.output_dir)
                        logger.info("Saving new best model to %s", args.output_dir)
                        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
                        # They can then be reloaded using `from_pretrained()`
                        model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
                        model_to_save.save_pretrained(args.output_dir)
                        tokenizer.save_pretrained(args.output_dir)

                        # Good practice: save your training arguments together with the trained model
                        torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))

                    for key, value in results.items():
                        tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
                # TODO add generated text to tensorboard
                # tb_writer.add_text('eval/generated_text', gen_text, global_step)
                tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
                tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
                logging_loss = tr_loss

            if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0 and global_step != 0 and args.save_total_limit > 0:
                checkpoint_prefix = 'checkpoint'
                # Save model checkpoint
                output_dir = os.path.join(args.output_dir, '{}-{}'.format(checkpoint_prefix, global_step))
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
                model_to_save.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)

                torch.save(args, os.path.join(output_dir, 'training_args.bin'))
                logger.info("Saving model checkpoint to %s", output_dir)

                _rotate_checkpoints(args, checkpoint_prefix)

                torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer.pt'))
                torch.save(scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt'))
                logger.info("Saving optimizer and scheduler states to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss
Esempio n. 18
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    record_result = []

    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]

    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt")):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to global_step of last saved checkpoint from model path
        try:
            global_step = int(
                args.model_name_or_path.split("-")[-1].split("/")[0])
        except ValueError:
            global_step = 0
        epochs_trained = global_step // (len(train_dataloader) //
                                         args.gradient_accumulation_steps)
        steps_trained_in_current_epoch = global_step % (
            len(train_dataloader) // args.gradient_accumulation_steps)

        logger.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logger.info("  Continuing training from epoch %d", epochs_trained)
        logger.info("  Continuing training from global step %d", global_step)
        logger.info("  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)

    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(
        epochs_trained,
        int(args.num_train_epochs),
        desc="Epoch",
        disable=args.local_rank not in [-1, 0],
    )
    set_seed(args)  # Added here for reproductibility
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "labels": batch[3]
            }
            if args.model_type != "distilbert":
                inputs["token_type_ids"] = (
                    batch[2]
                    if args.model_type in ["bert", "xlnet", "albert"] else None
                )  # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
            outputs = model(**inputs)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
                    len(epoch_iterator) <= args.gradient_accumulation_steps and
                (step + 1) == len(epoch_iterator)):
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    logs = {}
                    if (
                            args.local_rank == -1
                            and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        record_result.append(results)
                        for key, value in results.items():
                            eval_key = "eval_{}".format(key)
                            logs[eval_key] = value

                    loss_scalar = (tr_loss - logging_loss) / args.logging_steps
                    learning_rate_scalar = scheduler.get_lr()[0]
                    logs["learning_rate"] = learning_rate_scalar
                    logs["loss"] = loss_scalar
                    logging_loss = tr_loss

                    for key, value in logs.items():
                        tb_writer.add_scalar(key, value, global_step)
                    print(json.dumps({**logs, **{"step": global_step}}))

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)
                    torch.save(model, os.path.join(output_dir, "model.pt"))

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    results = evaluate(args, model, tokenizer)
    record_result.append(results)
    torch.save(record_result, os.path.join(args.output_dir, "result.pt"))

    return global_step, tr_loss / global_step
Esempio n. 19
0
def train(args, train_dataset, model, tokenizer, lang2id=None):
    """Train the model."""
    if args.local_rank in [-1, 0]:
        tb_train = SummaryWriter(
            log_dir=os.path.join(args.output_dir, "train"))
        if args.save_only_best_checkpoint:
            tb_valid = SummaryWriter(
                log_dir=os.path.join(args.output_dir, "valid"))

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt")):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to gobal_step of last saved checkpoint from model path
        global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
        epochs_trained = global_step // (len(train_dataloader) //
                                         args.gradient_accumulation_steps)
        steps_trained_in_current_epoch = global_step % (
            len(train_dataloader) // args.gradient_accumulation_steps)

        logger.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logger.info("  Continuing training from epoch %d", epochs_trained)
        logger.info("  Continuing training from global step %d", global_step)
        logger.info("  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)

    best_score = 0
    best_checkpoint = None
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(args)  # Added here for reproductibility
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "labels": batch[3]
            }
            if args.model_type != "distilbert":
                inputs["token_type_ids"] = (
                    batch[2] if args.model_type in ["bert"] else None
                )  # XLM don't use segment_ids
            if args.model_type == "xlm":
                inputs["langs"] = batch[4]
            outputs = model(**inputs)
            loss = outputs[0]

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    tb_train.add_scalar("lr",
                                        scheduler.get_lr()[0], global_step)
                    tb_train.add_scalar("loss", (tr_loss - logging_loss) /
                                        args.logging_steps, global_step)
                    logging_loss = tr_loss

                    # Only evaluate on single GPU otherwise metrics may not average well
                    if (args.local_rank == -1
                            and args.evaluate_during_training):
                        results = evaluate(args,
                                           model,
                                           tokenizer,
                                           split=args.train_split,
                                           language=args.train_language,
                                           lang2id=lang2id)
                        for key, value in results.items():
                            tb_train.add_scalar("eval_{}".format(key), value,
                                                global_step)

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    if args.eval_test_set:
                        output_predict_file = os.path.join(
                            args.output_dir, 'eval_test_results')
                        total = total_correct = 0.0
                        with open(output_predict_file, 'a') as writer:
                            writer.write(
                                '\n======= Predict using the model from checkpoint-{}:\n'
                                .format(global_step))
                            for language in args.predict_languages.split(','):
                                result = evaluate(args,
                                                  model,
                                                  tokenizer,
                                                  split=args.test_split,
                                                  language=language,
                                                  lang2id=lang2id,
                                                  prefix='checkpoint-' +
                                                  str(global_step))
                                writer.write('{}={}\n'.format(
                                    language, result['acc']))
                                total += result['num']
                                total_correct += result['correct']
                            writer.write('total={}\n'.format(total_correct /
                                                             total))

                    if args.save_only_best_checkpoint:
                        result = evaluate(args,
                                          model,
                                          tokenizer,
                                          split='dev',
                                          language=args.train_language,
                                          lang2id=lang2id,
                                          prefix=str(global_step))
                        for key, value in result.items():
                            tb_valid.add_scalar("eval_{}".format(key), value,
                                                global_step)
                        logger.info(" Dev accuracy {} = {}".format(
                            args.train_language, result['acc']))
                        if result['acc'] > best_score:
                            logger.info(
                                " result['acc']={} > best_score={}".format(
                                    result['acc'], best_score))
                            output_dir = os.path.join(args.output_dir,
                                                      "checkpoint-best")
                            best_checkpoint = output_dir
                            best_score = result['acc']
                            # Save model checkpoint
                            if not os.path.exists(output_dir):
                                os.makedirs(output_dir)
                            model_to_save = (
                                model.module
                                if hasattr(model, "module") else model
                            )  # Take care of distributed/parallel training
                            model_to_save.save_pretrained(output_dir)
                            tokenizer.save_pretrained(output_dir)

                            torch.save(
                                args,
                                os.path.join(output_dir, "training_args.bin"))
                            logger.info("Saving model checkpoint to %s",
                                        output_dir)

                            torch.save(
                                optimizer.state_dict(),
                                os.path.join(output_dir, "optimizer.pt"))
                            torch.save(
                                scheduler.state_dict(),
                                os.path.join(output_dir, "scheduler.pt"))
                            logger.info(
                                "Saving optimizer and scheduler states to %s",
                                output_dir)
                    else:
                        # Save model checkpoint
                        output_dir = os.path.join(
                            args.output_dir,
                            "checkpoint-{}".format(global_step))
                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
                        model_to_save = (
                            model.module if hasattr(model, "module") else model
                        )  # Take care of distributed/parallel training
                        model_to_save.save_pretrained(output_dir)
                        tokenizer.save_pretrained(output_dir)

                        torch.save(
                            args, os.path.join(output_dir,
                                               "training_args.bin"))
                        logger.info("Saving model checkpoint to %s",
                                    output_dir)

                        torch.save(optimizer.state_dict(),
                                   os.path.join(output_dir, "optimizer.pt"))
                        torch.save(scheduler.state_dict(),
                                   os.path.join(output_dir, "scheduler.pt"))
                        logger.info(
                            "Saving optimizer and scheduler states to %s",
                            output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_train.close()
        if args.save_only_best_checkpoint:
            tb_valid.close()

    # Save final model checkpoint at end of training
    output_dir = os.path.join(args.output_dir, "checkpoint-training-end")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    model_to_save = (model.module if hasattr(model, "module") else model
                     )  # Take care of distributed/parallel training
    model_to_save.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    torch.save(args, os.path.join(output_dir, "training_args.bin"))
    logger.info("Saving model checkpoint to %s", output_dir)
    torch.save(optimizer.state_dict(), os.path.join(output_dir,
                                                    "optimizer.pt"))
    torch.save(scheduler.state_dict(), os.path.join(output_dir,
                                                    "scheduler.pt"))
    logger.info("Saving optimizer and scheduler states to %s", output_dir)

    return global_step, tr_loss / global_step, best_score, best_checkpoint
def main():
    parser = argparse.ArgumentParser()
    ## Required parameters
    parser.add_argument("--data_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--bert_model", default=None, type=str, required=True,
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
                             "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
                             "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    ## Other parameters
    parser.add_argument("--max_seq_length",
                        default=128,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_test",
                        action='store_true',
                        help="Whether to run eval on the test set.")
    parser.add_argument("--do_tagging",
                        action='store_true',
                        help="Whether to run eval on the tagging set.")
    parser.add_argument("--do_lower_case",
                        action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--sel_prob",
                        default=0.5,
                        type=float,
                        help="The select prob for each word when pretraining.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--from_tf",
                        action='store_true',
                        help="from_tf")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument('--train_type',
                        type=str,
                        default="pretrain",
                        help="type of train")
    parser.add_argument('--fp16',
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help="Where do you want to store the pre-trained models downloaded from s3",
    )
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list",
    )
    parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
    # $$$ 此处修改
    parser.add_argument('--train_continue',
                        action='store_true',
                        help="continue train from the saved model")
    parser.add_argument('--saved_model_path',
                        type=str,
                        default=None,
                        help="The path of saved model trained before.")
    parser.add_argument('--use_new_model',
                        action='store_true',
                        help="read_feature_from_cache.")
    parser.add_argument('--feature_path',
                        type=str,
                        default=None,
                        help="The path of feature saved.")
    parser.add_argument("--pretrain_model_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The  directory where the pretrain model comes")
    parser.add_argument("--pretrain_model_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The  directory where the pretrain model comes")
    parser.add_argument("--thre",
                        default=0.5,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    # $$$
    args = parser.parse_args()

    # $$$ 改掉了task name
    processors = {
        "disfluency": DisfluencyProcessor,
    }

    # $$$ 改掉了task name
    num_labels_task = {
        "disfluency": 2,
    }

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
            args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError("At least one of `do_train` or `do_eval` must be True.")

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    num_labels = num_labels_task[task_name]
    label_list = processor.get_labels()
    label_disf_list = processor.get_labels_disf()
    label_sing_list = processor.get_single_disf()
    num_labels_tagging = len(label_disf_list)
    pretrained = args.model_name_or_path

    tokenizer = BertTokenizer.from_pretrained(
        pretrained,
        do_lower_case=args.do_lower_case,
        cache_dir=args.cache_dir if args.cache_dir else None
    )
    train_examples = None
    num_train_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples(args.data_dir)
        num_train_steps = int(
            len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)

    # Prepare model
    state = None
    train_continued = False
    if args.train_continue:
        train_continued = True
    # if train_continued:
    #     state = torch.load(args.saved_model_path)
    #     model = BertForSequenceDisfluency.from_pretrained(args.bert_model,
    #                                                       state_dict=state['state_dict'],
    #                                                       num_labels=num_labels,
    #                                                       num_labels_tagging=num_labels_tagging)
    #     last_epoch = state['epoch']
    # else:
    #     config = ElectraConfig.from_pretrained(
    #         pretrained,
    #         num_labels=num_labels,
    #         finetuning_task=args.task_name,
    #         cache_dir=args.cache_dir if args.cache_dir else None,
    #     )
    #     model = ElectraForSequenceDisfluency_sing.from_pretrained(
    #         pretrained,
    #         config=config,
    #         cache_dir=args.cache_dir if args.cache_dir else None,
    #         num_labels=num_labels, num_labels_tagging=num_labels_tagging
    #     )
    if args.use_new_model:
        logger.info("对了")
        # new_model_file = os.path.join(args.pretrain_model_dir, "pytorch_model.bin")
        new_model_file = os.path.join(args.pretrain_model_dir, args.pretrain_model_name)
        logger.info("use pretrain model {}".format(new_model_file))
        state = torch.load(new_model_file)
        config = ElectraConfig.from_pretrained(
            pretrained,
            num_labels=num_labels,
            finetuning_task=args.task_name,
            cache_dir=args.cache_dir if args.cache_dir else None,
        )
        model = ElectraForSequenceDisfluency_sing.from_pretrained(
            pretrained,
            config=config,
            cache_dir=args.cache_dir if args.cache_dir else None,
            state_dict=state['state_dict'],
            num_labels=num_labels, num_labels_tagging=num_labels_tagging
        )
    else:
        logger.info(
            "失败了失败了失败了失败了失败了失败了失败了失败了失败了失败了失败了失败了失败了失败了失败了失败了失败了失败了失败了失败了失败了失败了失败了失败了失败了失败了失败了失败了失败了失败了失败了失败了失败了失败了失败了失败了失败了失败了失败了失败了失败了")
        config = ElectraConfig.from_pretrained(
            pretrained,
            num_labels=num_labels,
            finetuning_task=args.task_name,
            cache_dir=args.cache_dir if args.cache_dir else None,
        )
        model = ElectraForSequenceDisfluency_sing.from_pretrained(
            pretrained,
            config=config,
            cache_dir=args.cache_dir if args.cache_dir else None,
            num_labels=num_labels, num_labels_tagging=num_labels_tagging
        )
    # if args.fp16:
    #     model.half()
    model.to(device)


    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    t_total = num_train_steps
    if args.local_rank != -1:
        t_total = t_total // torch.distributed.get_world_size()
    if args.fp16:
        from apex import amp
        # optimizer = BertAdam(optimizer_grouped_parameters,
        #                      lr=args.learning_rate,
        #                      warmup=args.warmup_proportion,
        #                      t_total=t_total)
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate)
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

        logger.info("amp handle init done!")
        # except ImportError:
        #     raise ImportError(
        #         "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
    else:
        optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
        if args.train_continue:
            optimizer.load_state_dict(state['optimizer'])

    # if args.local_rank != -1:
    #     try:
    #         from apex.parallel import DistributedDataParallel as DDP
    #     except ImportError:
    #         raise ImportError(
    #             "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
    #
    #     model = DDP(model)
    if n_gpu > 1:
        model = torch.nn.DataParallel(model)

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0

    if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        eval_examples = processor.get_pseudo_examples(args.data_dir)
        eval_features = convert_examples_to_features(
            eval_examples, label_list, label_disf_list, label_sing_list, args.max_seq_length, tokenizer, args.sel_prob,
            "dev")
        logger.info("***** Running evaluation on dev of epoch *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
        all_label_disf_ids = torch.tensor([f.label_disf_id for f in eval_features], dtype=torch.long)
        all_label_sing_ids = torch.tensor([f.label_sing_id for f in eval_features], dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids,
                                  all_label_disf_ids, all_label_sing_ids)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.train_batch_size)

        model.eval()

        logits_total = list()
        input_ids_total = list()

        for input_ids, input_mask, segment_ids, label_ids, label_disf_ids, label_sing_ids in tqdm(eval_dataloader,
                                                                                                  desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)
            label_disf_ids = label_disf_ids.to(device)
            label_sing_ids = label_sing_ids.to(device)

            with torch.no_grad():
                tmp_eval_loss = model(input_ids=input_ids,
                                      token_type_ids=segment_ids,
                                      attention_mask=input_mask,
                                      labels_tagging=label_disf_ids)
                # &&&& 增加了sing的logits
                logits_pair, logits_tagging, logits_sing = model(input_ids=input_ids, token_type_ids=segment_ids
                                                                 , attention_mask=input_mask)

            # &&&&
            logits_sing = F.softmax(logits_sing, dim=-1)
            logits_sing = logits_sing.detach().cpu().numpy()
            logits_total.extend(logits_sing)
            input_ids_total.extend(input_ids.detach().cpu().numpy())

        # $$$
        output_eval_file = os.path.join(args.data_dir,
                                        "single_logits.tsv")
        with open(output_eval_file, "w") as f:
            for i in range(len(logits_total)):
                f.write(str(logits_total[i][0])+"\t"+str(logits_total[i][1])+"\n")

        with open(os.path.join(args.data_dir, "pseudo.tsv"), 'r', encoding='utf8') as f:
            t = f.readlines()

        with open(os.path.join(args.data_dir, "train.tsv"), 'w', encoding='utf8') as fw:
            for i in range(len(logits_total)):
                if logits_total[i][1] > args.thre:
                    fw.write(t[i].strip()+"\n")

        with open(os.path.join(args.data_dir, "single_input.tsv"), 'w', encoding='utf8') as fw:
            for e in eval_examples:
                fw.write(e.text_a+"\n")


        print("done!")
Esempio n. 21
0
class AdmissibleCommandTrainer:
    def __init__(self, params):

        self.template_coeff = params.template_coeff
        self.object1_coeff = params.object1_coeff
        self.object2_coeff = params.object2_coeff
        self.max_grad_norm = params.max_grad_norm
        self.verbose = params.verbose
        self.print_freq = params.print_freq
        self.checkpoint_dir = params.checkpoint_dir
        self.experiment_name = params.experiment_name
        self.num_train_epochs = params.end_epoch - params.start_epoch
        self.start_epoch = params.start_epoch
        self.end_epoch = params.end_epoch
        self.batch_size = params.batch_size

        self.template2id = eval(open(params.template_file, 'r').read())
        self.id2template = {v: idx for idx, v in self.template2id.items()}

        self.object2id = eval(open(params.object_file, 'r').read())
        self.id2object = {v: idx for idx, v in self.object2id.items()}

        wandb.init(project=params.project_name,
                   name=params.experiment_name,
                   config=params,
                   resume=True if params.load_checkpoint else False)

        self.dataset = AdmissibleCommandsClassificationDataset(
            data_file=params.data_file,
            template2id=self.template2id,
            object2id=self.object2id,
            max_seq_length=params.max_seq_len,
            bert_model_type=params.bert_model_type)

        self.command_generator = AdmissibleCommandGenerator(
            num_templates=len(self.template2id),
            num_objects=len(self.object2id),
            embedding_size=params.embedding_size,
            state_hidden_size=params.bert_hidden_size,
            state_transformed_size=params.bert_transformation_size).to(device)

        if params.load_checkpoint:
            checkpoint_path = '{}/{}/Epoch{}/'.format(params.checkpoint_dir,
                                                      params.experiment_name,
                                                      params.start_epoch - 1)

            self.bert = BertModel.from_pretrained(checkpoint_path).to(device)
            self.dataset.tokenizer = BertTokenizer.from_pretrained(
                checkpoint_path)

            with open(checkpoint_path + 'checkpoint.pth', "rb") as f:
                model_dict = torch.load(
                    f,
                    map_location=torch.device(
                        'cuda:0' if torch.cuda.is_available() else 'cpu'))
                self.command_generator.load_state_dict(model_dict['ac_model'])

            train_indices = np.loadtxt(params.log_dir + '/' +
                                       params.experiment_name +
                                       "/train_idx.txt").astype(int)
            val_indices = np.loadtxt(params.log_dir + '/' +
                                     params.experiment_name +
                                     "/val_idx.txt").astype(int)
            test_indices = np.loadtxt(params.log_dir + '/' +
                                      params.experiment_name +
                                      "/test_idx.txt").astype(int)

        else:
            self.bert = BertModel.from_pretrained(
                params.bert_model_type).to(device)

            train_indices, val_indices = self.get_train_valid_test_split(
                len(self.dataset), 0.2)
            test_indices = np.random.choice(train_indices,
                                            int(len(train_indices) * 0.2),
                                            replace=False)

            save_dir = params.log_dir + '/' + params.experiment_name
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            np.savetxt(save_dir + "/train_idx.txt", np.array(train_indices))
            np.savetxt(save_dir + "/val_idx.txt", np.array(val_indices))
            np.savetxt(save_dir + "/test_idx.txt", np.array(test_indices))

        #for debugging
        # train_indices = train_indices[:1*self.batch_size]
        # val_indices = val_indices[:1*self.batch_size]
        # test_indices = test_indices[:1]

        print("Number of Datapoints Train: {}, Val: {}, Test: {}".format(
            len(train_indices), len(val_indices), len(test_indices)))

        train_sampler = SubsetRandomSampler(train_indices)
        valid_sampler = SubsetRandomSampler(val_indices)
        test_sampler = SubsetRandomSampler(test_indices)

        self.train_dataloader = DataLoader(self.dataset,
                                           batch_size=params.batch_size,
                                           sampler=train_sampler,
                                           drop_last=True)
        self.valid_dataloader = DataLoader(self.dataset,
                                           batch_size=params.batch_size,
                                           sampler=valid_sampler,
                                           drop_last=True)
        self.test_dataloader = DataLoader(self.dataset,
                                          batch_size=params.batch_size,
                                          sampler=test_sampler,
                                          drop_last=True)

        t_total = len(
            self.train_dataloader
        ) // params.gradient_accumulation_steps * self.num_train_epochs
        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [{
            "params": [
                p for n, p in self.bert.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            params.weight_decay,
        }, {
            "params": [
                p for n, p in self.bert.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        }, {
            "params":
            self.command_generator.parameters()
        }]

        self.optimizer = AdamW(optimizer_grouped_parameters,
                               lr=params.lr,
                               eps=params.adam_epsilon)
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=params.warmup_steps,
            num_training_steps=t_total)

        if params.load_checkpoint:
            self.optimizer.load_state_dict(model_dict['optimizer'])
            self.scheduler.load_state_dict(model_dict['scheduler'])

        self.BCE = nn.BCEWithLogitsLoss()
        self.o1_WeightedBCE = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(1.))
        self.o2_WeightedBCE = nn.BCEWithLogitsLoss(
            pos_weight=torch.tensor(10.))

        self.configure_logger(log_dir=params.log_dir,
                              experiment_name=params.experiment_name)

    def calc_f_score(self, trans):

        logits_t, logits_o1, logits_o2, template_targets, o1_targets, o2_targets = trans

        t_probs = torch.ge(torch.sigmoid(logits_t),
                           0.5).float().clone().detach().cpu().numpy()
        o1_probs = torch.ge(torch.sigmoid(logits_o1),
                            0.5).float().clone().detach().cpu().numpy()
        o2_probs = torch.ge(torch.sigmoid(logits_o2),
                            0.5).float().clone().detach().cpu().numpy()
        template_targets = template_targets.cpu().numpy()
        o1_targets = o1_targets.cpu().numpy()
        o2_targets = o2_targets.cpu().numpy()

        t_fscore, o1_fscore, o2_fscore = 0, 0, 0
        t_precision, o1_precision, o2_precision = 0, 0, 0
        t_recall, o1_recall, o2_recall = 0, 0, 0

        t_precision, t_recall, t_fscore, _ = precision_recall_fscore_support(
            template_targets, t_probs, average='weighted')

        N = self.batch_size
        for b in range(N):
            o1_p, o1_r, o1_f, _ = precision_recall_fscore_support(
                o1_targets[b],
                o1_probs[b],
                average='weighted',
                zero_division=1)
            o1_fscore += o1_f
            o1_precision += o1_p
            o1_recall += o1_r

            #average o2 f score over templates and batches
            o2_p, o2_r, o2_f = 0, 0, 0
            for t in range(len(self.template2id)):
                temp_p, temp_r, temp_f, _ = precision_recall_fscore_support(
                    o2_targets[b][t],
                    o2_probs[b][t],
                    average='weighted',
                    zero_division=1)
                o2_p += temp_p
                o2_r += temp_r
                o2_f += temp_f

            #divide by number of templates
            o2_fscore += (o2_f / len(self.template2id))
            o2_precision += (o2_p / len(self.template2id))
            o2_recall += (o2_r / len(self.template2id))

        #divide by batch size
        o1_fscore, o1_precision, o1_recall = o1_fscore / N, o1_precision / N, o1_recall / N
        o2_fscore, o2_precision, o2_recall = o2_fscore / N, o2_precision / N, o2_recall / N
        return {
            'template': [t_fscore, t_precision, t_recall],
            'o1': [o1_fscore, o1_precision, o1_recall],
            'o2': [o2_fscore, o2_precision, o2_recall]
        }

    def train(self):

        self.optimizer.zero_grad()
        for epoch in range(self.start_epoch, self.end_epoch):
            self.bert.train()
            t_fscore, o1_fscore, o2_fscore = 0, 0, 0
            t_precision, o1_precision, o2_precision = 0, 0, 0
            t_recall, o1_recall, o2_recall = 0, 0, 0
            t_loss, obj1_loss, obj2_loss = 0, 0, 0
            for step, batch in enumerate(
                    tqdm(self.train_dataloader,
                         desc="Training Epoch {}".format(epoch))):
                batch = tuple(t.to(device) for t in batch)
                state, template_targets, o1_template_targets, o2_o1_template_targets = batch

                input_ids = state[:, 0, :]
                input_mask = state[:, 1, :]
                bert_outputs = self.bert(input_ids,
                                         token_type_ids=None,
                                         attention_mask=input_mask)
                #pooled output of bert
                encoded_state = bert_outputs[1]  #(batch, 768)
                template_logits, o1_template_logits, o2_o1_logits = self.command_generator(
                    encoded_state)

                template_loss = self.template_coeff * self.BCE(
                    template_logits, template_targets.to(device))
                o1_loss = self.object1_coeff * self.o1_WeightedBCE(
                    o1_template_logits, o1_template_targets.to(device))
                o2_loss = self.object2_coeff * self.o2_WeightedBCE(
                    o2_o1_logits, o2_o1_template_targets.to(device))
                loss = (template_loss + o1_loss + o2_loss)

                trans = template_logits, o1_template_logits, o2_o1_logits, template_targets, o1_template_targets, o2_o1_template_targets
                metrics = self.calc_f_score(trans)

                t_loss += template_loss.item()
                obj1_loss += o1_loss.item()
                obj2_loss += o2_loss.item()

                t_fscore += metrics['template'][0]
                t_precision += metrics['template'][1]
                t_recall += metrics['template'][2]

                o1_fscore += metrics['o1'][0]
                o1_precision += metrics['o1'][1]
                o1_recall += metrics['o1'][2]

                o2_fscore += metrics['o2'][0]
                o2_precision += metrics['o2'][1]
                o2_recall += metrics['o2'][2]

                loss.backward()

                torch.nn.utils.clip_grad_norm_(
                    self.command_generator.parameters(), self.max_grad_norm)
                torch.nn.utils.clip_grad_norm_(self.bert.parameters(),
                                               self.max_grad_norm)

                self.optimizer.step()
                self.optimizer.zero_grad()
                self.scheduler.step()

                if self.verbose and step % self.print_freq == 0:
                    sys.stdout.flush()
                    train_stats = "\n\tTrain Stats:\n\tStep: {} Template Loss: {} O1 Loss: {}, O2 Loss: {}\n\t\tTEMPLATE FPR: {}\n\t\tO1 FPR: {}\n\t\tO2 FPR: {}\n".format(\
                        step, template_loss/(step + 1),obj1_loss/(step + 1),obj2_loss/(step + 1),\
                        [t_fscore/(step + 1),t_precision/(step + 1), t_recall/(step + 1) ],\
                        [o1_fscore/(step + 1),o1_precision/(step + 1),o1_recall/(step + 1)],\
                        [o2_fscore/(step + 1), o2_precision/(step + 1), o2_recall/(step + 1)])
                    print(train_stats)
                    self.log(train_stats)

            eval_metrics = self.eval()

            sys.stdout.flush()
            print("EPOCH: ", epoch)
            train_stats = "\n\tTrain Stats:\n\tStep: {} Train Loss: {} O1 Loss: {}, O2 Loss: {}\n\t\tTEMPLATE FPR: {}\n\t\tO1 FPR: {}\n\t\tO2 FPR: {}\n".format(\
                            step, template_loss/(step + 1),obj1_loss/(step + 1),obj2_loss/(step + 1),\
                            [t_fscore/(step + 1),t_precision/(step + 1), t_recall/(step + 1) ],\
                            [o1_fscore/(step + 1),o1_precision/(step + 1),o1_recall/(step + 1)],\
                            [o2_fscore/(step + 1), o2_precision/(step + 1), o2_recall/(step + 1)])
            eval_stats = "\n\tVal Stats:\n\t\tTEMPLATE FPR: {}\n\t\tO1 FPR: {}\n\t\tO2 FPR: {}\n".format(
                eval_metrics['template'], eval_metrics['o1'],
                eval_metrics['o2'])
            print(train_stats)
            print(eval_stats)

            self.log(train_stats)
            self.log(eval_stats)

            wandb.log({
                'Template Loss': template_loss / (step + 1),
                'O1 Loss': obj1_loss / (step + 1),
                'O2 Loss': obj2_loss / (step + 1),
                'Train Template F Score': t_fscore / (step + 1),
                'Train O1 F Score': o1_fscore / (step + 1),
                'Train O2 F Score': o2_fscore / (step + 1),
                'Val Template F Score': eval_metrics['template'][0],
                'Val O1 F Score': eval_metrics['o1'][0],
                'Val O2 F Score': eval_metrics['o2'][0],
            })

            checkpoint = {
                'epoch': epoch,
                'ac_model': self.command_generator.state_dict(),
                'state_encoder': self.bert.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'scheduler': self.scheduler.state_dict(),
            }

            save_dir = '{}/{}/Epoch{}/'.format(self.checkpoint_dir,
                                               self.experiment_name, epoch)
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)

            torch.save(checkpoint, save_dir + "checkpoint.pth")
            self.bert.save_pretrained(save_dir)
            self.dataset.tokenizer.save_pretrained(save_dir)

    def eval(self):
        self.bert.eval()
        t_fscore, o1_fscore, o2_fscore = 0, 0, 0
        t_precision, o1_precision, o2_precision = 0, 0, 0
        t_recall, o1_recall, o2_recall = 0, 0, 0
        for step, batch in enumerate(
                tqdm(self.valid_dataloader, desc="Evaluation")):
            batch = tuple(t.to(device) for t in batch)
            state, template_targets, o1_template_targets, o2_o1_template_targets = batch
            with torch.no_grad():
                input_ids = state[:, 0, :]
                input_mask = state[:, 1, :]
                bert_outputs = self.bert(input_ids,
                                         token_type_ids=None,
                                         attention_mask=input_mask)
                #pooled output of bert
                encoded_state = bert_outputs[1]  #(batch, 768)
                template_logits, o1_template_logits, o2_o1_logits = self.command_generator(
                    encoded_state)

            trans = template_logits, o1_template_logits, o2_o1_logits, template_targets, o1_template_targets, o2_o1_template_targets

            metrics = self.calc_f_score(trans)

            t_fscore += metrics['template'][0]
            t_precision += metrics['template'][1]
            t_recall += metrics['template'][2]

            o1_fscore += metrics['o1'][0]
            o1_precision += metrics['o1'][1]
            o1_recall += metrics['o1'][2]

            o2_fscore += metrics['o2'][0]
            o2_precision += metrics['o2'][1]
            o2_recall += metrics['o2'][2]

        eval_metrics = {
            'template': [
                t_fscore / (step + 1), t_precision / (step + 1),
                t_recall / (step + 1)
            ],
            'o1': [
                o1_fscore / (step + 1), o1_precision / (step + 1),
                o1_recall / (step + 1)
            ],
            'o2': [
                o2_fscore / (step + 1), o2_precision / (step + 1),
                o2_recall / (step + 1)
            ]
        }
        return eval_metrics

    def test(self):
        pass

    def configure_logger(self, log_dir, experiment_name):
        logger.configure(log_dir + "/" + experiment_name, format_strs=['log'])
        self.log = logger.log

    def get_train_valid_test_split(self, dataset_size, split):
        indices = list(range(dataset_size))
        split = int(np.floor(split * dataset_size))
        np.random.shuffle(indices)
        train_indices, val_indices = indices[split:], indices[:split]
        return train_indices, val_indices
Esempio n. 22
0
def train(args, train_dataset, model, tokenizer, pool):
    """ Train the model """

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)

    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)
    args.max_steps = args.epoch * len(train_dataloader)
    args.save_steps = len(train_dataloader)
    args.warmup_steps = len(train_dataloader)
    args.logging_steps = len(train_dataloader)
    args.num_train_epochs = args.epoch
    model.to(args.device)
    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=args.max_steps)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    checkpoint_last = os.path.join(args.output_dir, 'checkpoint-last')
    scheduler_last = os.path.join(checkpoint_last, 'scheduler.pt')
    optimizer_last = os.path.join(checkpoint_last, 'optimizer.pt')
    if os.path.exists(scheduler_last):
        scheduler.load_state_dict(torch.load(scheduler_last))
    if os.path.exists(optimizer_last):
        optimizer.load_state_dict(torch.load(optimizer_last))
    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", args.max_steps)

    global_step = args.start_step
    tr_loss, logging_loss, avg_loss, tr_nb, tr_num, train_loss = 0.0, 0.0, 0.0, 0, 0, 0
    best_mrr = 0.0
    best_f1 = 0
    # model.resize_token_embeddings(len(tokenizer))
    model.zero_grad()
    set_seed(args.seed
             )  # Added here for reproducibility (even between python 2 and 3)

    for idx in range(args.start_epoch, int(args.num_train_epochs)):
        bar = tqdm(train_dataloader, total=len(train_dataloader))
        tr_num = 0
        train_loss = 0
        for step, batch in enumerate(bar):
            inputs = batch[0].to(args.device)
            labels = batch[1].to(args.device)
            model.train()
            loss, logits = model(inputs, labels)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.max_grad_norm)
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

            tr_loss += loss.item()
            tr_num += 1
            train_loss += loss.item()
            if avg_loss == 0:
                avg_loss = tr_loss
            avg_loss = round(train_loss / tr_num, 5)
            bar.set_description("epoch {} loss {}".format(idx, avg_loss))

            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()
                scheduler.step()
                global_step += 1
                output_flag = True
                avg_loss = round(
                    np.exp((tr_loss - logging_loss) / (global_step - tr_nb)),
                    4)
                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    logging_loss = tr_loss
                    tr_nb = global_step

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:

                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args,
                                           model,
                                           tokenizer,
                                           pool=pool,
                                           eval_when_training=True)
                        # Save model checkpoint

                    if results['eval_f1'] > best_f1:
                        best_f1 = results['eval_f1']
                        logger.info("  " + "*" * 20)
                        logger.info("  Best f1:%s", round(best_f1, 4))
                        logger.info("  " + "*" * 20)

                        checkpoint_prefix = 'checkpoint-best-f1'
                        output_dir = os.path.join(
                            args.output_dir, '{}'.format(checkpoint_prefix))
                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
                        model_to_save = model.module if hasattr(
                            model, 'module') else model
                        output_dir = os.path.join(output_dir,
                                                  '{}'.format('model.bin'))
                        torch.save(model_to_save.state_dict(), output_dir)
                        logger.info("Saving model checkpoint to %s",
                                    output_dir)

        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break
    return global_step, tr_loss / global_step
Esempio n. 23
0
def train(args, train_dataset, model, tokenizer, fh, pool):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        args.tensorboard_dir = os.path.join(args.output_dir, 'tensorboard')
        if not os.path.exists(args.tensorboard_dir):
            os.makedirs(args.tensorboard_dir)
        tb_writer = SummaryWriter(args.tensorboard_dir)

    args.batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(train_dataset)

    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.batch_size,
                                  drop_last=True)
    total_examples = len(train_dataset) * (torch.distributed.get_world_size()
                                           if args.local_rank != -1 else 1)
    batch_size = args.batch_size * args.gradient_accumulation_steps * (
        torch.distributed.get_world_size() if args.local_rank != -1 else 1)
    # if args.max_steps > 0:
    #     t_total = args.max_steps
    #     args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    if args.num_train_epochs > 0:
        t_total = total_examples // batch_size * args.num_train_epochs
    args.max_steps = t_total
    model.to(args.device)
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()
    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)
    checkpoint_last = os.path.join(args.output_dir, 'checkpoint-last')
    scheduler_last = os.path.join(checkpoint_last, 'scheduler.pt')
    optimizer_last = os.path.join(checkpoint_last, 'optimizer.pt')
    if os.path.exists(scheduler_last):
        scheduler.load_state_dict(
            torch.load(scheduler_last, map_location="cpu"))
    if os.path.exists(optimizer_last):
        optimizer.load_state_dict(
            torch.load(optimizer_last, map_location="cpu"))
    if args.local_rank == 0:
        torch.distributed.barrier()
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank % args.gpu_per_node],
            output_device=args.local_rank % args.gpu_per_node,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", total_examples)
    logger.info("  Num epoch = %d", t_total * batch_size // total_examples)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        batch_size)
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = args.start_step
    tr_loss, logging_loss, avg_loss, tr_nb = 0.0, 0.0, 0.0, 0
    # model.resize_token_embeddings(len(tokenizer))
    model.zero_grad()
    set_seed(
        args)  # Added here for reproducibility (even between python 2 and 3)

    best_bleu = 0.0

    for idx in range(args.start_epoch, int(args.num_train_epochs)):
        for step, (batch, token_labels) in enumerate(train_dataloader):
            inputs = batch.to(args.device)
            attn_mask = torch.tensor(token_labels.clone().detach() != 0,
                                     dtype=torch.uint8,
                                     device=args.device)
            loss_mask = torch.tensor(token_labels.clone().detach() == 2,
                                     dtype=torch.uint8,
                                     device=args.device)
            model.train()
            # outputs = model(inputs, attention_mask=attn_mask, labels=inputs, loss_mask=loss_mask)
            # loss = outputs[0]
            outputs = model(inputs, attention_mask=attn_mask)
            logits = outputs[0]
            labels = inputs
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            flatten_shift_loss_mask = loss_mask[..., :-1].contiguous().view(-1)
            ids = torch.nonzero(flatten_shift_loss_mask).view(-1)
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1))[ids],
                shift_labels.view(-1)[ids])

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.max_grad_norm)
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

            tr_loss += loss.item()

            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()
                scheduler.step()
                global_step += 1
                output_flag = True
                avg_loss = round(
                    np.exp((tr_loss - logging_loss) / (global_step - tr_nb)),
                    4)
                if global_step % args.logging_steps == 0:
                    logger.info("  steps: %s  ppl: %s", global_step,
                                round(avg_loss, 5))
                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    tb_writer.add_scalar('lr',
                                         scheduler.get_last_lr()[0],
                                         global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss
                    tr_nb = global_step

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    checkpoint_prefix = "checkpoint"
                    # Save model checkpoint
                    if args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        # results = evaluate(args, model, tokenizer, eval_when_training=True)
                        # for key, value in results.items():
                        #     tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
                        #     logger.info("  %s = %s", key, round(value,4))
                        # output_dir = os.path.join(args.output_dir, '{}-{}-{}'.format(checkpoint_prefix, global_step, round(results['perplexity'],4)))
                        dev_bleu, dev_EM = eval_bleu(args,
                                                     model,
                                                     tokenizer,
                                                     file_type='dev',
                                                     num=100)
                        logger.info(f"dev bleu: {dev_bleu}, dev EM: {dev_EM}")
                        output_dir = os.path.join(
                            args.output_dir,
                            '{}-{}-{}'.format(checkpoint_prefix, global_step,
                                              round(dev_bleu, 2)))
                        if dev_bleu > best_bleu:
                            best_bleu = dev_bleu
                            logger.info(
                                f"best bleu updated. saved in {output_dir}")
                            logger.info(f"best bleu: {best_bleu}")
                    else:
                        output_dir = os.path.join(
                            args.output_dir,
                            "{}-{}".format(checkpoint_prefix, global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    # _rotate_checkpoints(args, checkpoint_prefix)
                    last_output_dir = os.path.join(args.output_dir,
                                                   'checkpoint-last')
                    if not os.path.exists(last_output_dir):
                        os.makedirs(last_output_dir)
                    model_to_save.save_pretrained(last_output_dir)
                    tokenizer.save_pretrained(last_output_dir)
                    idx_file = os.path.join(last_output_dir, 'idx_file.txt')
                    with open(idx_file, 'w', encoding='utf-8') as idxf:
                        idxf.write(str(0) + '\n')

                    torch.save(optimizer.state_dict(),
                               os.path.join(last_output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(last_output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                last_output_dir)

                    step_file = os.path.join(last_output_dir, 'step_file.txt')
                    with open(step_file, 'w', encoding='utf-8') as stepf:
                        stepf.write(str(global_step) + '\n')

                    # torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                    # torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                    # logger.info("Saving optimizer and scheduler states to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
Esempio n. 24
0
    def train(
        self,
        train_dataset,
        output_dir,
        show_running_loss=True,
        eval_data=None,
        verbose=True,
        **kwargs,
    ):
        """
        Trains the model on train_dataset.

        Utility function to be used by the train_model() method. Not intended to be used directly.
        """

        model = self.model
        args = self.args
        device = self.device

        tb_writer = SummaryWriter(logdir=args["tensorboard_dir"])
        train_sampler = RandomSampler(train_dataset)
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args["train_batch_size"])

        if args["max_steps"] > 0:
            t_total = args["max_steps"]
            args["num_train_epochs"] = (
                args["max_steps"] //
                (len(train_dataloader) // args["gradient_accumulation_steps"])
                + 1)
        else:
            t_total = len(train_dataloader) // args[
                "gradient_accumulation_steps"] * args["num_train_epochs"]

        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                args["weight_decay"],
            },
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ]
            },
        ]

        warmup_steps = math.ceil(t_total * args["warmup_ratio"])
        args["warmup_steps"] = warmup_steps if args[
            "warmup_steps"] == 0 else args["warmup_steps"]

        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args["learning_rate"],
                          eps=args["adam_epsilon"])
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=args["warmup_steps"],
            num_training_steps=t_total)

        if (args["model_name"] and os.path.isfile(
                os.path.join(args["model_name"], "optimizer.pt"))
                and os.path.isfile(
                    os.path.join(args["model_name"], "scheduler.pt"))):
            # Load in optimizer and scheduler states
            optimizer.load_state_dict(
                torch.load(os.path.join(args["model_name"], "optimizer.pt")))
            scheduler.load_state_dict(
                torch.load(os.path.join(args["model_name"], "scheduler.pt")))

        if args["fp16"]:
            try:
                from apex import amp
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )

            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=args["fp16_opt_level"])

        if args["n_gpu"] > 1:
            model = torch.nn.DataParallel(model)

        logger.info(" Training started")

        global_step = 0
        tr_loss, logging_loss = 0.0, 0.0
        model.zero_grad()
        train_iterator = trange(int(args["num_train_epochs"]),
                                desc="Epoch",
                                disable=args["silent"],
                                mininterval=0)
        epoch_number = 0
        best_eval_metric = None
        early_stopping_counter = 0
        steps_trained_in_current_epoch = 0
        epochs_trained = 0

        if args["model_name"] and os.path.exists(args["model_name"]):
            try:
                # set global_step to gobal_step of last saved checkpoint from model path
                checkpoint_suffix = args["model_name"].split("/")[-1].split(
                    "-")
                if len(checkpoint_suffix) > 2:
                    checkpoint_suffix = checkpoint_suffix[1]
                else:
                    checkpoint_suffix = checkpoint_suffix[-1]
                global_step = int(checkpoint_suffix)
                epochs_trained = global_step // (
                    len(train_dataloader) //
                    args["gradient_accumulation_steps"])
                steps_trained_in_current_epoch = global_step % (
                    len(train_dataloader) //
                    args["gradient_accumulation_steps"])

                logger.info(
                    "   Continuing training from checkpoint, will skip to saved global_step"
                )
                logger.info("   Continuing training from epoch %d",
                            epochs_trained)
                logger.info("   Continuing training from global step %d",
                            global_step)
                logger.info(
                    "   Will skip the first %d steps in the current epoch",
                    steps_trained_in_current_epoch)
            except ValueError:
                logger.info("   Starting fine-tuning.")

        if args["evaluate_during_training"]:
            training_progress_scores = self._create_training_progress_scores(
                **kwargs)

        if args["wandb_project"]:
            wandb.init(project=args["wandb_project"],
                       config={**args},
                       **args["wandb_kwargs"])
            wandb.watch(self.model)

        model.train()
        for current_epoch in train_iterator:
            if epochs_trained > 0:
                epochs_trained -= 1
                continue
            # epoch_iterator = tqdm(train_dataloader, desc="Iteration")
            for step, batch in enumerate(
                    tqdm(train_dataloader,
                         desc="Current iteration",
                         disable=args["silent"])):
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue
                batch = tuple(t.to(device) for t in batch)

                inputs = self._get_inputs_dict(batch)
                outputs = model(**inputs)
                # model outputs are always tuple in pytorch-transformers (see doc)
                loss = outputs[0]

                if args["n_gpu"] > 1:
                    loss = loss.mean(
                    )  # mean() to average on multi-gpu parallel training

                current_loss = loss.item()

                if show_running_loss:
                    print("\rRunning loss: %f" % loss, end="")

                if args["gradient_accumulation_steps"] > 1:
                    loss = loss / args["gradient_accumulation_steps"]

                if args["fp16"]:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    # torch.nn.utils.clip_grad_norm_(
                    #     amp.master_params(optimizer), args["max_grad_norm"]
                    # )
                else:
                    loss.backward()
                    # torch.nn.utils.clip_grad_norm_(
                    #     model.parameters(), args["max_grad_norm"]
                    # )

                tr_loss += loss.item()
                if (step + 1) % args["gradient_accumulation_steps"] == 0:
                    if args["fp16"]:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer),
                            args["max_grad_norm"])
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args["max_grad_norm"])

                    optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    model.zero_grad()
                    global_step += 1

                    if args["logging_steps"] > 0 and global_step % args[
                            "logging_steps"] == 0:
                        # Log metrics
                        tb_writer.add_scalar("lr",
                                             scheduler.get_lr()[0],
                                             global_step)
                        tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                             args["logging_steps"],
                                             global_step)
                        logging_loss = tr_loss
                        if args["wandb_project"]:
                            wandb.log({
                                "Training loss": current_loss,
                                "lr": scheduler.get_lr()[0],
                                "global_step": global_step,
                            })

                    if args["save_steps"] > 0 and global_step % args[
                            "save_steps"] == 0:
                        # Save model checkpoint
                        output_dir_current = os.path.join(
                            output_dir, "checkpoint-{}".format(global_step))

                        self._save_model(output_dir_current,
                                         optimizer,
                                         scheduler,
                                         model=model)

                    if args["evaluate_during_training"] and (
                            args["evaluate_during_training_steps"] > 0
                            and global_step %
                            args["evaluate_during_training_steps"] == 0):
                        # Only evaluate when single GPU otherwise metrics may not average well
                        results = self.eval_model(
                            eval_data,
                            verbose=verbose
                            and args["evaluate_during_training_verbose"],
                            silent=True,
                            **kwargs,
                        )
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value,
                                                 global_step)

                        output_dir_current = os.path.join(
                            output_dir, "checkpoint-{}".format(global_step))

                        if args["save_eval_checkpoints"]:
                            self._save_model(output_dir_current,
                                             optimizer,
                                             scheduler,
                                             model=model,
                                             results=results)

                        training_progress_scores["global_step"].append(
                            global_step)
                        training_progress_scores["train_loss"].append(
                            current_loss)
                        for key in results:
                            training_progress_scores[key].append(results[key])
                        report = pd.DataFrame(training_progress_scores)
                        report.to_csv(
                            os.path.join(args["output_dir"],
                                         "training_progress_scores.csv"),
                            index=False,
                        )

                        if args["wandb_project"]:
                            wandb.log(
                                self._get_last_metrics(
                                    training_progress_scores))

                        if not best_eval_metric:
                            best_eval_metric = results[
                                args["early_stopping_metric"]]
                            self._save_model(args["best_model_dir"],
                                             optimizer,
                                             scheduler,
                                             model=model,
                                             results=results)
                        if best_eval_metric and args[
                                "early_stopping_metric_minimize"]:
                            if (results[args["early_stopping_metric"]] -
                                    best_eval_metric <
                                    args["early_stopping_delta"]):
                                best_eval_metric = results[
                                    args["early_stopping_metric"]]
                                self._save_model(args["best_model_dir"],
                                                 optimizer,
                                                 scheduler,
                                                 model=model,
                                                 results=results)
                                early_stopping_counter = 0
                            else:
                                if args["use_early_stopping"]:
                                    if early_stopping_counter < args[
                                            "early_stopping_patience"]:
                                        early_stopping_counter += 1
                                        if verbose:
                                            logger.info(
                                                f" No improvement in {args['early_stopping_metric']}"
                                            )
                                            logger.info(
                                                f" Current step: {early_stopping_counter}"
                                            )
                                            logger.info(
                                                f" Early stopping patience: {args['early_stopping_patience']}"
                                            )
                                    else:
                                        if verbose:
                                            logger.info(
                                                f" Patience of {args['early_stopping_patience']} steps reached"
                                            )
                                            logger.info(
                                                " Training terminated.")
                                            train_iterator.close()
                                        return global_step, tr_loss / global_step
                        else:
                            if (results[args["early_stopping_metric"]] -
                                    best_eval_metric >
                                    args["early_stopping_delta"]):
                                best_eval_metric = results[
                                    args["early_stopping_metric"]]
                                self._save_model(args["best_model_dir"],
                                                 optimizer,
                                                 scheduler,
                                                 model=model,
                                                 results=results)
                                early_stopping_counter = 0
                            else:
                                if args["use_early_stopping"]:
                                    if early_stopping_counter < args[
                                            "early_stopping_patience"]:
                                        early_stopping_counter += 1
                                        if verbose:
                                            logger.info(
                                                f" No improvement in {args['early_stopping_metric']}"
                                            )
                                            logger.info(
                                                f" Current step: {early_stopping_counter}"
                                            )
                                            logger.info(
                                                f" Early stopping patience: {args['early_stopping_patience']}"
                                            )
                                    else:
                                        if verbose:
                                            logger.info(
                                                f" Patience of {args['early_stopping_patience']} steps reached"
                                            )
                                            logger.info(
                                                " Training terminated.")
                                            train_iterator.close()
                                        return global_step, tr_loss / global_step

            epoch_number += 1
            output_dir_current = os.path.join(
                output_dir,
                "checkpoint-{}-epoch-{}".format(global_step, epoch_number))

            if args["save_model_every_epoch"] or args[
                    "evaluate_during_training"]:
                os.makedirs(output_dir_current, exist_ok=True)

            if args["save_model_every_epoch"]:
                self._save_model(output_dir_current,
                                 optimizer,
                                 scheduler,
                                 model=model)

            if args["evaluate_during_training"]:
                results = self.eval_model(
                    eval_data,
                    verbose=verbose
                    and args["evaluate_during_training_verbose"],
                    silent=True,
                    **kwargs)

                self._save_model(output_dir_current,
                                 optimizer,
                                 scheduler,
                                 results=results)

                training_progress_scores["global_step"].append(global_step)
                training_progress_scores["train_loss"].append(current_loss)
                for key in results:
                    training_progress_scores[key].append(results[key])
                report = pd.DataFrame(training_progress_scores)
                report.to_csv(os.path.join(args["output_dir"],
                                           "training_progress_scores.csv"),
                              index=False)

                if args["wandb_project"]:
                    wandb.log(self._get_last_metrics(training_progress_scores))

                if not best_eval_metric:
                    best_eval_metric = results[args["early_stopping_metric"]]
                    self._save_model(args["best_model_dir"],
                                     optimizer,
                                     scheduler,
                                     model=model,
                                     results=results)
                if best_eval_metric and args["early_stopping_metric_minimize"]:
                    if results[args[
                            "early_stopping_metric"]] - best_eval_metric < args[
                                "early_stopping_delta"]:
                        best_eval_metric = results[
                            args["early_stopping_metric"]]
                        self._save_model(args["best_model_dir"],
                                         optimizer,
                                         scheduler,
                                         model=model,
                                         results=results)
                        early_stopping_counter = 0
                    else:
                        if args["use_early_stopping"] and args[
                                "early_stopping_consider_epochs"]:
                            if early_stopping_counter < args[
                                    "early_stopping_patience"]:
                                early_stopping_counter += 1
                                if verbose:
                                    logger.info(
                                        f" No improvement in {args['early_stopping_metric']}"
                                    )
                                    logger.info(
                                        f" Current step: {early_stopping_counter}"
                                    )
                                    logger.info(
                                        f" Early stopping patience: {args['early_stopping_patience']}"
                                    )
                            else:
                                if verbose:
                                    logger.info(
                                        f" Patience of {args['early_stopping_patience']} steps reached"
                                    )
                                    logger.info(" Training terminated.")
                                    train_iterator.close()
                                return global_step, tr_loss / global_step
                else:
                    if results[args[
                            "early_stopping_metric"]] - best_eval_metric > args[
                                "early_stopping_delta"]:
                        best_eval_metric = results[
                            args["early_stopping_metric"]]
                        self._save_model(args["best_model_dir"],
                                         optimizer,
                                         scheduler,
                                         model=model,
                                         results=results)
                        early_stopping_counter = 0
                    else:
                        if args["use_early_stopping"] and args[
                                "early_stopping_consider_epochs"]:
                            if early_stopping_counter < args[
                                    "early_stopping_patience"]:
                                early_stopping_counter += 1
                                if verbose:
                                    logger.info(
                                        f" No improvement in {args['early_stopping_metric']}"
                                    )
                                    logger.info(
                                        f" Current step: {early_stopping_counter}"
                                    )
                                    logger.info(
                                        f" Early stopping patience: {args['early_stopping_patience']}"
                                    )
                            else:
                                if verbose:
                                    logger.info(
                                        f" Patience of {args['early_stopping_patience']} steps reached"
                                    )
                                    logger.info(" Training terminated.")
                                    train_iterator.close()
                                return global_step, tr_loss / global_step

        return global_step, tr_loss / global_step
def train(args, train_dataset, model, tokenizer, config):
    """ Train the model """
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size,
                                  collate_fn=collate_fn)
    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    args.warmup_steps = int(t_total * args.warmup_proportion)
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)
    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt")):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)
    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)
    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)
    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)
    global_step = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path
                      ) and "checkpoint" in args.model_name_or_path:
        # set global_step to gobal_step of last saved checkpoint from model path
        global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
        epochs_trained = global_step // (len(train_dataloader) //
                                         args.gradient_accumulation_steps)
        steps_trained_in_current_epoch = global_step % (
            len(train_dataloader) // args.gradient_accumulation_steps)
        logger.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logger.info("  Continuing training from epoch %d", epochs_trained)
        logger.info("  Continuing training from global step %d", global_step)
        logger.info("  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)

    tr_loss, logging_loss = 0.0, 0.0
    if args.do_adv:
        fgm = FGM(model, emb_name=args.adv_name, epsilon=args.adv_epsilon)
    model.zero_grad()
    seed_everything(
        args.seed
    )  # Added here for reproductibility (even between python 2 and 3)
    # train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])

    best_eval_p = 0.0
    best_eval_r = 0.0
    best_eval_f1 = 0.0

    for _ in range(int(args.num_train_epochs)):
        # pbar = ProgressBar(n_total=len(train_dataloader), desc='Training')
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "labels": batch[3]
            }
            if args.model_type != "distilbert":
                # XLM and RoBERTa don"t use segment_ids
                inputs["token_type_ids"] = (batch[2] if args.model_type
                                            in ["bert", "xlnet"] else None)

            outputs = model(**inputs)
            loss = outputs[
                0]  # model outputs are always tuple in pytorch-transformers (see doc)
            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            if args.do_adv:
                fgm.attack()
                loss_adv = model(**inputs)[0]
                if args.n_gpu > 1:
                    loss_adv = loss_adv.mean()
                loss_adv.backward()
                fgm.restore()
            # pbar(step, {'loss': loss.item()})
            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                scheduler.step()  # Update learning rate schedule
                optimizer.step()
                model.zero_grad()
                global_step += 1

        # best eval
        results = evaluate(args, model, tokenizer)
        if results["f1"] > best_eval_f1:
            best_eval_p = results['acc']
            best_eval_r = results['recall']
            best_eval_f1 = results["f1"]

            ## 保存最好模型
            args.best_eval_output_dir = os.path.join(args.output_dir)
            if not os.path.exists(args.best_eval_output_dir):
                os.makedirs(args.best_eval_output_dir)
            model_to_save = (
                model.module if hasattr(model, "module") else model
            )  # Take care of distributed/parallel training
            model_to_save.save_pretrained(args.best_eval_output_dir)
            torch.save(
                args,
                os.path.join(args.best_eval_output_dir, "training_args.bin"))
            logger.info("eval results: p:{:.4f} r:{:.4f} f1:{:.4f}".format(
                best_eval_p, best_eval_r, best_eval_f1))
            logger.info("Saving step:{} as best model to {}".format(
                global_step, args.best_eval_output_dir))
            tokenizer.save_vocabulary(args.best_eval_output_dir)
            torch.save(optimizer.state_dict(),
                       os.path.join(args.best_eval_output_dir, "optimizer.pt"))
            torch.save(scheduler.state_dict(),
                       os.path.join(args.best_eval_output_dir, "scheduler.pt"))
            # logger.info("Saving optimizer and scheduler states to %s", args.best_eval_output_dir)
            # config_file = os.path.join(args.best_eval_output_dir, "best_config.json")
            # json.dump(config, config_file)

        logger.info("\n")
        if 'cuda' in str(args.device):
            torch.cuda.empty_cache()
    if args.do_predict:
        config_class, model_class, tokenizer_class = MODEL_CLASSES[
            args.model_type]
        model = model_class.from_pretrained(args.best_eval_output_dir,
                                            config=config)
        model.to(args.device)
        predict(args, model, tokenizer)

    return global_step, tr_loss / global_step
Esempio n. 26
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset
    )  # don't use DistributedSampler due to OOM. Will manually split dataset later.
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    def is_train_param(name, verbose=False):
        if name.endswith(".embeddings.word_embeddings.weight"):
            if verbose:
                logger.info(f'freezing {name}')
            return False
        if name.startswith("cross_encoder"):
            return False

        return True

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay) and is_train_param(n)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay) and is_train_param(n)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)
    logger.info('Number of trainable params: {:,}'.format(
        sum(p.numel() for n, p in model.named_parameters()
            if is_train_param(n, verbose=False))))

    # Check if saved optimizer or scheduler states exist
    if args.load_dir:
        if os.path.isfile(os.path.join(
                args.load_dir, "optimizer.pt")) and os.path.isfile(
                    os.path.join(args.load_dir, "scheduler.pt")):
            # Load in optimizer and scheduler states
            optimizer.load_state_dict(
                torch.load(os.path.join(args.load_dir, "optimizer.pt"),
                           map_location=torch.device('cpu')))
            scheduler.load_state_dict(
                torch.load(os.path.join(args.load_dir, "scheduler.pt"),
                           map_location=torch.device('cpu')))
            logger.info(f'optimizer and scheduler loaded from {args.load_dir}')

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )

        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 1
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if args.load_dir:
        try:
            # set global_step to global_step of last saved checkpoint from model path
            checkpoint_suffix = args.load_dir.split("-")[-1].split("/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    # Added here for reproductibility
    set_seed(args)

    for ep_idx, _ in enumerate(train_iterator):
        logger.info(f"\n[Epoch {ep_idx+1}]")
        if args.pbn_size > 0 and ep_idx + 1 > args.pbn_tolerance:
            if hasattr(model, 'module'):
                model.module.init_pre_batch(args.pbn_size)
            else:
                model.init_pre_batch(args.pbn_size)
            logger.info(
                f"Initialize pre-batch of size {args.pbn_size} for Epoch {ep_idx+1}"
            )

        # Skip batch
        train_iterator = iter(train_dataloader)
        initial_step = steps_trained_in_current_epoch
        if steps_trained_in_current_epoch > 0:
            train_dataloader.dataset.skip = True  # used for LazyDataloader
            while steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                next(train_iterator)
            train_dataloader.dataset.skip = False
        assert steps_trained_in_current_epoch == 0

        epoch_iterator = tqdm(train_iterator,
                              initial=initial_step,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        # logger.setLevel(logging.DEBUG)
        start_time = time()
        for step, batch in enumerate(epoch_iterator):
            logger.debug(
                f'1) {time()-start_time:.3f}s: on-the-fly pre-processing')
            start_time = time()

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model.train()
            batch = tuple(t.to(args.device) for t in batch)

            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "token_type_ids": batch[2],
                "start_positions": batch[3],
                "end_positions": batch[4],
                "input_ids_": batch[8],
                "attention_mask_": batch[9],
                "token_type_ids_": batch[10],
                # "neg_input_ids": batch[12], # should be commented out if none
                # "neg_attention_mask": batch[13],
                # "neg_token_type_ids": batch[14],
            }

            outputs = model(**inputs)
            # model outputs are always tuple in transformers (see doc)
            loss = outputs[0]
            epoch_iterator.set_description(
                f"Loss={loss.item():.3f}, lr={scheduler.get_lr()[0]:.6f}")

            logger.debug(f'2) {time()-start_time:.3f}s: forward')
            start_time = time()

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel (not distributed) training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            logger.debug(f'3) {time()-start_time:.3f}s: backward')
            start_time = time()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                logger.debug(f'4) {time()-start_time:.3f}s: optimize')
                start_time = time()

                # Log metrics
                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Only evaluate when single GPU otherwise metrics may not average well
                    if args.local_rank == -1 and args.evaluate_during_training:
                        # Validation acc
                        logger.setLevel(logging.WARNING)
                        results, _ = evaluate(args,
                                              model,
                                              tokenizer,
                                              prefix=global_step)
                        wandb.log(
                            {
                                "Eval EM": results['exact'],
                                "Eval F1": results['f1']
                            },
                            step=global_step,
                        )
                        logger.setLevel(logging.INFO)

                    wandb.log(
                        {
                            "lr": scheduler.get_lr()[0],
                            "loss":
                            (tr_loss - logging_loss) / args.logging_steps
                        },
                        step=global_step)
                    logging_loss = tr_loss

                # Save model checkpoint
                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)

                    # Take care of distributed/parallel training
                    model_to_save = model.module if hasattr(
                        model, "module") else model

                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break

        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    return global_step, tr_loss / global_step
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt")):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )

        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 1
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split(
                "/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    # Added here for reproductibility
    set_seed(args)

    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model.train()
            batch = tuple(t.to(args.device) for t in batch)

            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "token_type_ids": batch[2],
                "start_positions": batch[3],
                "end_positions": batch[4],
            }

            if args.model_type in [
                    "xlm", "roberta", "distilbert", "camembert"
            ]:
                del inputs["token_type_ids"]

            if args.model_type in ["xlnet", "xlm"]:
                inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
                if args.version_2_with_negative:
                    inputs.update({"is_impossible": batch[7]})
                if hasattr(model, "config") and hasattr(
                        model.config, "lang2id"):
                    inputs.update({
                        "langs":
                        (torch.ones(batch[0].shape, dtype=torch.int64) *
                         args.lang_id).to(args.device)
                    })

            outputs = model(**inputs)
            # model outputs are always tuple in transformers (see doc)
            loss = outputs[0]

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel (not distributed) training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                # Log metrics
                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Only evaluate when single GPU otherwise metrics may not average well
                    if args.local_rank == -1 and args.evaluate_during_training:
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value,
                                                 global_step)
                    tb_writer.add_scalar("lr",
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss

                # Save model checkpoint
                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))
                    # Take care of distributed/parallel training
                    model_to_save = model.module if hasattr(
                        model, "module") else model
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
Esempio n. 28
0
class Trainer(object):
    """
    a trainer modified from ```https://github.com/huggingface/transformers/blob/master/examples/run_language_modeling.py```
    """
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def _post_step(self, args, outputs):
        pass

    def _forward(self, args, inputs, labels, masker, model, backprop=True):
        outputs = model(inputs,
                        masked_lm_labels=labels) if args.mlm else model(
                            inputs, labels=labels)
        loss = outputs[
            0]  # model outputs are always tuple in transformers (see doc)
        if backprop:
            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            self._post_step(args, outputs)
            return loss
        else:
            return loss

    def _train_writer(self, logging_steps):
        self.tb_writer.add_scalar('lr',
                                  self.scheduler.get_lr()[0], self.global_step)
        self.tb_writer.add_scalar('loss', (self.tr_loss - self.logging_loss) /
                                  logging_steps, self.global_step)
        self.logging_loss = self.tr_loss

    def _to(self, args, tensor):
        if isinstance(tensor, torch.Tensor):
            output = tensor.to(args.device)
        elif isinstance(tensor[0], torch.Tensor):
            output = [_t.to(args.device) for _t in tensor]
        else:
            output = [(_t[0].to(args.device), _t[1].to(args.device))
                      for _t in tensor]
        return output

    def _post_training(self):
        pass

    def _train_batch(self, args, step, inputs, labels, masker, eval_dataset,
                     eval_masker, model):
        inputs = self._to(args, inputs)
        labels = self._to(args, labels)

        model.train()
        loss = self._forward(args,
                             inputs,
                             labels,
                             masker,
                             model,
                             backprop=True)

        self.tr_loss += loss.item()
        if (step + 1) % args.gradient_accumulation_steps == 0:
            if args.fp16:
                torch.nn.utils.clip_grad_norm_(
                    amp.master_params(self.optimizer), args.max_grad_norm)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)
            self.optimizer.step()
            self.scheduler.step()  # Update learning rate schedule
            model.zero_grad()
            self._post_training()
            self.global_step += 1

            if args.local_rank in [
                    -1, 0
            ] and args.logging_steps > 0 and self.global_step % args.logging_steps == 0:
                # Log metrics
                if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                    results = self.evaluate(args, eval_dataset, eval_masker,
                                            model)
                    for key, value in results.items():
                        self.tb_writer.add_scalar('eval_{}'.format(key), value,
                                                  self.global_step)
                self._train_writer(args.logging_steps)

            if args.local_rank in [
                    -1, 0
            ] and args.save_steps > 0 and self.global_step % args.save_steps == 0:
                checkpoint_prefix = 'checkpoint'
                # Save model checkpoint
                output_dir = os.path.join(
                    args.output_dir, '{}-{}'.format(checkpoint_prefix,
                                                    self.global_step))
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                model_to_save = model.module if hasattr(
                    model, 'module'
                ) else model  # Take care of distributed/parallel training
                model_to_save.save_pretrained(output_dir)

                self.tokenizer.save_pretrained(output_dir)
                torch.save(args, os.path.join(output_dir, 'training_args.bin'))
                logger.info("Saving model checkpoint to %s", output_dir)

                torch.save(self.optimizer.state_dict(),
                           os.path.join(output_dir, "optimizer.pt"))
                torch.save(self.scheduler.state_dict(),
                           os.path.join(output_dir, "scheduler.pt"))
                logger.info("Saving optimizer and scheduler states to %s",
                            output_dir)

                self._rotate_checkpoints(args, checkpoint_prefix)

    def _train_epoch(self, args, epoch_id, epoch_iterator, train_masker,
                     eval_dataset, eval_masker, model):
        for step, batch in enumerate(epoch_iterator):
            # consume this example also consume randomness.
            inputs, labels = train_masker.mask_tokens(
                batch, args.mlm_probability) if args.mlm else (batch, batch)
            if epoch_id < self.epochs_trained:
                logger.info("Continue training: skip epoch %d", epoch_id)
                break
            if self.steps_trained_in_current_epoch > 0:
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    self.steps_trained_in_current_epoch -= 1
                continue

            if step % 5000 == 0:
                input_ids = inputs[-1].tolist()
                logger.info("domain %s", input_ids[0])
                logger.info(
                    "one exmaple is like %s", " ".join(
                        train_masker.tokenizer.convert_ids_to_tokens(
                            input_ids[1:])))
                logger.info("one label is like %s", str(labels[-1].tolist()))

            self._train_batch(args, step, inputs, labels, train_masker,
                              eval_dataset, eval_masker, model)
            if args.max_steps > 0 and self.global_step > args.max_steps:
                epoch_iterator.close()
                break

    def _build_train_data_loader(self, args, train_dataset, model):
        train_sampler = RandomSampler(
            train_dataset) if args.local_rank == -1 else DistributedSampler(
                train_dataset)
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)
        return train_sampler, train_dataloader

    def _init_logging(self):
        self.tr_loss, self.logging_loss = 0.0, 0.0

    def train(self, args, train_dataset, train_masker, eval_dataset,
              eval_masker, model):
        """ Train the model """
        if args.local_rank in [-1, 0]:
            self.tb_writer = SummaryWriter(args.output_dir)

        args.train_batch_size = args.per_gpu_train_batch_size * max(
            1, args.n_gpu)
        train_sampler, train_dataloader = self._build_train_data_loader(
            args, train_dataset, model)

        if args.max_steps > 0:
            t_total = args.max_steps
            args.num_train_epochs = args.max_steps // (
                len(train_dataloader) // args.gradient_accumulation_steps) + 1
        else:
            t_total = len(
                train_dataloader
            ) // args.gradient_accumulation_steps * args.num_train_epochs
        self.t_total = t_total

        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay
        }, {
            'params': [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }]

        self.optimizer = AdamW(optimizer_grouped_parameters,
                               lr=args.learning_rate,
                               correct_bias=False)

        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=args.warmup_steps,
            num_training_steps=t_total)

        if (args.model_name_or_path and os.path.isfile(
                os.path.join(args.model_name_or_path, "optimizer.pt"))
                and os.path.isfile(
                    os.path.join(args.model_name_or_path, "scheduler.pt"))):
            # Load in optimizer and scheduler states
            self.optimizer.load_state_dict(
                torch.load(
                    os.path.join(args.model_name_or_path, "optimizer.pt")))
            self.scheduler.load_state_dict(
                torch.load(
                    os.path.join(args.model_name_or_path, "scheduler.pt")))

        if args.fp16:
            model, self.optimizer = amp.initialize(
                model, self.optimizer, opt_level=args.fp16_opt_level)

        # multi-gpu training (should be after apex fp16 initialization)
        if args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
        if args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
                find_unused_parameters=True)

        # Train!
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_sampler))
        logger.info("  Num Epochs = %d", args.num_train_epochs)
        logger.info("  Instantaneous batch size per GPU = %d",
                    args.per_gpu_train_batch_size)
        logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            args.train_batch_size * args.gradient_accumulation_steps *
            (torch.distributed.get_world_size()
             if args.local_rank != -1 else 1))
        logger.info("  Gradient Accumulation steps = %d",
                    args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        self.global_step = 0
        self.epochs_trained = 0
        self.steps_trained_in_current_epoch = 0
        if args.model_name_or_path and os.path.exists(args.model_name_or_path):
            try:
                # set global_step to gobal_step of last saved checkpoint from model path
                checkpoint_suffix = args.model_name_or_path.split(
                    "-")[-1].split("/")[0]
                self.global_step = int(checkpoint_suffix)
                self.epochs_trained = self.global_step // (
                    len(train_dataloader) // args.gradient_accumulation_steps)
                self.steps_trained_in_current_epoch = self.global_step % (
                    len(train_dataloader) // args.gradient_accumulation_steps)

                logger.info("  Continuing training from global step %d",
                            self.global_step)
                logger.info("  Will skip the first %d steps in epoch %d",
                            self.steps_trained_in_current_epoch,
                            self.epochs_trained)
            except ValueError:
                logger.info("  Starting fine-tuning.")

        self._init_logging()
        model.zero_grad()
        train_iterator = trange(int(args.num_train_epochs),
                                desc="Epoch",
                                disable=args.local_rank not in [-1, 0])
        set_seed(
            args
        )  # Added here for reproducibility (even between python 2 and 3)
        for epoch_id, _ in enumerate(train_iterator):
            # seek to the current epoch.
            epoch_iterator = tqdm(train_dataloader,
                                  desc="Iteration",
                                  disable=args.local_rank not in [-1, 0])
            self._train_epoch(args, epoch_id, epoch_iterator, train_masker,
                              eval_dataset, eval_masker, model)

            if args.max_steps > 0 and self.global_step > args.max_steps:
                train_iterator.close()
                break

        if args.local_rank in [-1, 0]:
            self.tb_writer.close()

        return self.global_step, self.tr_loss  #/ self.global_step

    def _eval(self, args, eval_dataloader, eval_masker, model):
        eval_loss = 0.0
        nb_eval_steps = 0

        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            inputs, labels = eval_masker.mask_tokens(
                batch, 0.15) if args.mlm else (batch, batch)
            inputs = self._to(args, inputs)
            labels = self._to(args, labels)

            with torch.no_grad():
                lm_loss = self._forward(args,
                                        inputs,
                                        labels,
                                        eval_masker,
                                        model,
                                        backprop=False)
                eval_loss += lm_loss.mean().item()
            nb_eval_steps += 1

        eval_loss = eval_loss / nb_eval_steps

        perplexity = torch.exp(torch.tensor(eval_loss))

        return {
            "perplexity": perplexity,
            "loss": eval_loss,
        }

    def evaluate(self, args, eval_dataset, eval_masker, model, prefix=""):
        # Loop to handle MNLI double evaluation (matched, mis-matched)
        eval_output_dir = args.output_dir

        if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
            os.makedirs(eval_output_dir)

        args.eval_batch_size = args.per_gpu_eval_batch_size * max(
            1, args.n_gpu)
        # Note that DistributedSampler samples randomly
        eval_sampler = SequentialSampler(
            eval_dataset) if args.local_rank == -1 else DistributedSampler(
                eval_dataset)
        eval_dataloader = DataLoader(eval_dataset,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size,
                                     drop_last=True)

        # Eval!
        logger.info("***** Running evaluation {} *****".format(prefix))
        logger.info("  Num examples = %d", len(eval_dataset))
        logger.info("  Batch size = %d", args.eval_batch_size)
        model.eval()

        result = self._eval(args, eval_dataloader, eval_masker, model)

        output_eval_file = os.path.join(eval_output_dir, prefix,
                                        "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results {} *****".format(prefix))
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))

        return result

    def _rotate_checkpoints(self, args, checkpoint_prefix, use_mtime=False):
        if not args.save_total_limit:
            return
        if args.save_total_limit <= 0:
            return

        # Check if we should delete older checkpoint(s)
        glob_checkpoints = glob.glob(
            os.path.join(args.output_dir, '{}-*'.format(checkpoint_prefix)))
        if len(glob_checkpoints) <= args.save_total_limit:
            return

        ordering_and_checkpoint_path = []
        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append(
                    (os.path.getmtime(path), path))
            else:
                regex_match = re.match(
                    '.*{}-([0-9]+)'.format(checkpoint_prefix), path)
                if regex_match and regex_match.groups():
                    ordering_and_checkpoint_path.append(
                        (int(regex_match.groups()[0]), path))

        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [
            checkpoint[1] for checkpoint in checkpoints_sorted
        ]
        number_of_checkpoints_to_delete = max(
            0,
            len(checkpoints_sorted) - args.save_total_limit)
        checkpoints_to_be_deleted = checkpoints_sorted[:
                                                       number_of_checkpoints_to_delete]
        for checkpoint in checkpoints_to_be_deleted:
            logger.info(
                "Deleting older checkpoint [{}] due to args.save_total_limit".
                format(checkpoint))
            shutil.rmtree(checkpoint)
def train(args, train_dataset, model: PreTrainedModel,
          tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    def collate(examples: List[torch.Tensor]):
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
        return pad_sequence(examples,
                            batch_first=True,
                            padding_value=tokenizer.pad_token_id)

    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size,
                                  collate_fn=collate)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    model = model.module if hasattr(
        model,
        "module") else model  # Take care of distributed/parallel training
    model.resize_token_embeddings(len(tokenizer))

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if (args.model_name_or_path and os.path.isfile(
            os.path.join(args.model_name_or_path, "optimizer.pt"))
            and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt"))):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if args.model_name_or_path and os.path.exists(args.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split(
                "/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0

    model.zero_grad()
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(args)  # Added here for reproducibility
    for epoch in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])

        if args.local_rank != -1:
            train_sampler.set_epoch(epoch)

        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            inputs, labels = mask_tokens(batch, tokenizer,
                                         args) if args.mlm else (batch, batch)
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            model.train()
            outputs = model(inputs,
                            masked_lm_labels=labels) if args.mlm else model(
                                inputs, labels=labels)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if (
                            args.local_rank == -1
                            and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value,
                                                 global_step)
                    tb_writer.add_scalar("lr",
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    checkpoint_prefix = "checkpoint"
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir,
                        "{}-{}".format(checkpoint_prefix, global_step))
                    os.makedirs(output_dir, exist_ok=True)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    _rotate_checkpoints(args, checkpoint_prefix)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
Esempio n. 30
0
def main(args):

    START_TAG = "<START_TAG>"
    END_TAG = "<END_TAG>"
    O = "O"
    BLOC = "B-LOC"
    ILOC = "I-LOC"
    BORG = "B-ORG"
    IORG = "I-ORG"
    BPER = "B-PER"
    IPER = "I-PER"
    PAD = "<PAD>"
    UNK = "<UNK>"
    token2idx = {PAD: 0, UNK: 1}
    tag2idx = {
        START_TAG: 0,
        END_TAG: 1,
        O: 2,
        BLOC: 3,
        ILOC: 4,
        BORG: 5,
        IORG: 6,
        BPER: 7,
        IPER: 8
    }
    # tb_writer = SummaryWriter(args.model_name)

    id2tag = {v: k for k, v in tag2idx.items()}
    if not os.path.exists(args.model_name):
        os.makedirs(args.model_name)
    save_parser(args, os.path.join(args.model_name, "parser_config.json"))

    # set cuda device and seed
    use_cuda = torch.cuda.is_available() and args.is_cuda
    cuda_device = ":{}".format(args.cuda_device)
    device = torch.device('cuda' + cuda_device if use_cuda else 'cpu')
    torch.manual_seed(args.seed)
    if use_cuda:
        torch.cuda.manual_seed(args.seed)

    print("Loading Datasets")
    train_set = crfDataset(
        args.train_data_path)  # os.path.join(args.data_path, "train_data"))
    test_set = crfDataset(
        args.test_data_path)  # os.path.join(args.data_path, "test_data"))
    train_loader = DataLoader(train_set,
                              args.batch_size,
                              num_workers=0,
                              pin_memory=True)
    test_loader = DataLoader(test_set,
                             args.batch_size,
                             num_workers=0,
                             pin_memory=True)

    print("Building models")
    print("model_add: {}".format(args.bert_model_path))
    model = BertCRF(args.bert_model_path,
                    len(tag2idx),
                    tag2idx,
                    START_TAG,
                    END_TAG,
                    with_lstm=args.with_lstm,
                    lstm_layers=args.rnn_layer,
                    bidirection=args.lstm_bidirectional,
                    lstm_hid_size=args.lstm_hid_size,
                    dropout=args.dropout)
    if args.load_chkpoint:
        print("==Loading Model from checkpoint: {}".format(
            args.chkpoint_model))
        model.load_state_dict(torch.load(args.chkpoint_model))
    print(model)
    model.to(device)

    crf_params = list(map(id, model.crf.parameters()))
    base_params = filter(lambda p: id(p) not in crf_params, model.parameters())

    optimizer = AdamW([{
        "params": base_params
    }, {
        "params": model.crf.parameters(),
        "lr": args.crf_lr
    }],
                      lr=args.lr)
    if args.load_chkpoint:
        print("==Loading optimizer from checkpoint: {}".format(
            args.chkpoint_optim))
        optimizer.load_state_dict(torch.load(args.chkpoint_optim))

    tokenizer = BertTokenizer.from_pretrained(args.bert_tokenizer_path)

    print("Training", datetime.datetime.now())
    print("Cuda Usage: {}, device: {}".format(use_cuda, device))

    step = 0

    best_f1 = 0
    patience = 0
    early_stop = False

    for eidx in range(1, args.epochs + 1):
        if eidx == 2:
            model.debug = True
        if early_stop:
            print("Early stop. epoch {} step {} best f1 {}".format(
                eidx, step, best_f1))
            break
            # sys.exit(0)
        print("Start epoch {}".format(eidx).center(60, "="))

        for bidx, batch in enumerate(train_loader):
            model.train()
            x_batch, y_batch = batch[0], batch[1]
            input_ids, segment_ids, mask = prepare_xbatch_for_bert(
                x_batch,
                tokenizer,
                max_len=args.max_len,
                batch_first=True,
                device=device)

            y_batch = _prepare_data(y_batch,
                                    tag2idx,
                                    END_TAG,
                                    device,
                                    max_len=args.max_len,
                                    batch_first=True)

            optimizer.zero_grad()
            loss = model.neg_log_likelihood(input_ids, segment_ids, mask,
                                            y_batch)
            batch_size = input_ids.size(1)
            loss /= batch_size
            # print(loss)
            loss.backward()
            optimizer.step()
            # break
            step += 1
            if step % args.log_interval == 0:
                print("epoch {} step {} batch {} loss {}".format(
                    eidx, step, bidx, loss))
            if step % args.save_interval == 0:
                torch.save(model.state_dict(),
                           os.path.join(args.model_name, "newest_model"))
                torch.save(optimizer.state_dict(),
                           os.path.join(args.model_name, "newest_optimizer"))
            if step % args.valid_interval == 0:
                f1, precision, recall = bert_evaluate(model,
                                                      test_loader,
                                                      tokenizer,
                                                      START_TAG,
                                                      END_TAG,
                                                      id2tag,
                                                      device=device,
                                                      mtype="crf")
                # tb_writer.add_scalar("eval/f1", f1, step)
                # tb_writer.add_scalar("eval/precision", precision, step)
                # tb_writer.add_scalar("eval/recall", recall, step)
                print("[valid] epoch {} step {} f1 {} precision {} recall {}".
                      format(eidx, step, f1, precision, recall))
                if f1 > best_f1:
                    patience = 0
                    best_f1 = f1
                    torch.save(model.state_dict(),
                               os.path.join(args.model_name, "best_model"))
                    torch.save(optimizer.state_dict(),
                               os.path.join(args.model_name, "best_optimizer"))
                else:
                    patience += 1
                    if patience == args.patience:
                        early_stop = True