Example #1
0
def train(train_data,dev_data,model_file):
    input_sent = train_data['input_sent']
    target = train_data['labels']
    dev_input_ids = dev_data['input_sent']
    dev_target = dev_data['labels']
    train_size=len(input_sent)
    batch_size=BATCH_SIZE
    batch_count=int(math.ceil(train_size)/batch_size)
    model=EvidenceClassifier()
    weights = torch.tensor([1.0])
    
    # logging(model)
    
    if torch.cuda.is_available():
        # model.cuda
        model = nn.DataParallel(model, device_ids = [0,1])
        model.to(f'cuda:{model.device_ids[0]}')

        weights = weights.cuda()
    # criterion = nn.NLLLoss(weight=weights, reduction='mean',ignore_index=2)
    criterion = nn.BCEWithLogitsLoss(reduction='mean',pos_weight= weights)
    optimizer = AdamW(model.parameters(),lr=1e-05,correct_bias=False)
    
    logging(optimizer)
    
    best_dev_acc = -1
    best_epoch_idx = -1
    best_epoch_seed = -1
    start_epoch = 0
    ckp_path=os.path.join('checkpoint',model_name+'_checkpoint.pt')
    best_epoch_idx,best_epoch_seed,best_dev_acc,start_epoch,model,optimizer=load_ckp(ckp_path, model, optimizer)
    
    for epoch_idx in range(start_epoch, EPOCH):
        model.train()
        model.zero_grad()
        logging('Epoch:', epoch_idx + 1)
        cur_seed = RANDOM_SEED + epoch_idx + 1
        set_random_seeds(cur_seed)
        random.shuffle(cur_shuffled_train_data)
        start_time = datetime.datetime.now()
        train_loss_val = 0
        is_best = False
        
        for batch_idx in tqdm(range(0,batch_count)):
            batch_start = batch_idx * batch_size
            batch_end = min(train_size,batch_start+batch_size)
            data = get_batch_data(input_sent[batch_start:batch_end],target[batch_start:batch_end])
            batch_sent_ids = torch.tensor(data['sent_ids']).to(torch.int64)
            batch_sent_attention = torch.tensor(data['sent_attention']).to(torch.int64)
            batch_sent_mask = torch.tensor(data['sent_mask']).float()
            batch_evi_targets = torch.tensor(data['targets']).to(torch.int64)
            
            if torch.cuda.is_available():
                batch_sent_ids = batch_sent_ids.cuda()
                batch_sent_attention = batch_sent_attention.cuda()
                batch_sent_mask = batch_sent_mask.cuda()
                batch_evi_targets = batch_evi_targets.cuda()
            batch_sent_ids = autograd.Variable(batch_sent_ids)
            batch_sent_attention = autograd.Variable(batch_sent_attention)
            batch_sent_mask = autograd.Variable(batch_sent_mask)
            batch_evi_targets = autograd.Variable(batch_evi_targets)

            outputs = model(batch_sent_ids,batch_sent_attention,batch_sent_mask,is_training=True)
            y_pred = outputs.view((outputs.shape[0]*outputs.shape[1]))
            labels = batch_evi_targets.view((batch_evi_targets.shape[0]*batch_evi_targets.shape[1])).float()
            y_pred[labels==2] = -np.Inf
            labels[labels==2] = 0

            loss = criterion(y_pred,labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 10.0)
            # if (batch_idx + 1) % update_freq == 0:
            optimizer.step()
            # model.zero_grad()
            train_loss_val +=loss.item()
        train_loss_val /=batch_count
        end_time = datetime.datetime.now()
        logging('\nTraining_loss: ',train_loss_val)
        logging('Time: ',end_time-start_time)
        logging('\nDev_Results\n')

        acc,F1 = predict(dev_data,model)
        logging('Dev acc:',round(acc,3))
        logging('Dev F1:',round(F1,3))
        if F1 > best_dev_acc:
            best_epoch_idx=epoch_idx+1
            best_epoch_seed=cur_seed
            logging("model saved ...")
            best_dev_acc=F1
            torch.save(model.state_dict(),model_file)
            is_best=True
        checkpoint = {
        'best_epoch_idx':best_epoch_idx,
        'best_epoch_seed':best_epoch_seed,
        'best_dev_acc':best_dev_acc,
        'epoch':epoch_idx+1,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        }
        save_ckp(checkpoint, is_best, 'checkpoint', 'checkpoint') # uncomment this when running for long time
        logging('\ncheckpoint updated\n\n')
        if epoch_idx+1-best_epoch_idx>=early_stop_count:
            break

    logging("*"*50)
    logging('Best epoch',best_epoch_idx)
    logging('Best Epoch Seed:', best_epoch_seed)
Example #2
0
def main(parser):
    # Config
    args = parser.parse_args()
    data_dir = Path(args.data_dir)
    model_dir = Path(args.model_dir)

    # data_config = Config(json_path=data_dir / 'config.json')
    model_config = Config(json_path=model_dir / 'config.json')

    # Vocab & Tokenizer
    tok_path = get_tokenizer() # ./tokenizer_78b3253a26.model
    ptr_tokenizer = SentencepieceTokenizer(tok_path)

    _, vocab_of_gluonnlp = get_pytorch_kobert_model()
    token_to_idx = vocab_of_gluonnlp.token_to_idx

    model_config.vocab_size = len(token_to_idx)
    vocab = Vocabulary(token_to_idx=token_to_idx)

    print("len(token_to_idx): ", len(token_to_idx))
    with open(model_dir / "token2idx_vocab.json", 'w', encoding='utf-8') as f:
        json.dump(token_to_idx, f, ensure_ascii=False, indent=4)

    # save vocab & tokenizer
    with open(model_dir / "vocab.pkl", 'wb') as f:
        pickle.dump(vocab, f)

    # load vocab & tokenizer
    with open(model_dir / "vocab.pkl", 'rb') as f:
        vocab = pickle.load(f)

    tokenizer = Tokenizer(vocab=vocab, split_fn=ptr_tokenizer, pad_fn=keras_pad_fn, maxlen=model_config.maxlen)
    ner_formatter = NamedEntityRecognitionFormatter(vocab=vocab, tokenizer=tokenizer, maxlen=model_config.maxlen, model_dir=model_dir)

    # Train & Val Datasets
    cwd = Path.cwd()
    data_in = cwd / "data_in"
    train_data_dir = data_in / "NER-master" / "말뭉치 - 형태소_개체명"
    tr_clf_ds = NamedEntityRecognitionDataset(train_data_dir=train_data_dir, model_dir=model_dir)
    tr_clf_ds.set_transform_fn(transform_source_fn=ner_formatter.transform_source_fn, transform_target_fn=ner_formatter.transform_target_fn)
    tr_clf_dl = DataLoader(tr_clf_ds, batch_size=model_config.batch_size, shuffle=True, num_workers=4, drop_last=False)

    # Model
    model = KobertCRF(config=model_config, num_classes=len(tr_clf_ds.ner_to_index))
    model.train()

    # optim
    train_examples_len = len(tr_clf_ds)
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]

    # num_train_optimization_steps = int(train_examples_len / model_config.batch_size / model_config.gradient_accumulation_steps) * model_config.epochs
    t_total = len(tr_clf_dl) // model_config.gradient_accumulation_steps * model_config.epochs
    optimizer = AdamW(optimizer_grouped_parameters, lr=model_config.learning_rate, eps=model_config.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=model_config.warmup_steps, t_total=t_total)

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    n_gpu = torch.cuda.device_count()
    # if n_gpu > 1:
    #     model = torch.nn.DataParallel(model)
    model.to(device)

    # save
    tb_writer = SummaryWriter('{}/runs'.format(model_dir))
    checkpoint_manager = CheckpointManager(model_dir)
    summary_manager = SummaryManager(model_dir)
    best_val_loss = 1e+10
    best_train_acc = 0

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(tr_clf_ds))
    logger.info("  Num Epochs = %d", model_config.epochs)
    logger.info("  Instantaneous batch size per GPU = %d", model_config.batch_size)
    # logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
    #                args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d", model_config.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    best_dev_acc, best_dev_loss = 0.0, 99999999999.0
    best_steps = 0
    model.zero_grad()
    set_seed()  # Added here for reproductibility (even between python 2 and 3)

    # Train
    train_iterator = trange(int(model_config.epochs), desc="Epoch")
    for _epoch, _ in enumerate(train_iterator):
        epoch_iterator = tqdm(tr_clf_dl, desc="Iteration") # , disable=args.local_rank not in [-1, 0]
        epoch = _epoch
        for step, batch in enumerate(epoch_iterator):
            model.train()
            x_input, token_type_ids, y_real = map(lambda elm: elm.to(device), batch)
            log_likelihood, sequence_of_tags = model(x_input, token_type_ids, y_real)

            # loss: negative log-likelihood
            loss = -1 * log_likelihood

            if n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel training
            if model_config.gradient_accumulation_steps > 1:
                loss = loss / model_config.gradient_accumulation_steps

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), model_config.max_grad_norm)
            tr_loss += loss.item()

            if (step + 1) % model_config.gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                with torch.no_grad():
                    sequence_of_tags = torch.tensor(sequence_of_tags)
                    print("sequence_of_tags: ", sequence_of_tags)
                    print("y_real: ", y_real)
                    print("loss: ", loss)
                    print("(sequence_of_tags == y_real): ", (sequence_of_tags == y_real))

                    mb_acc = (sequence_of_tags == y_real).float()[y_real != vocab.PAD_ID].mean()

                tr_acc = mb_acc.item()
                tr_loss_avg = tr_loss / global_step
                tr_summary = {'loss': tr_loss_avg, 'acc': tr_acc}

                # if step % 50 == 0:
                print('epoch : {}, global_step : {}, tr_loss: {:.3f}, tr_acc: {:.2%}'.format(epoch + 1, global_step,
                                                                                             tr_summary['loss'],
                                                                                             tr_summary['acc']))

                if model_config.logging_steps > 0 and global_step % model_config.logging_steps == 0:
                    # Log metrics
                    if model_config.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        pass
                    tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss) / model_config.logging_steps, global_step)
                    logger.info("Average loss: %s at global step: %s",
                                str((tr_loss - logging_loss) / model_config.logging_steps), str(global_step))
                    logging_loss = tr_loss

                if model_config.save_steps > 0 and global_step % model_config.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(model_config.output_dir, 'epoch-{}'.format(epoch + 1))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    logger.info("Saving model checkpoint to %s", output_dir)

                    state = {'global_step': global_step + 1,
                             'model_state_dict': model.state_dict(),
                             'opt_state_dict': optimizer.state_dict()}
                    summary = {'train': tr_summary}
                    summary_manager.update(summary)
                    summary_manager.save('summary.json')

                    is_best = tr_acc >= best_train_acc  # acc 기준 (원래는 train_acc가 아니라 val_acc로 해야)
                    # Save
                    if is_best:
                        best_train_acc = tr_acc
                        checkpoint_manager.save_checkpoint(state,
                                                           'best-epoch-{}-step-{}-acc-{:.3f}.bin'.format(epoch + 1,
                                                                                                         global_step,
                                                                                                         tr_acc))
                    else:
                        torch.save(state, os.path.join(output_dir,
                                                       'model-epoch-{}-step-{}-acc-{:.3f}.bin'.format(epoch + 1,
                                                                                                      global_step,
                                                                                                      tr_acc)))

    tb_writer.close()
    logger.info(" global_step = %s, average loss = %s", global_step, tr_loss / global_step)

    return global_step, tr_loss / global_step, best_steps
    def train(self):
        # Model
        model = KobertBiLSTMCRF(config=self.model_config,
                                num_classes=len(self.tr_ds.ner_to_index))
        model.train()

        # optim
        train_examples_len = len(self.tr_ds)
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]

        # num_train_optimization_steps = int(train_examples_len / model_config.batch_size / model_config.gradient_accumulation_steps) * model_config.epochs
        t_total = len(
            self.tr_dl
        ) // self.model_config.gradient_accumulation_steps * self.model_config.epochs
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=self.model_config.learning_rate,
                          eps=self.model_config.adam_epsilon)
        scheduler = WarmupLinearSchedule(
            optimizer,
            warmup_steps=self.model_config.warmup_steps,
            t_total=t_total)

        device = torch.device(
            'cuda') if torch.cuda.is_available() else torch.device('cpu')

        n_gpu = torch.cuda.device_count()
        # if n_gpu > 1:
        #     model = torch.nn.DataParallel(model)
        model.to(device)

        # save
        tb_writer = SummaryWriter('{}/runs'.format(self.model_dir))
        checkpoint_manager = CheckpointManager(self.model_dir)
        summary_manager = SummaryManager(self.model_dir)
        best_val_loss = 1e+10
        best_train_acc = 0

        # Train!
        self.logger.info("***** Running training *****")
        self.logger.info("  Num examples = %d", len(self.tr_ds))
        self.logger.info("  Num Epochs = %d", self.model_config.epochs)
        self.logger.info("  Instantaneous batch size per GPU = %d",
                         self.model_config.batch_size)
        # logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
        #                args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
        self.logger.info("  Gradient Accumulation steps = %d",
                         self.model_config.gradient_accumulation_steps)
        self.logger.info("  Total optimization steps = %d", t_total)

        global_step = 0
        tr_loss, logging_loss = 0.0, 0.0
        best_dev_acc, best_dev_loss = 0.0, 99999999999.0
        best_steps = 0
        model.zero_grad()
        self.set_seed(
        )  # Added here for reproductibility (even between python 2 and 3)

        # Train
        train_iterator = trange(int(self.model_config.epochs), desc="Epoch")
        for _epoch, _ in enumerate(train_iterator):
            epoch_iterator = tqdm(
                self.tr_dl,
                desc="Iteration")  # , disable=args.local_rank not in [-1, 0]
            epoch = _epoch
            for step, batch in enumerate(epoch_iterator):
                model.train()
                x_input, token_type_ids, y_real = map(
                    lambda elm: elm.to(device), batch)
                log_likelihood, sequence_of_tags = model(
                    x_input, token_type_ids, y_real)

                # loss: negative log-likelihood
                loss = -1 * log_likelihood

                if n_gpu > 1:
                    loss = loss.mean(
                    )  # mean() to average on multi-gpu parallel training
                if self.model_config.gradient_accumulation_steps > 1:
                    loss = loss / self.model_config.gradient_accumulation_steps

                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               self.model_config.max_grad_norm)
                tr_loss += loss.item()

                if (step + 1
                    ) % self.model_config.gradient_accumulation_steps == 0:
                    optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    model.zero_grad()
                    global_step += 1

                    with torch.no_grad():
                        sequence_of_tags = torch.tensor(sequence_of_tags)
                        print("sequence_of_tags: ", sequence_of_tags)
                        print("y_real: ", y_real)
                        print("loss: ", loss)
                        print("(sequence_of_tags == y_real): ",
                              (sequence_of_tags == y_real))
                        _tags = torch.squeeze(sequence_of_tags, dim=0)
                        mb_acc = (_tags == y_real).float()[
                            y_real != self.vocab.PAD_ID].mean()
                        #mb_acc = (sequence_of_tags == y_real).float()[y_real != self.vocab.PAD_ID].mean()

                    tr_acc = mb_acc.item()
                    tr_loss_avg = tr_loss / global_step
                    tr_summary = {'loss': tr_loss_avg, 'acc': tr_acc}

                    # if step % 50 == 0:
                    print(
                        'epoch : {}, global_step : {}, tr_loss: {:.3f}, tr_acc: {:.2%}'
                        .format(epoch + 1, global_step, tr_summary['loss'],
                                tr_summary['acc']))

                    if self.model_config.logging_steps > 0 and global_step % self.model_config.logging_steps == 0:
                        # Log metrics
                        if self.model_config.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                            pass
                        tb_writer.add_scalar('lr',
                                             scheduler.get_lr()[0],
                                             global_step)
                        tb_writer.add_scalar('loss', (tr_loss - logging_loss) /
                                             self.model_config.logging_steps,
                                             global_step)
                        self.logger.info(
                            "Average loss: %s at global step: %s",
                            str((tr_loss - logging_loss) /
                                self.model_config.logging_steps),
                            str(global_step))
                        logging_loss = tr_loss

                    if self.model_config.save_steps > 0 and global_step % self.model_config.save_steps == 0:

                        eval_summary, list_of_y_real, list_of_pred_tags = self.evaluate(
                            model, self.val_dl)

                        # Save model checkpoint
                        output_dir = os.path.join(self.model_config.output_dir,
                                                  'epoch-{}'.format(epoch + 1))
                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
                        self.logger.info("Saving model checkpoint to %s",
                                         output_dir)

                        state = {
                            'global_step': global_step + 1,
                            'model_state_dict': model.state_dict(),
                            'opt_state_dict': optimizer.state_dict()
                        }
                        summary = {'train': tr_summary}
                        summary_manager.update(summary)
                        summary_manager.save('summary.json')

                        is_best = tr_acc >= best_train_acc  # acc 기준 (원래는 train_acc가 아니라 val_acc로 해야)
                        # Save
                        if is_best:
                            best_train_acc = tr_acc
                            checkpoint_manager.save_checkpoint(
                                state,
                                'best-epoch-{}-step-{}-acc-{:.3f}.bin'.format(
                                    epoch + 1, global_step, tr_acc))

                            print(
                                "Saving model checkpoint as best-epoch-{}-step-{}-acc-{:.3f}.bin"
                                .format(epoch + 1, global_step, best_dev_acc))

                            # print classification report and save confusion matrix
                            cr_save_path = self.model_dir + '/best-epoch-{}-step-{}-acc-{:.3f}-cr.csv'.format(
                                epoch + 1, global_step, best_dev_acc)
                            cm_save_path = self.model_dir + '/best-epoch-{}-step-{}-acc-{:.3f}-cm.png'.format(
                                epoch + 1, global_step, best_dev_acc)

                            self.save_cr_and_cm(list_of_y_real,
                                                list_of_pred_tags,
                                                cr_save_path=cr_save_path,
                                                cm_save_path=cm_save_path)
                        else:
                            torch.save(
                                state,
                                os.path.join(
                                    output_dir,
                                    'model-epoch-{}-step-{}-acc-{:.3f}.bin'.
                                    format(epoch + 1, global_step, tr_acc)))

        tb_writer.close()
        self.logger.info(" global_step = %s, average loss = %s", global_step,
                         tr_loss / global_step)
Example #4
0
class MaskGenerator(object):
    def __init__(self,
                 args,
                 model_base,
                 config,
                 tokenizer,
                 training=True,
                 base=False):
        self.args = args
        self.pytorch_version = args.pytorch_version
        self.masking_type = args.masking
        # self.gap_acc_reward = args.gap_acc_reward

        self.tokenizer = tokenizer
        # Construct Stopword list
        stop_words = []
        for c in string.punctuation:
            stop_words.append(c)
        self.stop_words = list(set(stop_words))
        self.stop_words.append("[CLS]")
        self.stop_words.append("[SEP]")
        self.stop_words.append("[UNK]")
        self.log_file = os.path.join(args.output_dir, "training_logs.txt")

        if self.masking_type == 'entity':
            self.nlp = spacy.load("en_core_web_sm")
        # Indicate that whether this mask generator is base or not
        self.base = base
        self.training = training

        # Set Mask Generating model inside the mask generator
        self.model = model_base

        # Lock the bert parameters
        if self.model is not None and not args.continual:
            for p in self.model.bert.parameters():
                p.requires_grad = False

        if self.model is not None and self.training:
            # Set Optimizer for training neural mask generator
            no_decay = ['bias', 'LayerNorm.weight']
            learnable_weight = []
            learnable_bias = []
            for n, p in self.model.named_parameters():
                if 'bert' not in n and 'weight' in n:
                    learnable_weight.append(n)
                if 'bert' not in n and 'bias' in n:
                    learnable_bias.append(n)
            optimizer_grouped_parameters = [{
                'params': [
                    p for n, p in self.model.named_parameters()
                    if n in learnable_weight
                ],
                'weight_decay':
                args.weight_decay
            }, {
                'params': [
                    p for n, p in self.model.named_parameters()
                    if n in learnable_bias
                ],
                'weight_decay':
                0.0
            }]
            self.optimizer = AdamW(optimizer_grouped_parameters,
                                   lr=args.mask_learning_rate,
                                   eps=args.adam_epsilon)

        self.adv_rewards = []
        self.adv_losses = []
        self.rewards = []
        self.reward_limit = 10

        # History for RL training
        self.initialize_history()
        # Accuracy is shared accross episode
        self.former_accuracy = None
        self.acc_history = []
        self.base_acc_history = []
        self.rnd_acc_history = []
        self.best_diff = -100
        self.best_improve_indicator = 0
        self.improve_indicator = 0
        self.rand_self_indicator = 0
        self.regret = 0

        self.replay_memory = []
        self.state_containers = []

        self.global_step = 0

        self.episode = 0
        # Tensorboard writer for mask generator
        if self.training:
            self.tb_writer = SummaryWriter(
                os.path.join(args.output_dir, "logs"))

        # Initialize mask checkpoint directory

        if self.base and args.self_play == "learning":
            self.mask_base_dir = os.path.join(args.output_dir, "mask_base")
        else:
            self.mask_base_dir = os.path.join(args.output_dir, "mask")

        if not os.path.exists(self.mask_base_dir) and args.local_rank in [
                -1, 0
        ]:
            os.makedirs(self.mask_base_dir)

        self.mask_last_dir = os.path.join(self.mask_base_dir, "last")
        if not os.path.exists(self.mask_last_dir) and args.local_rank in [
                -1, 0
        ]:
            os.makedirs(self.mask_last_dir)

        self.checkpoint = os.path.join(self.mask_base_dir, "checkpoint")
        if not os.path.exists(self.checkpoint) and args.local_rank in [-1, 0]:
            os.makedirs(self.checkpoint)

        self.profile_dir = os.path.join(args.output_dir, "replay_profile_dir")
        os.makedirs(self.profile_dir, exist_ok=True)

    def save(self, episode):
        if "neural" not in self.masking_type: return

        model_to_save = self.model.module if hasattr(self.model,
                                                     'module') else self.model
        logger.info("Saving last mask model checkpoint to %s",
                    self.mask_last_dir)
        model_to_save.save_pretrained(self.mask_last_dir)

        if episode % 50 == 0:
            mask_tmp_dir = os.path.join(self.mask_base_dir,
                                        "{}".format(episode))
            if not os.path.exists(mask_tmp_dir) and self.args.local_rank in [
                    -1, 0
            ]:
                os.makedirs(mask_tmp_dir)
            logger.info("Saving %d-th mask model checkpoint to %s" %
                        (episode, mask_tmp_dir))
            model_to_save.save_pretrained(mask_tmp_dir)

        if episode % 10 == 0:
            model_to_save.save_pretrained(self.checkpoint)
            torch.save(
                {
                    'optim': self.optimizer.state_dict(),
                    'episode': self.episode,
                }, os.path.join(self.checkpoint, "last.ckpt"))

    def initialize_history(self):
        self.masked_log_probs = []
        self.log_probs = []
        while len(self.rewards) > self.reward_limit:
            self.rewards.pop(0)
        self.non_mask_ratios = []
        self.entropy = []

        self.action_history = []

        self.tmp_replay_memory = []

    def set_masking_type(self, masking_type):
        self.masking_type = masking_type

    def mask(self, inputs, tokenizer, args, visualize=False, model=None):
        masking_type = self.masking_type
        if masking_type == "random":
            outputs = self.random_mask_tokens(inputs, tokenizer, args,
                                              visualize)
        elif masking_type == "neural":
            outputs = self.neural_mask_tokens(inputs, tokenizer, args,
                                              visualize, model)
        elif masking_type == "whole":
            outputs = self.whole_random_mask_tokens(inputs, tokenizer, args,
                                                    visualize)
        elif masking_type == "span":
            outputs = self.span_random_mask_tokens(inputs, tokenizer, args,
                                                   visualize)
        elif masking_type == "entity":
            outputs = self.entity_random_mask_tokens(inputs, tokenizer, args,
                                                     visualize)
        elif masking_type == "punc":
            outputs = self.punc_random_mask_tokens(inputs, tokenizer, args,
                                                   visualize)

        return outputs

    def neural_mask_tokens(self,
                           inputs,
                           tokenizer,
                           args,
                           visualize=False,
                           model=None):
        """ Prepare masked tokens based on neural mask generator. """
        if args.device2 is not None: device = args.device2
        else: device = args.device

        self.device = device

        if len(inputs.shape) == 1:
            inputs = inputs.unsqueeze(0)

        labels = inputs.clone()
        inputs = inputs.to(device)
        raw_inputs = inputs.clone()
        input_mask = ~inputs.eq(args.pad_token)

        masked_inputs = raw_inputs

        _input_mask = input_mask
        seq_length = inputs.shape[-1]

        self.model.eval()

        if args.continual:
            original_device = next(model.parameters()).device
            # Temporarily send to extract feature
            model.to(device)
            with torch.no_grad():
                model.eval()
                if hasattr(model, 'bert'):
                    language_model = model.bert
                elif hasattr(model, 'distilbert'):
                    language_model = model.distilbert
                else:
                    raise NotImplementedError
                outputs = language_model(inputs, attention_mask=input_mask)
            inputs = outputs[0]
            model.to(original_device)

        with torch.no_grad():
            outputs = self.model(inputs,
                                 attention_mask=input_mask,
                                 visualize=visualize,
                                 args=args)

        logit, value = outputs[0], outputs[1]
        logit = logit * input_mask.to(torch.float)
        logit[logit == 0] = -float('inf')

        prob = F.softmax(logit, 1)
        log_prob = F.log_softmax(logit, 1)

        num_samples = [
            max(int(i.sum().item() * args.masking_prob), 1) for i in input_mask
        ]
        entropy = []

        for i, num in enumerate(input_mask):
            non_mask_num = int(num.sum().item())
            entropy.append(-(log_prob[i][:non_mask_num] *
                             prob[i][:non_mask_num]).sum().item())

        batch_log_probs = []
        for i in range(len(num_samples)):
            nmn = int(input_mask[i].sum().item())
            _prob = prob[i][:nmn] + EPS

            if self.training:
                normalize_counter = Counter()
                normalize_counter.update(raw_inputs[i][:nmn].tolist())
                normalize_factor = \
                        [1 / np.sqrt(normalize_counter[w]) for w in raw_inputs[i][:nmn].tolist()]

            if self.training and (args.self_play != "frozen" or not self.base):
                action = _prob.multinomial(num_samples=num_samples[i])
            else:
                action = _prob.topk(k=num_samples[i], dim=0)[1]

            self.action_history.append(action.tolist())

            if self.training and args.save_history:
                if args.self_play == "learning" or (args.self_play
                                                    in ["frozen", "random"]
                                                    and not self.base):
                    replay_memory_same_context = []
                    for a in action:
                        feature = None
                        v = value[i]
                        sub_p = normalize_factor[a]
                        replay_memory_same_context.append(
                            MemoryEntry(
                                inputs[i].cpu(),
                                a.cpu(),
                                log_prob[i][a].detach().cpu(),
                                v.detach().cpu(),
                                feature=feature,
                                sub_p=sub_p,
                                raw_input=raw_inputs[i].cpu(),
                            ))
                    self.tmp_replay_memory.append(replay_memory_same_context)

                elif not self.base:
                    for a in action:
                        feature = None
                        v = value[i]
                        self.tmp_replay_memory.append(
                            MemoryEntry(inputs[i].cpu(),
                                        a.cpu(),
                                        log_prob[i][a].detach().cpu(),
                                        v.detach().cpu(),
                                        feature=feature))

            if self.training and False:
                log_prob_ = log_prob[i].gather(0, action)
                batch_log_probs.append(log_prob_)

            masked_indices = torch.zeros(labels[i].shape).to(device).scatter_(
                0, action, 1).bool()
            labels[i][~masked_indices] = -100

            _input_mask = (~masked_inputs[i].eq(args.pad_token)).to(
                torch.float)
            indices_replaced = torch.bernoulli(
                _input_mask * 0.8).bool() & masked_indices
            masked_inputs[i][
                indices_replaced] = tokenizer.convert_tokens_to_ids(
                    tokenizer.mask_token)

            indices_random = torch.bernoulli(
                _input_mask * 0.5).bool() & masked_indices & ~indices_replaced
            random_words = torch.randint(len(tokenizer),
                                         labels[i].shape,
                                         dtype=torch.long).cuda(device)
            masked_inputs[i][indices_random] = random_words[indices_random]

        if self.training:
            if False:
                self.log_probs.append(torch.cat(batch_log_probs))
            self.entropy.append(sum(entropy) / len(entropy))

        if visualize:
            non_mask_num = int(input_mask.sum().item())
            mask_prob = prob[0][:non_mask_num].cpu().numpy()
            return masked_inputs, labels, mask_prob, action
        return masked_inputs, labels

    def random_mask_tokens(self, inputs, tokenizer, args, visualize=False):
        """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
        if len(inputs.shape) == 1:
            inputs = inputs.unsqueeze(0)

        labels = inputs.clone()
        puncs = [
            tokenizer.convert_tokens_to_ids(p)
            for p in list(string.punctuation)
        ]

        input_mask = (~inputs.eq(0)).to(torch.float)
        masking_prob = args.masking_prob
        # masked position = 1 or 0
        # Consider padding
        masked_indices = torch.bernoulli(input_mask * masking_prob).bool()
        labels[~masked_indices] = -100

        indices_replaced = torch.bernoulli(
            input_mask * 0.8).bool() & masked_indices
        inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(
            tokenizer.mask_token)

        indices_random = torch.bernoulli(
            input_mask * 0.5).bool() & masked_indices & ~indices_replaced

        random_words = torch.randint(len(tokenizer),
                                     labels.shape,
                                     dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        if self.base:
            for i in range(masked_indices.shape[0]):
                action = masked_indices[i].nonzero().view(-1).tolist()
                self.action_history.append(action)

        return inputs, labels

    def whole_random_mask_tokens(self,
                                 inputs,
                                 tokenizer,
                                 args,
                                 visualize=False):
        """ Prepare masked tokens for masked language modeling with whole word masking. """
        if args.device2 is not None: device = args.device2
        else: device = args.device
        labels = []
        new_inputs = []
        # Doing operation per batch
        for input_raw in inputs:
            label_raw = input_raw.clone()

            input_mask = (~input_raw.eq(0))
            tokens = tokenizer.convert_ids_to_tokens(input_raw.tolist())

            cand_indexes = []
            # Aggregate Subtokens to one token
            for (i, token) in enumerate(tokens):
                if token == "[CLS]" or token == "[SEP]" or token == "[PAD]":
                    continue
                if args.stopword_masking and token in self.stop_words:
                    continue
                if len(cand_indexes) >= 1 and token.startswith(
                        "##") and not args.stopword_masking:
                    cand_indexes[-1].append(i)
                else:
                    cand_indexes.append([i])

            random.shuffle(cand_indexes)

            # Ignore padding
            num_to_predict = max(
                1, int(round(input_mask.sum().item() * args.masking_prob)))

            masked_lms = []
            covered_indexes = set()
            for index_set in cand_indexes:
                if len(masked_lms) >= num_to_predict:
                    break

                if len(masked_lms) + len(index_set) > num_to_predict:
                    continue

                is_any_index_covered = False
                for index in index_set:
                    if index in covered_indexes:
                        is_any_index_covered = True
                        break

                if is_any_index_covered:
                    continue

                for index in index_set:
                    covered_indexes.add(index)

                    masked_token = None

                    # 80% of time, replace with [MASK]
                    if random.random() < 0.8:
                        masked_token = tokenizer.convert_tokens_to_ids(
                            tokenizer.mask_token)
                    else:
                        if random.random() < 0.5:
                            masked_token = tokenizer.convert_tokens_to_ids(
                                tokens[index])
                        else:
                            masked_token = random.randint(
                                0,
                                len(tokenizer) - 1)
                    tokens[index] = masked_token

                    masked_lms.append((index, tokens[index]))

            assert len(masked_lms) <= num_to_predict

            masked_lms = sorted(masked_lms, key=lambda x: x[0])

            masked_lms_pointer = 0
            for i in range(input_raw.shape[0]):
                if masked_lms_pointer < len(
                        masked_lms) and i == masked_lms[masked_lms_pointer][0]:
                    input_raw[i] = masked_lms[masked_lms_pointer][1]
                    masked_lms_pointer += 1
                else:
                    label_raw[i] = -100

            new_inputs.append(input_raw)
            labels.append(label_raw)

        labels = torch.stack(labels, 0)
        inputs = torch.stack(new_inputs, 0)

        return inputs, labels

    def span_random_mask_tokens(self,
                                inputs,
                                tokenizer,
                                args,
                                visualize=False):
        """ Prepare masked tokens for masked language modeling with whole word masking. """
        if args.device2 is not None: device = args.device2
        else: device = args.device
        labels = []
        new_inputs = []
        # Doing operation per batch
        for input_raw in inputs:
            label_raw = input_raw.clone()

            # Geometric distribution for sampling span length
            geometric = torch.distributions.geometric.Geometric(0.2)

            input_mask = (~input_raw.eq(0))

            tokens = tokenizer.convert_ids_to_tokens(input_raw.tolist())

            cand_indexes = []
            # Aggregate Subtokens to one token
            for (i, token) in enumerate(tokens):
                if token == "[CLS]" or token == "[SEP]" or token == "[PAD]":
                    continue

                if len(cand_indexes) >= 1 and token.startswith("##"):
                    cand_indexes[-1].append(i)

                else:
                    cand_indexes.append([i])

            num_to_predict = max(
                1, int(round(input_mask.sum().item() * args.masking_prob)))

            masked_lms = []
            covered_indexes = set()

            tolerance = 0

            while len(masked_lms) < num_to_predict:
                tolerance += 1
                # Breaking infinite loof (in case of cannot finding exact length for num_to_predict...)
                if tolerance > 1000000:
                    print("  I CAN'T TOLERATE ANY MORE!!!  ")
                    break

                # Randomly pick the starting cand indexes
                start = int(random.random() * len(cand_indexes))

                # If start index is already masked, continue
                if cand_indexes[start][0] in covered_indexes:
                    continue

                # Sample the length of random spans
                l = geometric.sample().item() + 1
                if l > 10:
                    l = 10  # Clipping length following the spanBERT paper

                l = int(l)
                masked_token_policy = None
                # Decide How to masking tokens in same span
                if random.random() < 0.8:
                    masked_token_policy = "mask"
                else:
                    if random.random() < 0.5:
                        masked_token_policy = "random"
                    else:
                        masked_token_policy = "original"
                # Q: What if the index exceeds the maximum length..?
                sampled_span = cand_indexes[start:start + l]
                index_set = []
                for idx in sampled_span:
                    index_set += idx
                if len(masked_lms) + len(index_set) > num_to_predict:
                    continue
                is_any_index_covered = False
                for index in index_set:
                    if index in covered_indexes:
                        is_any_index_covered = True
                        break
                if is_any_index_covered:
                    continue
                tolerance = 0
                for index in index_set:
                    covered_indexes.add(index)
                    masked_token = None
                    if masked_token_policy == "mask":
                        masked_token = tokenizer.convert_tokens_to_ids(
                            tokenizer.mask_token)
                    elif masked_token_policy == "original":
                        masked_token = tokenizer.convert_tokens_to_ids(
                            tokens[index])
                    elif masked_token_policy == "random":
                        masked_token = random.randint(0, len(tokenizer) - 1)
                    else:
                        assert False
                    masked_lms.append((index, masked_token))

            assert len(masked_lms) <= num_to_predict

            masked_lms = sorted(masked_lms, key=lambda x: x[0])
            masked_lms_pointer = 0
            for i in range(input_raw.shape[0]):
                if masked_lms_pointer < len(
                        masked_lms) and i == masked_lms[masked_lms_pointer][0]:
                    input_raw[i] = masked_lms[masked_lms_pointer][1]
                    masked_lms_pointer += 1
                else:
                    label_raw[i] = -100
            new_inputs.append(input_raw)
            labels.append(label_raw)

        labels = torch.stack(labels, 0)
        inputs = torch.stack(new_inputs, 0)

        if visualize:
            action = []
            for idx, label in enumerate(labels.squeeze(0)):
                if label.item() != -1:
                    action.append(idx)
            return inputs, labels, action

        return inputs, labels

    def entity_random_mask_tokens(self,
                                  inputs,
                                  tokenizer,
                                  args,
                                  visualize=False):
        """ Prepare masked tokens for masked language modeling with whole word masking. """
        if args.device2 is not None: device = args.device2
        else: device = args.device
        labels = []
        new_inputs = []
        # Doing operation per batch
        for input_raw in inputs:
            label_raw = input_raw.clone()

            input_mask = (~input_raw.eq(0))
            tokens = tokenizer.convert_ids_to_tokens(input_raw.tolist())

            cand_tokens = []
            for token in tokens[:input_mask.sum().item()]:
                if token == "[CLS]" or token == "[SEP]" or token == "[PAD]":
                    continue
                if len(cand_tokens) >= 1 and token.startswith("##"):
                    cand_tokens[-1].append(token)
                else:
                    cand_tokens.append([token])

            char_to_idx = []
            # It matches to cand_indexes
            sentence = ''
            for (i, t) in enumerate(cand_tokens):
                word = ' '.join(t).replace(" ##", "").replace("##", "")
                sentence += word
                for _ in range(len(word)):
                    char_to_idx.append(i)
                sentence += ' '
                char_to_idx.append(i)

            cand_indexes = []
            # Aggregate Subtokens to one token
            for (i, token) in enumerate(tokens):
                if token == "[CLS]" or token == "[SEP]" or token == "[PAD]":
                    continue
                if len(cand_indexes) >= 1 and token.startswith("##"):
                    cand_indexes[-1].append(i)
                else:
                    cand_indexes.append([i])

            doc = self.nlp(sentence)
            entity_idx = []
            for ent in doc.ents:
                start_idx = char_to_idx[ent.start_char]
                end_idx = char_to_idx[ent.end_char]
                for i in range(start_idx, end_idx + 1):
                    entity_idx.append(i)

            entity_cand_indexes = []
            non_cand_indexes = []
            for num, idxes in enumerate(cand_indexes):
                if num in entity_idx:
                    entity_cand_indexes.append(idxes)
                else:
                    non_cand_indexes.append(idxes)

            random.shuffle(entity_cand_indexes)
            random.shuffle(non_cand_indexes)
            cand_indexes = entity_cand_indexes + non_cand_indexes

            # Ignore padding
            num_to_predict = max(
                1, int(round(input_mask.sum().item() * args.masking_prob)))

            masked_lms = []
            covered_indexes = set()
            for index_set in cand_indexes:
                if len(masked_lms) >= num_to_predict:
                    break

                if len(masked_lms) + len(index_set) > num_to_predict:
                    continue

                is_any_index_covered = False
                for index in index_set:
                    if index in covered_indexes:
                        is_any_index_covered = True
                        break

                if is_any_index_covered:
                    continue

                for index in index_set:
                    covered_indexes.add(index)

                    masked_token = None

                    # 80% of time, replace with [MASK]
                    if random.random() < 0.8:
                        masked_token = tokenizer.convert_tokens_to_ids(
                            tokenizer.mask_token)
                    else:
                        if random.random() < 0.5:
                            masked_token = tokenizer.convert_tokens_to_ids(
                                tokens[index])
                        else:
                            masked_token = random.randint(
                                0,
                                len(tokenizer) - 1)
                    tokens[index] = masked_token

                    masked_lms.append((index, tokens[index]))

            assert len(masked_lms) <= num_to_predict

            masked_lms = sorted(masked_lms, key=lambda x: x[0])

            masked_lms_pointer = 0
            for i in range(input_raw.shape[0]):
                if masked_lms_pointer < len(
                        masked_lms) and i == masked_lms[masked_lms_pointer][0]:
                    input_raw[i] = masked_lms[masked_lms_pointer][1]
                    masked_lms_pointer += 1
                else:
                    label_raw[i] = -1

            new_inputs.append(input_raw)
            labels.append(label_raw)

        labels = torch.stack(labels, 0)
        inputs = torch.stack(new_inputs, 0)

        return inputs, labels

    def punc_random_mask_tokens(self,
                                inputs,
                                tokenizer,
                                args,
                                visualize=False):
        if args.device2 is not None: device = args.device2
        else: device = args.device

        labels = []
        new_inputs = []

        puncs = list(string.punctuation)

        for input_raw in inputs:
            label_raw = input_raw.clone()
            input_mask = (~input_raw.eq(0))
            tokens = tokenizer.convert_ids_to_tokens(input_raw.tolist())

            num_to_predict = max(
                1, int(round(input_mask.sum().item() * args.masking_prob)))

            punc_index = []
            non_punc_index = []
            for (i, t) in enumerate(tokens):
                if t in puncs:
                    punc_index.append(i)
                else:
                    non_punc_index.append(i)
            random.shuffle(punc_index)
            random.shuffle(non_punc_index)
            cand_indexes = punc_index + non_punc_index

            masked_lms = []
            covered_indexes = set()
            for index in cand_indexes:
                if len(masked_lms) >= num_to_predict:
                    break

                is_any_index_covered = False
                if index in covered_indexes:
                    is_any_index_covered = True
                    break

                if is_any_index_covered:
                    continue

                covered_indexes.add(index)
                masked_token = None

                if random.random() < 0.8:
                    masked_token = tokenizer.convert_tokens_to_ids(
                        tokenizer.mask_token)

                else:
                    if random.random() < 0.5:
                        masked_token = tokenizer.convert_tokens_to_ids(
                            tokens[index])
                    else:
                        masked_token = random.randint(0, len(tokenizer) - 1)
                tokens[index] = masked_token
                masked_lms.append((index, tokens[index]))

            assert len(masked_lms) <= num_to_predict

            masked_lms = sorted(masked_lms, key=lambda x: x[0])
            masked_lms_pointer = 0

            for i in range(input_raw.shape[0]):
                if masked_lms_pointer < len(
                        masked_lms) and i == masked_lms[masked_lms_pointer][0]:
                    input_raw[i] = masked_lms[masked_lms_pointer][1]
                    masked_lms_pointer += 1
                else:
                    label_raw[i] = -1

            new_inputs.append(input_raw)
            labels.append(label_raw)

        labels = torch.stack(labels, 0)
        inputs = torch.stack(new_inputs, 0)

        return inputs, labels

    # op = opponent, pl = player
    def append_reward_selfplay(self,
                               op_acc,
                               pl_acc,
                               op_history,
                               pl_history=None,
                               rnd_acc=None,
                               rnd_history=None):
        if "neural" not in self.masking_type:
            return

        pl_history = self.action_history

        reward = pl_acc - op_acc

        if reward > 0:
            reward = 1
            self.improve_indicator += 1
        elif reward == 0:
            reward = 0
        else:
            reward = -1
            self.improve_indicator -= 1
            self.regret += 1

        if rnd_acc is not None:
            rnd_reward = pl_acc - rnd_acc
            rnd_reward = np.sign(rnd_reward)
        else:
            rnd_reward = None

        assert len(op_history) == len(pl_history)
        if rnd_history is not None:
            assert len(rnd_history) == len(pl_history)

        tmp_replay_memory = []
        for i, transitions in enumerate(self.tmp_replay_memory):
            op_actions = op_history[i]
            pl_actions = pl_history[i]
            rnd_actions = rnd_history[
                i] if rnd_history is not None else pl_actions

            pl_actions_disjoint_op = list(set(pl_actions) - set(op_actions))
            pl_actions_disjoint_rnd = list(set(pl_actions) - set(rnd_actions))

            base_data = None
            for j, entry in enumerate(transitions):
                if entry.action in pl_actions_disjoint_op and entry.action in pl_actions_disjoint_rnd:
                    # Take minimum reward for both disjoint circumstances
                    entry.update_reward(
                        min(reward, rnd_reward
                            ) if rnd_history is not None else reward)

                elif entry.action in pl_actions_disjoint_op and entry.action not in pl_actions_disjoint_rnd:
                    entry.update_reward(reward)

                elif entry.action not in pl_actions_disjoint_op and entry.action in pl_actions_disjoint_rnd:
                    entry.update_reward(rnd_reward)

                else:
                    # Joint action - do not learn
                    continue

                if base_data is None:
                    assert entry.state.__class__.__name__ == "Tensor"
                    base_data = StateContainer(entry.state)
                    self.state_containers.append(base_data)

                entry.state = base_data
                base_data.counter_add()

                entry.z = 1
                tmp_replay_memory.append(entry)

        self.replay_memory += tmp_replay_memory

        start = time.time()
        while len(self.replay_memory) > self.args.memory_capacity:
            pop_entry = self.replay_memory.pop(0)
            pop_entry.state.counter_minus()

        # Clean memory with no reference
        delete_count = 0
        for idx, sc in enumerate(self.state_containers):
            if sc.counter == 0:
                self.state_containers.pop(idx)
                delete_count += 1

        while len(self.state_containers) > 5000:
            pop_entry = self.replay_memory.pop(0)
            pop_entry.state.counter_minus()

            sc = self.state_containers[0]
            if sc.counter == 0:
                self.state_containers.pop(0)
                delete_count += 1

        print("Elapsed Time for Pop memory: {}".format(time.time() - start))
        print("Memory len: {}".format(len(self.replay_memory)))

        print("Delete {} states among {}".format(delete_count,
                                                 len(self.state_containers)))

        if not self.base:
            if rnd_reward is not None:
                self.rewards.append(rnd_reward)
            else:
                self.rewards.append(reward)
            self.acc_history.append(pl_acc)
            self.base_acc_history.append(op_acc)

        if rnd_acc is not None and not self.base:
            self.rnd_acc_history.append(rnd_acc)
            if rnd_acc > pl_acc:
                self.rand_self_indicator -= 1
            elif rnd_acc < pl_acc:
                self.rand_self_indicator += 1

    def train_replay(self, args):
        if "neural" not in self.masking_type:
            return

        if args.device2 is not None: device = args.device2
        else: device = args.device

        if len(self.replay_memory) < args.replay_start:
            if not self.base: self.write_logs(rewards=self.rewards)
            return

        total_loss = 0
        total_policy_loss = 0
        total_value_loss = 0

        self.model.train()
        self.model.zero_grad()
        for _ in tqdm(range(args.replay_step),
                      desc="ReinforcementLearning",
                      position=0):
            if args.actor_critic:
                r = np.array([e.reward for e in self.replay_memory])
                V = np.array([e.value for e in self.replay_memory])
                _sampling_p = []
                if "R" in args.sampling_strategy:
                    _sampling_p.append(np.exp(r) / np.sum(np.exp(r)))
                if "A" in args.sampling_strategy:
                    _sampling_p.append(
                        np.exp(abs(r - V)) / np.sum(np.exp(abs(r - V))))
                if "V" in args.sampling_strategy:
                    _sampling_p.append(np.exp(V) / np.sum(np.exp(V)))
                # assert len(_sampling_p) > 0

                if len(_sampling_p) > 0:
                    sampling_p = sum(_sampling_p) / len(_sampling_p)
                else:
                    sampling_p = np.array([1 for e in self.replay_memory])
                sampling_p = sampling_p / sum(sampling_p)

                sub_p = np.array([e.sub_p for e in self.replay_memory])
                sampling_p *= sub_p
                sampling_p = sampling_p / sum(sampling_p)
                sampled_batch = np.random.choice(self.replay_memory,
                                                 args.replay_batch_size,
                                                 replace=False,
                                                 p=sampling_p)
            else:
                r = np.array([e.reward for e in self.replay_memory])
                sampling_p = np.exp(r) / np.sum(np.exp(r))
                sampled_batch = np.random.choice(self.replay_memory,
                                                 args.replay_batch_size,
                                                 replace=False,
                                                 p=sampling_p)
            inputs = torch.stack([e.state.data
                                  for e in sampled_batch]).to(device)
            action = torch.stack([e.action
                                  for e in sampled_batch]).view(-1,
                                                                1).to(device)
            reward = torch.tensor([e.reward for e in sampled_batch],
                                  device=device,
                                  dtype=torch.float)
            z = torch.tensor([e.z for e in sampled_batch],
                             device=device,
                             dtype=torch.float)

            if args.continual:
                raw_inputs = torch.stack([e.raw_input
                                          for e in sampled_batch]).to(device)
                input_mask = (~raw_inputs.eq(args.pad_token)).to(torch.float)
            else:
                input_mask = (~inputs.eq(args.pad_token)).to(torch.float)
            seq_length = inputs.shape[1]
            outputs = self.model(inputs, attention_mask=input_mask, args=args)
            if args.actor_critic:
                logit, value = outputs[0], outputs[1]
            else:
                logit = outputs[0]
            logit = logit * input_mask
            logit[logit == 0] = -float('inf')

            prob = F.softmax(logit, 1)
            if 0 in [p for p in prob.sum(-1)]:
                log_prob = torch.log(prob + EPS)
            else:
                log_prob = F.log_softmax(logit, 1)

            action_log_prob = log_prob.gather(1, action).view(-1)

            # Update value
            for batch_idx, entry in enumerate(sampled_batch):
                entry.value = value[batch_idx].detach().cpu()

            if args.actor_critic:
                adv = reward * z - value.detach()
            else:
                adv = reward * z

            if args.importance_sampling:
                old_log_prob = torch.stack([e.log_prob
                                            for e in sampled_batch]).to(device)
                ratio = torch.exp(action_log_prob - old_log_prob)
                if args.ppo_policy:
                    policy_loss = -(adv * ratio).mean()
                else:
                    policy_loss = -(adv * ratio.detach() *
                                    action_log_prob).mean(0)
            else:
                policy_loss = -(adv * action_log_prob).mean(0)

            # Add Entropy Regularizer
            entropy_regularizer = 0
            for i, num in enumerate(input_mask):
                non_mask_num = int(num.sum().item())
                _entropy = (-(log_prob[i][:non_mask_num] *
                              prob[i][:non_mask_num]).mean())
                entropy_regularizer += _entropy

            if not args.actor_critic:
                entropy_regularizer = 0

            value_loss = 0
            if args.actor_critic:
                value_loss = 0.5 * (reward - value).pow(2).mean()
                total_value_loss += value_loss.item()

            loss = 0.5 * value_loss + policy_loss - (args.entropy_coeff *
                                                     entropy_regularizer)
            # loss = 0.5 * value_loss + policy_loss - (0.01 * entropy_regularizer)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                           args.max_grad_norm)
            self.optimizer.step()
            self.model.zero_grad()

            total_loss += loss.item()
            total_policy_loss += policy_loss.item()

        tqdm.write(str(total_loss))
        tqdm.write(str(total_policy_loss))
        tqdm.write(str(total_value_loss))
        if not self.base:
            self.write_logs(loss=total_loss / args.replay_step,
                            rewards=self.rewards,
                            policy_loss=total_policy_loss / args.replay_step,
                            value_loss=total_value_loss / args.replay_step)

    def write_logs(self,
                   policy_loss=None,
                   loss=None,
                   rewards=None,
                   value_loss=None):
        # Train = End of the episode
        if len(self.acc_history) > 0:
            self.tb_writer.add_scalar('accuracy', self.acc_history[-1],
                                      self.episode)
        if len(self.base_acc_history) > 0:
            self.tb_writer.add_scalar('base_accuracy',
                                      self.base_acc_history[-1], self.episode)
            self.tb_writer.add_scalar(
                'accuracy_diff',
                self.acc_history[-1] - self.base_acc_history[-1], self.episode)
        if len(self.rnd_acc_history) > 0:
            self.tb_writer.add_scalar('rnd_accuracy', self.rnd_acc_history[-1],
                                      self.episode)
            self.tb_writer.add_scalar('rand_self_indicator',
                                      self.rand_self_indicator, self.episode)

        self.tb_writer.add_scalar('improve_indicator', self.improve_indicator,
                                  self.episode)
        if policy_loss is not None:
            self.tb_writer.add_scalar('policy_loss', policy_loss, self.episode)
        if value_loss is not None:
            self.tb_writer.add_scalar('value_loss', value_loss, self.episode)
        if loss is not None:
            self.tb_writer.add_scalar('loss', loss, self.episode)
        if rewards is not None and len(rewards) > 0:
            self.tb_writer.add_scalar('reward',
                                      sum(rewards) / self.reward_limit,
                                      self.episode)
        if len(self.entropy) > 0:
            self.tb_writer.add_scalar('entropy',
                                      sum(self.entropy) / len(self.entropy),
                                      self.episode)

        self.tb_writer.add_scalar('cumulative_regret', self.regret,
                                  self.episode)
        self.episode += 1
Example #5
0
def train(args, train_iter, dev, test, src_field, tgt_field, tag_field,
          checkpoint):
    # srcpadid = src_field.vocab.stoi['<pad>']
    tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

    model = Classify_Extractor(args, tgt_field)

    if torch.cuda.is_available():
        model.cuda()

    print_params(model)

    decay = args.decay

    if args.optimizer == 'bert':
        weight_decay = 0.0
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            weight_decay
        }, {
            'params': [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }]
        opt = AdamW(optimizer_grouped_parameters, lr=args.lr, eps=1e-8)
        totalnum = 0
        for i in train_iter:
            totalnum += 1
        #print(args.lr)
        #print(args.maximum_steps)
        #exit()
        t_total = totalnum // decay * args.maximum_steps
        scheduler = WarmupLinearSchedule(opt, warmup_steps=0, t_total=t_total)
    else:
        opt = torch.optim.Adadelta(model.parameters(), lr=args.lr)

    best_e = 0.0
    best_c = 0.0
    best_epoch_for_c = 0
    best_epoch_for_e = 0
    offset = 0.0
    pre_epoch = 0
    patience_c = 0
    patience_e = 0

    if checkpoint is not None:
        print('model.load_state_dict(checkpoint[model])')
        model.load_state_dict(checkpoint['model'])
        if args.resume:
            opt.load_state_dict(checkpoint['optim'])

            best_f = checkpoint['f']
            offset = checkpoint['iters']
            pre_epoch = checkpoint['epoch']

            print('*************************************')
            print('resume from {} epoch {} iters and best_f {}'.format(
                pre_epoch, offset, best_f))
            print('*************************************')

    print("**************start training****************")
    start = time.time()

    for epoch in range(args.maxepoch):
        train_iter.init_epoch()
        epoch += pre_epoch

        for iters, train_batch in enumerate(train_iter):
            iters += offset
            model.train()
            # model.zero_grad()
            # model.constrain_transition()
            t1 = time.time()
            batch_src = train_batch.src
            #print(batch_src)
            #exit()
            src = [tokenizer.convert_tokens_to_ids(s) for s in batch_src]
            maxlen = max([len(s) for s in batch_src])

            src_mask = []
            padded_sents = []
            for s in src:
                new_s = s + [0] * (maxlen - len(s))
                padded_sents.append(new_s)
                mask = [1] * len(s) + [0] * (maxlen - len(s))
                src_mask.append(mask)
            # B T
            src = torch.tensor(padded_sents).long().cuda()
            # B T
            src_mask = torch.tensor(src_mask).byte().cuda()
            # src, src_mask = prepare_src(train_batch.src, srcpadid)
            tgt = prepare_tgt(train_batch.tgt)
            tag = train_batch.tag

            loss = model(src, src_mask, tgt, tag)

            # "update parameters"

            if decay > 1:
                loss = loss / decay

            loss.backward()

            # if args.grad_clip:
            #     torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)

            if (iters + 1) % decay == 0:
                opt.step()
                scheduler.step()  # Update learning rate schedule
                opt.zero_grad()

            # opt.step()

            t2 = time.time()

            loss = loss.item()

            print("epoch:{} iters:{} src:({},{}) tgt:({},{}) "
                  "loss:{:.2f} t:{:.2f}".format(epoch + 1, iters + 1,
                                                *src.size(), *tgt.size(), loss,
                                                t2 - t1))

        # if torch.cuda.is_available():
        #     torch.cuda.empty_cache()

        if (epoch + 1) % 1 == 0:
            print("=============validate model==============")
            with torch.no_grad():
                dev.init_epoch()
                model.eval()
                # model.constrain_transition()
                sents = []
                cy_true = []
                cy_pred = []
                for j, dev_batch in enumerate(dev):
                    t1 = time.time()
                    # src, src_mask = prepare_src(dev_batch.src, srcpadid)
                    batch_src = dev_batch.src
                    src = [
                        tokenizer.convert_tokens_to_ids(s) for s in batch_src
                    ]
                    maxlen = max([len(s) for s in batch_src])

                    src_mask = []
                    padded_sents = []
                    for s in src:
                        new_s = s + [0] * (maxlen - len(s))
                        padded_sents.append(new_s)
                        mask = [1] * len(s) + [0] * (maxlen - len(s))
                        src_mask.append(mask)
                    # B T
                    src = torch.tensor(padded_sents).long().cuda()
                    # B T
                    src_mask = torch.tensor(src_mask).byte().cuda()

                    tgt = prepare_tgt(dev_batch.tgt)
                    tag = dev_batch.tag.squeeze(-1)
                    _, pre_tag = model.component_extraction(src, src_mask)
                    pre_ctag = model.simile_classify(src, src_mask)
                    cy_true.extend(tag.tolist())
                    cy_pred.extend(pre_ctag.tolist())

                    for sen, tags, p_tags, c_tags in zip(
                            src, tgt, pre_tag, tag):
                        sen = sen[:len(p_tags)].tolist()
                        tags = tags[:len(p_tags)].tolist()
                        if c_tags == 1:
                            sents.append([
                                sen, [tgt_field.vocab.itos[t] for t in tags],
                                [tgt_field.vocab.itos[t] for t in p_tags]
                            ])
                    print('dev iters: {}, t:{}'.format(j, time.time() - t1))

                _, eprecision, erecall, ef1 = evaluate(sents)

                cprecision = precision_score(cy_true, cy_pred)
                crecall = recall_score(cy_true, cy_pred)
                cf1 = f1_score(cy_true, cy_pred)

                print(
                    'epoch: {} classify--> precision: {} recall: {} f1: {} best:{}'
                    .format(epoch + 1, cprecision, crecall, cf1, best_c))
                print('extractor--> precision: {} recall: {} f1: {} best: {}'.
                      format(eprecision, erecall, ef1, best_e))

                if cf1 > best_c:
                    best_c = cf1
                    best_epoch_for_c = epoch + 1

                    print(
                        'save best classifier model at epoch={}'.format(epoch +
                                                                        1))
                    checkpoint = {
                        'model': model.state_dict(),
                        'optim': opt.state_dict(),
                        'args': args
                    }
                    torch.save(
                        checkpoint, '{}/{}.classify.best.pt'.format(
                            args.model_path, args.model))
                    patience_c = 0
                else:
                    patience_c += 1

                if ef1 > best_e:
                    best_e = ef1
                    best_epoch_for_e = epoch + 1

                    print(
                        'save best extractor model at epoch={}'.format(epoch +
                                                                       1))
                    checkpoint = {
                        'model': model.state_dict(),
                        'optim': opt.state_dict(),
                        'args': args
                    }
                    torch.save(
                        checkpoint, '{}/{}.extractor.best.pt'.format(
                            args.model_path, args.model))
                    patience_e = 0
                else:
                    patience_e += 1

        if patience_c > args.patience and patience_e > args.patience:
            print("early stop at {}".format(epoch))
            break

        if args.decay:
            opt.param_groups[0]['lr'] = opt.param_groups[0]['lr'] * args.decay

    print('*******Done********{}'.format(
        time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())))
    minutes = (time.time() - start) // 60
    if minutes < 60:
        print(
            'best_c:{}, best_e:{} best_epoch_c:{}, best_epoch_e:{}, time:{} mins'
            .format(best_c, best_e, best_epoch_for_c, best_epoch_for_e,
                    minutes))
    else:
        hours = minutes / 60
        print(
            'best_c:{}, best_e:{} best_epoch_c:{}, best_epoch_e:{}, time:{:.1f} hours'
            .format(best_c, best_e, best_epoch_for_c, best_epoch_for_e, hours))

    print('*******Testing************')
    model1 = Classify_Extractor(args, tgt_field)
    model1.cuda()
    load_from = '{}/{}.classify.best.pt'.format(args.model_path, args.model)
    print('load the best model {}'.format(load_from))
    checkpoint = torch.load(load_from, map_location='cpu')
    print('load parameters')
    model1.load_state_dict(checkpoint['model'])

    model2 = Classify_Extractor(args, tgt_field)
    model2.cuda()
    load_from = '{}/{}.extractor.best.pt'.format(args.model_path, args.model)
    print('load the best model {}'.format(load_from))
    checkpoint = torch.load(load_from, map_location='cpu')
    print('load parameters')
    model2.load_state_dict(checkpoint['model'])
    with torch.no_grad():
        test.init_epoch()
        model1.eval()
        model2.eval()
        sents = []
        cy_true = []
        cy_pred = []
        for j, test_batch in enumerate(test):
            t1 = time.time()
            # src, src_mask = prepare_src(test_batch.src, srcpadid)
            batch_src = test_batch.src
            src = [tokenizer.convert_tokens_to_ids(s) for s in batch_src]
            maxlen = max([len(s) for s in batch_src])

            src_mask = []
            padded_sents = []
            for s in src:
                new_s = s + [0] * (maxlen - len(s))
                padded_sents.append(new_s)
                mask = [1] * len(s) + [0] * (maxlen - len(s))
                src_mask.append(mask)
            # B T
            src = torch.tensor(padded_sents).long().cuda()
            # B T
            src_mask = torch.tensor(src_mask).byte().cuda()

            tgt = prepare_tgt(test_batch.tgt)
            tag = test_batch.tag.squeeze(-1)
            _, pre_tag = model2.component_extraction(src, src_mask)
            pre_ctag = model1.simile_classify(src, src_mask)
            cy_true.extend(tag.tolist())
            cy_pred.extend(pre_ctag.tolist())

            # for sen, tags, p_tags in zip(src, tgt, pre_tag):
            #     sen = sen[:len(p_tags)].tolist()
            #     tags = tags[:len(p_tags)].tolist()
            #     sents.append([sen, [tgt_field.vocab.itos[t] for t in tags],
            #                  [tgt_field.vocab.itos[t] for t in p_tags]])
            for sen, tags, p_tags, c_tags in zip(src, tgt, pre_tag, pre_ctag):
                sen = sen[:len(p_tags)].tolist()
                tags = tags[:len(p_tags)].tolist()
                if c_tags == 1:
                    sents.append([
                        sen, [tgt_field.vocab.itos[t] for t in tags],
                        [tgt_field.vocab.itos[t] for t in p_tags]
                    ])
                elif c_tags == 0:
                    sents.append([
                        sen, [tgt_field.vocab.itos[t] for t in tags],
                        ['O' for t in p_tags]
                    ])

            print('test iters: {}, t:{}'.format(j, time.time() - t1))

        _, eprecision, erecall, ef1 = evaluate(sents)

        cprecision = precision_score(cy_true, cy_pred)
        crecall = recall_score(cy_true, cy_pred)
        cf1 = f1_score(cy_true, cy_pred)

        print('Testing classify--> precision: {} recall: {} f1: {}'.format(
            cprecision, crecall, cf1))
        print('extractor--> precision: {} recall: {} f1: {}'.format(
            eprecision, erecall, ef1))
Example #6
0
def main():    
    parser = argparse.ArgumentParser("")
    parser.add_argument("--model", type=str, default='')    
    parser.add_argument("--resume", action='store_true')
    parser.add_argument("--eval", action='store_true')
    parser.add_argument("--batch_size", type=int, default=CFG.batch_size)
    parser.add_argument("--nepochs", type=int, default=CFG.num_train_epochs)    
    parser.add_argument("--wsteps", type=int, default=CFG.warmup_steps)
    parser.add_argument("--nlayers", type=int, default=CFG.num_hidden_layers)
    parser.add_argument("--nahs", type=int, default=CFG.num_attention_heads)
    parser.add_argument("--seed", type=int, default=7)
    parser.add_argument("--lr", type=float, default=CFG.learning_rate)
    parser.add_argument("--dropout", type=float, default=CFG.dropout)
    parser.add_argument("--types", nargs='+', type=str, 
                        default=['1JHC', '1JHN', '2JHC', '2JHH', '2JHN', '3JHC', '3JHH', '3JHN'], 
                        help='3JHC,2JHC,1JHC,3JHH,2JHH,3JHN,2JHN,1JHN')
    parser.add_argument("--train_file", default="train_mute_cp")
    parser.add_argument("--test_file", default="test_mute_cp")
    parser.add_argument("--pseudo_path", default="")
    parser.add_argument("--pseudo", action='store_true')
    parser.add_argument("--gen_pseudo", action='store_true')
    parser.add_argument("--use_all", action='store_true')
    parser.add_argument("--structure_file", default="structures_mu")
    parser.add_argument("--contribution_file", default="scalar_coupling_contributions")        
    args = parser.parse_args()
    print(args) 
    
    CFG.batch_size=args.batch_size
    CFG.num_train_epochs=args.nepochs
    CFG.warmup_steps=args.wsteps
    CFG.num_hidden_layers=args.nlayers
    CFG.num_attention_heads=args.nahs
    CFG.learning_rate=args.lr
    CFG.dropout=args.dropout
    CFG.seed =  args.seed
    print(CFG.__dict__)
    
    random.seed(CFG.seed)
    np.random.seed(CFG.seed)
    torch.manual_seed(CFG.seed)
    
    #if not args.eval:    
    if True:
        train_df = load_csv(args.train_file)
        
        structures_df = load_csv(args.structure_file)  
        structures_df[['x', 'y', 'z']] -= structures_df.groupby('molecule_name')[['x', 'y', 'z']].transform('mean')        
        
        contributions_df = load_csv(args.contribution_file)
        train_df = train_df.merge(contributions_df, how='left')   
        train_df = normalize_cols(train_df, ['scalar_coupling_constant', 'fc', 'sd', 'pso', 'dso'])        
        train_df = add_extra_features(train_df, structures_df)
        train_df = train_df.fillna(1e08)
        n_mols = train_df['molecule_name'].nunique()
        train_df, valid_df = train_test_split(train_df, 5000 )
        
        # only molecules with the args.types
        print(train_df['molecule_name'].nunique())
        mol_names_with_at = train_df[train_df['type'].isin(args.types)]['molecule_name'].unique()
        train_df = train_df[train_df['molecule_name'].isin(mol_names_with_at)].reset_index(drop=True)
        print(train_df['molecule_name'].nunique())
        
        # Print the 5 rows of valid_df to verify whether the valid_df is the same as the previous experiment.
        print(valid_df.head(5))
        
        if args.pseudo:        
            test_df = load_csv(args.test_file)
            logger.info(f'loading dataset - {args.pseudo_path} ...')
            test_pseudo_df = pd.read_csv(args.pseudo_path)
            #mol_names_jhn = train_df[test_df['type'].isin(['1JHN', '2JHN', '3JHN'])]['molecule_name'].unique()
            #test_df = test_df[test_df['molecule_name'].isin(mol_names_jhn)].reset_index(drop=True)        
            test_df = add_extra_features(test_df, structures_df)
            test_df = test_df.set_index('id')
            test_pseudo_df = test_pseudo_df.set_index('id')
            test_df[['scalar_coupling_constant',  'fc', 'sd', 'pso', 'dso']] = test_pseudo_df[['scalar_coupling_constant',  'fc', 'sd', 'pso', 'dso']]
            test_df = test_df.reset_index()            
            #test_df = normalize_target(test_df)
            test_df = normalize_cols(test_df, ['scalar_coupling_constant', 'fc', 'sd', 'pso', 'dso'])
            #test_df = test_df.assign(fc=1e08, sd=1e08, pso=1e08, dso=1e08)
            train_df['weight'] = 1.0
            valid_df['weight'] = 1.0
            test_df['weight'] = 1.0
            n_mols = test_df['molecule_name'].nunique()            
            train_df = train_df.append(test_df).reset_index(drop=True)
        else:
            train_df['weight'] = 1.0
            valid_df['weight'] = 1.0
        
        if args.use_all:
            train_df = train_df.append(valid_df) 
        
        print(f' n_train:{len(train_df)}, n_valid:{len(valid_df)}')
    
    config = BertConfig(            
            3, # not used
            hidden_size=CFG.hidden_size,
            num_hidden_layers=CFG.num_hidden_layers,
            num_attention_heads=CFG.num_attention_heads,
            intermediate_size=CFG.intermediate_size,
            hidden_dropout_prob=CFG.dropout,
            attention_probs_dropout_prob=CFG.dropout,
        )    
    model = cust_model.SelfAttn(config)
    if args.model != "":
        print("=> loading checkpoint '{}'".format(args.model))
        checkpoint = torch.load(args.model)
        CFG.start_epoch = checkpoint['epoch']        
        model.load_state_dict(checkpoint['state_dict'])        
        print("=> loaded checkpoint '{}' (epoch {})"
              .format(args.model, checkpoint['epoch']))
    model.cuda()
    
    def count_parameters(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('parameters: ', count_parameters(model))
    
    n_gpu = torch.cuda.device_count()
    if n_gpu > 1:
        model = torch.nn.DataParallel(model)
    
    # to produce the submission.csv
    if args.eval:
        test_df = load_csv(args.test_file)
        structures_df = load_csv(args.structure_file)
        structures_df[['x', 'y', 'z']] -= structures_df.groupby('molecule_name')[['x', 'y', 'z']].transform('mean')        
        test_df = add_extra_features(test_df, structures_df)
        test_df = test_df.assign(fc=1e08, sd=1e08, pso=1e08, dso=1e08) 
        test_df['scalar_coupling_constant'] = 0
        test_df['weight'] = 1.0
        test_db = db.MolDB(test_df, CFG.max_seq_length)
        test_loader = DataLoader(
            test_db, batch_size=CFG.batch_size, shuffle=False,
            num_workers=CFG.num_workers)
        res_df = validate(test_loader, model, args.types)        
        res_df = unnormalize_cols(res_df, cols=['fc', 'sd', 'pso', 'dso'])
        res_df = unnormalize_target(res_df, 'prediction1')
        if args.gen_pseudo:
            res_df['scalar_coupling_constant'] = res_df['prediction1']
            res_df = res_df[res_df['id']>-1].sort_values('id')
            res_df[['id', 'scalar_coupling_constant', 'fc', 'sd', 'pso', 'dso']].to_csv(f'pseudo_{CFG.seed}.csv', index=False)
            return
        res_df['prediction4']= res_df[['fc', 'sd', 'pso', 'dso']].sum(1)
        res_df['prediction']= res_df[['prediction1','prediction4']].mean(1)        
        res_df['scalar_coupling_constant'] = res_df['prediction']
        res_df = res_df[res_df['id']>-1].sort_values('id')
        os.makedirs('output', exist_ok=True)
        res_df[['id', 'scalar_coupling_constant']].to_csv(f'output/submission_{CFG.seed}.csv', index=False)        
        return
    
    train_db = db.MolDB(train_df, CFG.max_seq_length)    
    print('preloading dataset ...')
    train_db = db.MolDB_FromDB(train_db, 10)    
    valid_db = db.MolDB(valid_df, CFG.max_seq_length)    
    num_train_optimization_steps = int(
        len(train_db) / CFG.batch_size / CFG.gradient_accumulation_steps) * (CFG.num_train_epochs-CFG.start_epoch)
    print('num_train_optimization_steps', num_train_optimization_steps)      

    train_loader = DataLoader(
        train_db, batch_size=CFG.batch_size, shuffle=True,
        num_workers=CFG.num_workers, pin_memory=True)
    val_loader = DataLoader(
        valid_db, batch_size=CFG.batch_size, shuffle=False,
        num_workers=CFG.num_workers)
    
    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    
    optimizer = AdamW(optimizer_grouped_parameters,
                           lr=CFG.learning_rate,
                           weight_decay=CFG.weight_decay,                           
                           )
    scheduler = WarmupLinearSchedule(optimizer, CFG.warmup_steps,
                                        t_total=num_train_optimization_steps
                                     )
    
    def get_lr():
        return scheduler.get_lr()[0]
    
    if args.model != "":
        if args.resume:
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
        #for param_group in optimizer.param_groups:
        #    param_group['lr'] = CFG.learning_rate
        mae_log_df = checkpoint['mae_log']
        del checkpoint
    else:
        mae_log_df = pd.DataFrame(columns=(['EPOCH']+['LR']+args.types + ['OVERALL']) )     
    os.makedirs('log', exist_ok=True)
    
    
    res_df = validate(val_loader, model, args.types)        
    res_df = unnormalize_cols(res_df, cols=['scalar_coupling_constant', 'fc', 'sd', 'pso', 'dso'])
    res_df = unnormalize_target(res_df, 'prediction1')            
    res_df['prediction4']= res_df[['fc', 'sd', 'pso', 'dso']].sum(1)
    res_df['prediction']= res_df[['prediction1','prediction4']].mean(1)
    res_df.to_csv(f'log/valid_df_{"_".join(args.types)}.csv', index=False)
    overall_mae, maes = metric(res_df, args.types)
    print(overall_mae, maes)    
    
    
    curr_lr = get_lr()
    print(f'initial learning rate:{curr_lr}')
    for epoch in range(CFG.start_epoch, CFG.num_train_epochs):
        # train for one epoch
                
        #print(adjust_learning_rate(optimizer, epoch))    
        train(train_loader, model, optimizer, epoch, args.types, scheduler)
       
        if epoch % CFG.test_freq == 0:
            res_df = validate(val_loader, model, args.types)        
            res_df = unnormalize_cols(res_df, cols=['scalar_coupling_constant', 'fc', 'sd', 'pso', 'dso'])
            res_df = unnormalize_target(res_df, 'prediction1')            
            res_df['prediction4']= res_df[['fc', 'sd', 'pso', 'dso']].sum(1)
            res_df['prediction']= res_df[['prediction1','prediction4']].mean(1)
            res_df.to_csv(f'log/valid_df_{"_".join(args.types)}.csv', index=False)
            overall_mae, maes = metric(res_df, args.types)
            
            # write log file
            mae_row = dict([(typ, [mae]) for typ, mae in maes.items() if typ in args.types])
            mae_row.update({'EPOCH':(epoch),'OVERALL':overall_mae, 'LR':curr_lr})
            mae_log_df = mae_log_df.append(pd.DataFrame(mae_row), sort=False)
            print(mae_log_df.tail(20))        
            mae_log_df.to_csv(f'log/{"_".join(args.types)}.csv', index=False)
            
            #scheduler.step(overall_mae)
            curr_lr = get_lr()
            print(f'set the learning_rate: {curr_lr}')
            
            # evaluate on validation set
            batch_size = CFG.batch_size            
            pseudo_path = '' if not args.pseudo else '_' + args.pseudo_path 
            curr_model_name = (f'b{batch_size}_l{config.num_hidden_layers}_'
                               f'mh{config.num_attention_heads}_h{config.hidden_size}_'
                               f'd{CFG.dropout}_'
                               f'ep{epoch}_{"_".join(args.types)}_s{CFG.seed}{pseudo_path}.pt')
            model_to_save = model.module if hasattr(model, 'module') else model  # Only save the cust_model it-self    
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': 'transformer',
                'state_dict': model_to_save.state_dict(),
                'mae_log': mae_log_df,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                },
                FINETUNED_MODEL_PATH, curr_model_name
            )                                                
                                         
    print('done')
Example #7
0
def main(parser):

    args = parser.parse_args()

    if args.fp16 == True:
        from apex import amp

    data_dir = Path(args.data_dir)
    model_dir = Path(args.model_dir)
    model_config = Config(json_path=model_dir / 'config.json')
    model_config.learning_rate = args.lr
    model_config.batch_size = args.batch_size

    model_config.vocab_size = len(vocab)
    print("vocabulary length: ", len(vocab))

    # Train & Val Datasets
    train_data_dir = "../data/NER-master/말뭉치 - 형태소_개체명"
    tr_ds = NamedEntityRecognitionDataset(train_data_dir=train_data_dir, vocab=vocab, \
                                          tokenizer=bert_tokenizer, maxlen=model_config.maxlen, model_dir=model_dir)
    tr_dl = DataLoader(tr_ds,
                       batch_size=model_config.batch_size,
                       shuffle=True,
                       num_workers=2,
                       drop_last=False)

    val_data_dir = "../data/NER-master/validation_set"
    val_ds = NamedEntityRecognitionDataset(train_data_dir=val_data_dir, vocab=vocab, \
                                           tokenizer=bert_tokenizer, maxlen=model_config.maxlen, model_dir=model_dir)
    val_dl = DataLoader(val_ds,
                        batch_size=model_config.batch_size,
                        shuffle=True,
                        num_workers=2,
                        drop_last=False)

    # Model
    model = BertMulti_CRF(config=model_config,
                          num_classes=len(tr_ds.ner_to_index),
                          vocab=vocab)
    #model = BertMulti_Only(config=model_config, num_classes=len(tr_ds.ner_to_index), vocab=vocab)
    #model = BiLSTM(config=model_config, num_classes=len(tr_ds.ner_to_index), vocab=vocab)
    #model = BiLSTM_CRF(config=model_config, num_classes=len(tr_ds.ner_to_index))
    model.train()

    # optim
    train_examples_len = len(tr_ds)
    val_examples_len = len(val_ds)
    print("num of train: {}, num of val: {}".format(train_examples_len,
                                                    val_examples_len))

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    # num_train_optimization_steps = int(train_examples_len / model_config.batch_size / model_config.gradient_accumulation_steps) * model_config.epochs
    t_total = len(
        tr_dl
    ) // model_config.gradient_accumulation_steps * model_config.epochs
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=model_config.learning_rate,
                      eps=model_config.adam_epsilon)
    #optimizer = torch.optim.Adam(model.parameters(), model_config.learning_rate)
    if args.lr_schedule:
        scheduler = WarmupLinearSchedule(
            optimizer, warmup_steps=model_config.warmup_steps, t_total=t_total)
        #lmbda = lambda epoch: 0.5
        #scheduler = LambdaLR(optimizer, lr_lambda=lmbda)

    #Create model output directory
    output_dir = os.path.join(
        model_dir,
        '{}-lr{}-bs{}'.format(model.name, model_config.learning_rate,
                              model_config.batch_size))
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    #checkpoint_manager = CheckpointManager(model_dir)
    summary_manager = SummaryManager(output_dir)

    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    '''
    n_gpu = torch.cuda.device_count()
    if n_gpu > 1:
        model = torch.nn.DataParallel(model)
    '''
    model.to(device)

    if args.fp16:
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
    if args.continue_train:
        revert_to_best(model, optimizer, output_dir)
        logging.info("==== continue training: %s ====", '{}-lr{}-bs{}' \
                    .format(model.name, model_config.learning_rate, model_config.batch_size))

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(tr_ds))
    logger.info("  Num Epochs = %d", model_config.epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                model_config.batch_size)
    logger.info("  Gradient Accumulation steps = %d",
                model_config.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    log_file = open('{}/log.tsv'.format(output_dir), 'at')
    print('{}\t{}\t{}\t{}\t{}\t{}\t{}'.format('epoch', 'train loss', 'eval_loss', 'eval global accuracy', \
                                              'micro_f1_score', 'macro_f1_score', 'learning_rate'), file=log_file)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    best_dev_acc, best_dev_loss = 0.0, 99999999999.0
    best_epoch = 0
    best_steps = 0
    patience = args.patience
    f_scores = []
    model.zero_grad()
    set_seed()
    criterion = nn.CrossEntropyLoss()

    train_begin = datetime.now()
    '''
    train_iterator = trange(int(model_config.epochs), desc="Epoch")  
    for _epoch, _ in enumerate(train_iterator):
    '''
    for _epoch in range(model_config.epochs):
        #epoch_iterator = tqdm(tr_dl, desc="Iteration")
        epoch_iterator = tr_dl
        epoch = _epoch

        for step, batch in enumerate(epoch_iterator):

            model.train()
            #print(batch)

            x_input, token_type_ids, y_real = map(lambda elm: elm.to(device),
                                                  batch)
            #print(x_input.size(), token_type_ids.size(), y_real.size()) #都是batch_size*max_len
            #print(y_real)
            if model.name == "BertMulti_Only":
                y_out = model(x_input, token_type_ids, y_real)
                y_out.requires_grad_()
                y_out.contiguous()
                y_real.contiguous()
                y_real_ = y_real.view(-1)
                y_out_ = y_out.view(-1, len(tr_ds.ner_to_index))
                loss = criterion(y_out_, y_real_)
                _, sequence_of_tags = F.softmax(y_out, dim=2).max(2)
            elif model.name == "BiLSTM":
                y_out = model(x_input, token_type_ids, y_real)
                y_out.requires_grad_()
                y_out.contiguous()
                y_real.contiguous()

                y_out1 = F.log_softmax(y_out, dim=2)
                y_out1 = y_out1.view(-1, len(tr_ds.ner_to_index))

                y_real_ = y_real.view(-1)
                mask = (y_real_ != 1).float()
                #print(len(mask))
                original_len = int(torch.sum(mask))
                #print(x_input[0], y_real[0], original_len, '\n')
                y_out1 = y_out1[range(y_out1.shape[0]), y_real_] * mask
                loss = -torch.sum(y_out1) / original_len

                _, sequence_of_tags = F.softmax(y_out, dim=2).max(2)
            else:
                log_likelihood, sequence_of_tags = model(
                    x_input, token_type_ids, y_real)
                loss = -1 * log_likelihood

            if model_config.gradient_accumulation_steps > 1:
                loss = loss / model_config.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           model_config.max_grad_norm)
            tr_loss += loss.item()

            if (step + 1) % model_config.gradient_accumulation_steps == 0:
                optimizer.step()
                if args.lr_schedule:
                    scheduler.step()  # Update learning rate schedule
                    #print(scheduler.state_dict())
                model.zero_grad()
                global_step += 1

                with torch.no_grad():
                    sequence_of_tags = torch.tensor(sequence_of_tags).to(
                        device)
                    #print(sequence_of_tags.size(), y_real.size())
                    mb_acc = (sequence_of_tags == y_real
                              ).float()[y_real != vocab['[PAD]']].mean()

                tr_acc = mb_acc.item()
                tr_loss_avg = tr_loss / global_step
                tr_summary = {'loss': tr_loss_avg, 'acc': tr_acc}

                if (step + 1) % 20 == 0:
                    logging.info('epoch : {}, global_step : {}, tr_loss: {:.3f}, tr_acc: {:.2%}' \
                                 .format(epoch + 1, global_step, tr_summary['loss'], tr_summary['acc']))

                # evaluation and save model
                if model_config.logging_steps > 0 and global_step % model_config.logging_steps == 0:

                    eval_summary = evaluate(model, val_dl)

                    f_scores.append(eval_summary['macro_f1_score'])

                    # Save model checkpoint
                    summary = {'train': tr_summary, 'eval': eval_summary}
                    summary_manager.update(summary)
                    summary_manager.save('summary.json')

                    # Save
                    is_best = eval_summary[
                        "macro_f1_score"] >= best_dev_acc  # acc 기준 (원래는 train_acc가 아니라 val_acc로 해야)
                    is_best_str = 'BEST' if is_best else '< {:.4f}'.format(
                        max(f_scores))
                    logging.info(
                        '[Los trn]  [Los dev]  [global acc]  [micro f1]  [macro f1]     [global step]    [LR]'
                    )
                    logging.info('{:8.2f}  {:9.2f}  {:9.2f}  {:11.4f}  {:9.4f} {:4}  {:9}  {:14.8f}' \
                                 .format((tr_loss - logging_loss) / model_config.logging_steps, eval_summary['eval_loss'], \
                                         eval_summary['eval_global_acc'], eval_summary['micro_f1_score'], \
                                         eval_summary['macro_f1_score'], is_best_str, global_step, model_config.learning_rate))
                    print('{}\t{}\t{}\t{}\t{}\t{}\t{}'.format(epoch, tr_loss, \
                                                              eval_summary['eval_loss'], eval_summary['eval_global_acc'], \
                                                              eval_summary['micro_f1_score'], eval_summary['macro_f1_score'], \
                                                              model_config.learning_rate), file=log_file)
                    log_file.flush()

                    logging_loss = tr_loss

                    if is_best:
                        best_dev_acc = eval_summary["macro_f1_score"]
                        best_dev_loss = eval_summary["eval_loss"]
                        best_steps = global_step
                        best_epoch = epoch
                        #checkpoint_manager.save_checkpoint(state, 'best-epoch-{}-step-{}-acc-{:.3f}.bin'.format(epoch + 1, global_step, best_dev_acc))
                        #logging.info("Saving model checkpoint as best-epoch-{}-step-{}-acc-{:.3f}.bin".format(epoch + 1, global_step, best_dev_acc))
                        logging.info(
                            "Saving model at epoch{}, step{} in {}".format(
                                epoch, global_step, output_dir))
                        torch.save(model.state_dict(),
                                   '{}/model.state'.format(output_dir))
                        torch.save(optimizer.state_dict(),
                                   '{}/optim.state'.format(output_dir))
                        patience = args.patience

                    else:
                        revert_to_best(model, optimizer, output_dir)
                        patience -= 1
                        logging.info("==== revert to epoch[%d], step%d. F1 score: %.4f, patience: %d ====", \
                                     best_epoch, best_steps, max(f_scores), patience)

                        if patience == 0:
                            break

        else:

            continue

        break

    #print("global_step = {}, average loss = {}".format(global_step, tr_loss / global_step))

    train_end = datetime.now()
    train_elapsed = elapsed(train_end - train_begin)
    logging.info('==== training time elapsed: %s, epoch: %s ====',
                 train_elapsed, epoch)

    return global_step, tr_loss / global_step, best_steps
Example #8
0
class Trainer:
    def __init__(self, args, config, model, criterion, train_dataloader,
                 valid_dataloader, logger, save_path, tb_writer):

        self.args = args
        self.config = config
        self.model = model
        self.criterion = criterion
        self.train_dataloader = train_dataloader
        self.valid_dataloader = valid_dataloader
        self.logger = logger
        self.save_path = save_path
        self.tb_writer = tb_writer

        self.t_total = len(self.train_dataloader) * self.args.epoch
        self.device = self.config.device
        self.optimizer = AdamW(self.get_model_parameters(),
                               lr=self.config.learning_rate)
        self.scheduler = WarmupLinearSchedule(self.optimizer,
                                              0.1 * self.t_total, self.t_total)

        self.global_step = 0
        self.best_eval_acc = 0.2

    def get_model_parameters(self):
        # Optimizer & Loss
        param_optimizer = list(self.model.named_parameters())
        no_decay = ['bias', 'LayerNorm.weight', 'LayerNorm.bias']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]

        return optimizer_grouped_parameters

    def train(self, do_eval=True, do_save=True):

        for epoch in range(self.args.epoch):
            self.train_epoch(epoch)
            self.evaluation(epoch)
            self.write_to_tb()
            self.save_model(epoch)

        self.tb_writer.close()

    def transform_to_bert_input(self, batch):

        input_ids, valid_length, token_type_ids = batch[0], batch[1], batch[2]

        input_ids = torch.from_numpy(input_ids).to(self.device)
        valid_length = valid_length.clone().detach().to(self.device)
        token_type_ids = torch.tensor(token_type_ids).long().to(self.device)

        return input_ids, valid_length, token_type_ids

    def compute_acc(self, y_hat, y, mean=True):
        if mean:
            yhat = y_hat.max(
                dim=-1)[1]  # [0]: max value, [1]: index of max value
            acc = (yhat == y).float().mean()  # padding은 acc에서 제거
            return acc
        else:
            correct_count = (yhat == y).long().sum()
            return correct_count

    def train_epoch(self, epoch):
        self.model.to(self.device)
        self.model.train()

        tr_correct_cnt, tr_total_cnt = 0, 0
        tr_loss = 0.0
        # train_loader = tqdm(self.train_dataloader)
        train_loader = self.train_dataloader

        for step, batch in enumerate(train_loader):

            self.model.zero_grad()

            sent1 = batch['sent1']
            input_1, valid_length_1, token_type_1 = self.transform_to_bert_input(
                sent1)
            embed1 = self.model(input_1, valid_length_1, token_type_1)

            sent2 = batch['sent2']
            input_2, valid_length_2, token_type_2 = self.transform_to_bert_input(
                sent2)
            embed2 = self.model(input_2, valid_length_2, token_type_2)

            label = batch['label']
            label = torch.tensor(label).long().to(self.device)

            pred = self.model.get_logit(embed1, embed2)
            loss = self.criterion(pred, label.view(-1))

            tr_loss += loss.item()
            loss.backward()

            if step > 0 and (
                    step) % self.config.gradient_accumulation_steps == 0:
                self.global_step += self.config.gradient_accumulation_steps

                self.optimizer.step()
                self.optimizer.zero_grad()
                self.scheduler.step()

                with torch.no_grad():
                    accuracy = self.compute_acc(pred, label)

                self.tr_acc = accuracy.item()
                self.tr_avg_loss = tr_loss / self.global_step

                if self.global_step % 100 == 0:  #int(len(self.train_dataloader)/5) ==0:

                    self.logger.info(
                        'epoch : {} /{}, global_step : {} /{}, tr_avg_loss: {:.3f}, tr_acc: {:.2%}'
                        .format(epoch + 1, self.args.epoch, self.global_step,
                                self.t_total, self.tr_avg_loss, self.tr_acc))

    def evaluation(self, epoch):
        self.model.eval()
        eval_correct_cnt, eval_total_cnt = 0, 0
        eval_loss = 0.0

        eval_acc = 0.0
        eval_step = 1

        self.logger.info('*****************Evaluation*****************')
        valid_loader = tqdm(self.valid_dataloader)
        for step, batch in enumerate(valid_loader):
            with torch.no_grad():

                sent1 = batch['sent1']
                input_1, valid_length_1, token_type_1 = self.transform_to_bert_input(
                    sent1)
                embed1 = self.model(input_1, valid_length_1, token_type_1)

                sent2 = batch['sent2']
                input_2, valid_length_2, token_type_2 = self.transform_to_bert_input(
                    sent2)
                embed2 = self.model(input_2, valid_length_2, token_type_2)

                label = batch['label']
                label = torch.tensor(label).long().to(self.device)
                pred = self.model.get_logit(embed1, embed2)

            loss = self.criterion(pred, label.view(-1))
            eval_loss += loss.item()

            acc = self.compute_acc(pred, label)
            eval_acc += acc.item()
            eval_step += 1.0

        self.eval_avg_loss = eval_loss / eval_step
        self.eval_avg_acc = eval_acc / eval_step

        self.logger.info(
            'epoch : {} /{}, global_step : {} /{}, eval_loss: {:.3f}, eval_acc: {:.2%}'
            .format(epoch + 1, self.args.epoch, self.global_step, self.t_total,
                    self.eval_avg_loss, self.eval_avg_acc))

    def save_model(self, epoch):
        if self.eval_avg_acc > self.best_eval_acc:
            self.best_eval_acc = self.eval_avg_acc

            self.model.to(torch.device('cpu'))
            state = {
                'epoch': epoch + 1,
                'model_state_dict': self.model.state_dict(),
                'opt_state_dict': self.optimizer.state_dict()
            }

            save_model_path = '{}/epoch_{}_step_{}_tr_acc_{:.3f}_tr_loss_{:.3f}_eval_acc_{:.3f}_eval_loss_{:.3f}.pt'.format(
                self.save_path, epoch + 1, self.global_step, self.tr_acc,
                self.tr_avg_loss, self.eval_avg_acc, self.eval_avg_loss)

            # Delte previous checkpoint
            if len(glob.glob(self.save_path + '/epoch*.pt')) > 0:
                os.remove(glob.glob(self.save_path + '/epoch*.pt')[0])
            torch.save(state, save_model_path)
            self.logger.info(' Model saved to {}'.format(save_model_path))

            os.mkdir(self.save_path +
                     '/epoch_{}_eval_loss_{:.3f}_eval_acc_{:.3f}'.format(
                         epoch + 1, self.eval_avg_loss, self.eval_avg_acc))

    def write_to_tb(self):
        self.tb_writer.add_scalars('loss', {
            'train': self.tr_avg_loss,
            'val': self.eval_avg_loss
        }, self.global_step)
        self.tb_writer.add_scalars('acc', {
            'train': self.tr_acc,
            'val': self.eval_avg_acc
        }, self.global_step)
Example #9
0
def train(args, train_dataset, val_dataset, model, tokenizer):
    """ Train the model """
    pretrained_model = model[0]
    adapter_model = model[1]

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size // args.gradient_accumulation_steps)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in adapter_model.named_parameters() if not any(nd in n for nd in no_decay)],
         'weight_decay': args.weight_decay},
        {'params': [p for n, p in adapter_model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        adapter_model, optimizer = amp.initialize(adapter_model, optimizer, opt_level=args.fp16_opt_level)
    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        pretrained_model = torch.nn.DataParallel(pretrained_model)
        adapter_model = torch.nn.DataParallel(adapter_model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        pretrained_model = torch.nn.parallel.DistributedDataParallel(pretrained_model, device_ids=[args.local_rank],
                                                          output_device=args.local_rank)
        adapter_model = torch.nn.parallel.DistributedDataParallel(adapter_model, device_ids=[args.local_rank],
                                                          output_device=args.local_rank)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num train examples = %d", len(train_dataset)) #logging.info(f"  Num train_examples = {len(train_examples)}")
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
                   args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    logger.info("Try resume from checkpoint")
    if args.restore:
        if os.path.exists(os.path.join(args.output_dir, 'global_step.bin')):
            logger.info("Load last checkpoint data")
            global_step = torch.load(os.path.join(args.output_dir, 'global_step.bin'))
            output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
            logger.info("Load from output_dir {}".format(output_dir))

            optimizer.load_state_dict(torch.load(os.path.join(output_dir, 'optimizer.bin')))
            scheduler.load_state_dict(torch.load(os.path.join(output_dir, 'scheduler.bin')))
            # args = torch.load(os.path.join(output_dir, 'training_args.bin'))
            if hasattr(adapter_model, 'module'):
                adapter_model.module.load_state_dict(torch.load(os.path.join(output_dir, 'pytorch_model.bin')))
            else:  # Take care of distributed/parallel training
                adapter_model.load_state_dict(torch.load(os.path.join(output_dir, 'pytorch_model.bin')))

            global_step += 1
            start_epoch = int(global_step / len(train_dataloader))
            start_step = global_step-start_epoch*len(train_dataloader)-1
            logger.info("Start from global_step={} epoch={} step={}".format(global_step, start_epoch, start_step))
            if args.local_rank in [-1, 0]:
                tb_writer = SummaryWriter(log_dir="runs/" + args.my_model_name, purge_step=global_step)

        else:
            global_step = 0
            start_epoch = 0
            start_step = 0
            if args.local_rank in [-1, 0]:
                tb_writer = SummaryWriter(log_dir="runs/" + args.my_model_name, purge_step=global_step)
            logger.info("Start from scratch")
    else:
        global_step = 0
        start_epoch = 0
        start_step = 0
        if args.local_rank in [-1, 0]:
            tb_writer = SummaryWriter(log_dir="runs/" + args.my_model_name, purge_step=global_step)
        logger.info("Start from scratch")

    tr_loss, logging_loss = 0.0, 0.0
    pretrained_model.zero_grad()
    adapter_model.zero_grad()

    set_seed(args)  # Added here for reproductibility (even between python 2 and 3)

    for epoch in range(start_epoch, int(args.num_train_epochs)):
        for step, batch in enumerate(train_dataloader):
            start = time.time()
            if args.restore and (step < start_step):
                continue
            # if args.restore and (flag_count < global_step):
            #     flag_count+=1
            #     continue
            pretrained_model.eval()
            adapter_model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {'input_ids':      batch[0],
                      'attention_mask': batch[1],
                      'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None,  # XLM and RoBERTa don't use segment_ids
                      'labels':         batch[3]}
            pretrained_model_outputs = pretrained_model(**inputs)
            outputs = adapter_model(pretrained_model_outputs,**inputs)

            loss = outputs[0]  # model outputs are always tuple in pytorch-transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean() # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            # epoch_iterator.set_description("loss {}".format(loss))
            logger.info("Epoch {}/{} - Iter {} / {}, loss = {:.5f}, time used = {:.3f}s".format(epoch, int(args.num_train_epochs),step,
                                                                                             len(train_dataloader),
                                                                                             loss.item(),
                                                                                             time.time() - start))
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(adapter_model.parameters(), args.max_grad_norm)


            tr_loss += loss.item()

            if (step + 1) % args.gradient_accumulation_steps == 0:
                scheduler.step()  # Update learning rate schedule
                optimizer.step()
                pretrained_model.zero_grad()
                adapter_model.zero_grad()
                global_step += 1
                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)

                    logging_loss = tr_loss

                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = adapter_model.module if hasattr(adapter_model,
                                                            'module') else adapter_model  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)  # save to pytorch_model.bin  model.state_dict()

                    torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer.bin'))
                    torch.save(scheduler.state_dict(), os.path.join(output_dir, 'scheduler.bin'))
                    torch.save(args, os.path.join(output_dir, 'training_args.bin'))
                    torch.save(global_step, os.path.join(args.output_dir, 'global_step.bin'))

                    logger.info("Saving model checkpoint, optimizer, global_step to %s", output_dir)
                    if (global_step/args.save_steps) > args.max_save_checkpoints:
                        try:
                            shutil.rmtree(os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step-args.max_save_checkpoints*args.save_steps)))
                        except OSError as e:
                            print(e)
                if args.local_rank == -1 and args.evaluate_during_training and global_step %args.eval_steps== 0:  # Only evaluate when single GPU otherwise metrics may not average well
                    model = (pretrained_model, adapter_model)
                    results = evaluate(args, val_dataset, model, tokenizer)
                    for key, value in results.items():
                        tb_writer.add_scalar('eval_{}'.format(key), value, global_step)

            if args.max_steps > 0 and global_step > args.max_steps:
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
Example #10
0
class SmallTalk:
    def __init__(self, name, model_name, model_type='gpt2', opt_level=None, lr=6.25e-5, lm_coef=1.0, mc_coef=1.0, gradient_accumulation_steps=8, max_norm=1.0, device='cuda:0'):
        self.lr, self.lm_coef, self.mc_coef, self.gradient_accumulation_steps, self.max_norm, self.device = lr, lm_coef, mc_coef, gradient_accumulation_steps, max_norm, device
        self.name, self.model_name, self.model_type, self.opt_level = name, model_name, model_type, opt_level

        self.logger, self.tb_logger, self.checkpoint_handler = stu.setup_training_loggers(self.name)

        self.verbose = False
        self.epoch = 0

        # TODO: Add logger statement here
        model_class, tokenizer_class = (GPT2DoubleHeadsModel, GPT2Tokenizer) if self.model_type == 'gpt2' else (OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer)
        self.model, self.tokenizer = model_class.from_pretrained(self.model_name).to(self.device), tokenizer_class.from_pretrained(self.model_name)

        stu.add_special_tokens_(model=self.model, tokenizer=self.tokenizer)

        self.optimizer = AdamW(self.model.parameters(), lr=self.lr, correct_bias=True)

        if self.opt_level:
            self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level=self.opt_level)

        self.trainer = Engine(self.update)
        self.evaluator = Engine(self.inference)

    def update(self, engine, batch):
        self.model.train()
        batch = tuple(input_tensor.to(self.device) for input_tensor in batch)
        input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch
        (lm_loss), (mc_loss), *_ = self.model(
            input_ids, token_type_ids=token_type_ids, mc_token_ids=mc_token_ids,
            mc_labels=mc_labels, lm_labels=lm_labels
        )
        loss = (lm_loss * self.lm_coef + mc_loss * self.mc_coef) / self.gradient_accumulation_steps

        if self.opt_level:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_norm)

        if engine.state.iteration % self.gradient_accumulation_steps == 0:
            self.optimizer.step()
            self.optimizer.zero_grad()
        return loss.item()

    def inference(self, engine, batch):
        self.model.eval()
        with torch.no_grad():
            batch = tuple(input_tensor.to(self.device) for input_tensor in batch)
            input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch

            if self.verbose:
                self.logger.info(self.tokenizer.decode(input_ids[0, -1, :].tolist()))

            # if we dont send labels to model, it doesnt return losses
            lm_logits, mc_logits, *_ = self.model(
                input_ids, token_type_ids=token_type_ids, mc_token_ids=mc_token_ids,
            )

            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(-1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)

            return (lm_logits_flat_shifted, mc_logits), (lm_labels_flat_shifted, mc_labels)

    def train_model(self, n_epochs, train_loader, val_loader, eval_before_start=True):
        # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
        self.trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: self.evaluator.run(val_loader))
        self.trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: self.update_epoch())
        if eval_before_start:
            self.trainer.add_event_handler(Events.STARTED, lambda _: self.evaluator.run(val_loader))

        # Linearly decrease the learning rate from lr to zero
        scheduler = PiecewiseLinear(self.optimizer, "lr", [(0, self.lr), (n_epochs * len(train_loader), 0.0)])
        self.trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

        # Prepare metrics
        RunningAverage(output_transform=lambda x: x).attach(self.trainer, "loss")
        metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1), output_transform=lambda x: (x[0][0], x[1][0])),
                   "accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))}
        metrics["average_ppl"] = MetricsLambda(math.exp, metrics["nll"])
        for name, metric in metrics.items():
            metric.attach(self.evaluator, name)

        # On the main process: add progress bar, tensorboard, checkpoints and save model
        pbar = ProgressBar(persist=True)
        pbar.attach(self.trainer, metric_names=["loss"])

        if not self.verbose:
            pbar_eval = ProgressBar(persist=False)
            pbar_eval.attach(self.evaluator)

        self.evaluator.add_event_handler(Events.STARTED, lambda _: self.logger.info(f'Beginning validation for epoch {self.epoch}...'))
        self.evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Validation: %s" % pformat(self.evaluator.state.metrics)))

        self.tb_logger.attach(self.trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED)
        self.tb_logger.attach(self.trainer, log_handler=OptimizerParamsHandler(self.optimizer), event_name=Events.ITERATION_STARTED)
        self.tb_logger.attach(self.evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys()), another_engine=self.trainer),
                              event_name=Events.EPOCH_COMPLETED)

        self.trainer.add_event_handler(Events.EPOCH_COMPLETED, self.checkpoint_handler,
                                       {'mymodel': getattr(self.model, 'module', self.model)})  # "getattr" takes care of distributed encapsulation

        # Run the training
        self.trainer.run(train_loader, max_epochs=n_epochs)

        # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
        if n_epochs > 0:
            os.rename(self.checkpoint_handler._saved[-1][1][-1], os.path.join(cfg.checkpoint_log_folder, self.name, WEIGHTS_NAME))
            self.tb_logger.close()

    def save(self, path, inference_only=False):
        """ Saves important components of model to be imported later. """
        save_dict = {
            'epoch': self.epoch,
            'model_state_dict': self.model.state_dict(),
            'model_name': self.model_name,
            'model_type': self.model_type,
            'opt_level': self.opt_level
        }

        if not inference_only:
            save_dict['optimizer_state_dict'] = self.optimizer.state_dict()

        torch.save(save_dict, path)

    # TODO: May want to revisit here if we want to do evaluation on a cpu. See https://github.com/NVIDIA/apex/issues/242
    def load(self, path):
        """ Loads important components of model back into memory to pick up where we left off. """
        checkpoint = torch.load(path)
        assert self.model_type == checkpoint['model_type'], f"Model types do not match, current model is {self.model_type} and loaded model is {checkpoint['model_type']}"
        assert self.model_name == checkpoint['model_name'], f"Model names do not match, current model is {self.model_name} and loaded model is {checkpoint['model_name']}"
        assert self.opt_level == checkpoint['opt_level'], f"Model opt_levels do not match, current model is {self.opt_level} and loaded model is {checkpoint['opt_level']}"

        self.model.load_state_dict(checkpoint['model_state_dict'])
        if 'optimizer_state_dict' in checkpoint:
            self.logger.info('Optimizer information saved for continued training. Loading into model.')
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        else:
            self.logger.info('Model previously saved for inference only.')

        self.epoch = checkpoint['epoch']

    def load_checkpoint(self, path):
        """ Loads an entire checkpoint and overwrite model """
        self.model.load_state_dict(torch.load(path, map_location=self.device))

    def update_epoch(self):
        self.epoch += 1

    def get_num_params(self, trainable_only=True):
        if trainable_only:
            return sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        return sum(p.numel() for p in self.model.parameters())

    def interact(self, personality=None, max_history=2, max_length=20, min_length=1, temperature=0.7, top_k=0, top_p=0.9, no_sample=False, random_pause=None):
        """
        Interact with bot in python setting
        :param personality: Personality to use to condition model on for chat. None will pull a random one from training data set. List of several short sentences describing personality.
        :param max_history: Number of responses per individual to retain for model to generate text with (in addition to the utterance the model is directly responding to).
        :param max_length: Maximum length of output utterances
        :param min_length: Minimum length of output utterances
        :param temperature: Sampling softmax temperature. 1.0 is standard softmax, as it decreases it allows for less diversity in outputs (makes peaks higher in distribution).
        :param top_k: Filter top_k tokens before sampling (<=0 is no filtering)
        :param top_p: Nucleus filtering
        :param no_sample: Whether to simply choose the most likely token at each sample and skip fancy sampling methods above
        :param random_pause: Whether to pause for random amounts of time to seem more human (should be tuple of low and high value to randomly pause between).
        """
        personality = self.get_personality(personality=personality)
        self.logger.info(self.tokenizer.decode(list(chain(*personality))))

        self.model.eval()
        history = []

        self.logger.info('You may now begin talking to the bot. Don\'t be shy, say hello!')

        while True:
            raw_text = input('>>> ')
            while not raw_text:
                print('Please enter in a non-empty value.')
                raw_text = input('>>> ')
            history.append(self.tokenizer.encode(raw_text))

            if random_pause:
                assert len(random_pause) == 2, 'random_pause arg should be a tuple of length 2 if passed'
                time.sleep(random_pause[0] + random.random() * (random_pause[1] - random_pause[0]))

            with torch.no_grad():
                out_ids = stu.sample_sequence(personality=personality, history=history, tokenizer=self.tokenizer, model=self.model, device=self.device,
                                              max_length=max_length, min_length=min_length, temperature=temperature, top_k=top_k, top_p=top_p, no_sample=no_sample)
            history.append(out_ids)
            history = history[-(2 * max_history + 1):]
            out_text = self.tokenizer.decode(out_ids, skip_special_tokens=True)
            print(out_text)

    def get_reply(self, conversation_history, personality, max_history=2, max_length=20, min_length=1, temperature=0.7, top_k=0, top_p=0.9, no_sample=False, random_pause=None):
        """
        Based heavily on self.interact. See above documentation for detail on parameters.
        Alternate version of interact for use with chatbot. Uses ConversationHistory object to put together the history and return one reply at a time, rather than manage
        an entire conversation.
        """
        self.model.eval()

        # Build history object from ConversationHistory class
        history = conversation_history.get_list_of_conversation_latest_n_exchanges(n=max_history)
        history = [self.tokenizer.encode(msg) for msg in history]

        # Get ids from model
        with torch.no_grad():
            out_ids = stu.sample_sequence(personality=personality, history=history, tokenizer=self.tokenizer, model=self.model, device=self.device,
                                          max_length=max_length, min_length=min_length, temperature=temperature, top_k=top_k, top_p=top_p, no_sample=no_sample)

        return self.tokenizer.decode(out_ids, skip_special_tokens=True)

    def get_personality(self, personality=None):
        """
        Retrieves a random personality if personality is None, otherwise converts personality raw text to a format the model understands.
        :param personality: List of 4-5 sentences in raw text string form
        """
        if personality is None:
            return stu.get_random_personality(self)
        else:
            return [self.tokenizer.encode(sentence) for sentence in personality]

    def print_personality(self, personality):
        print(self.tokenizer.decode(chain(*personality)))
def main():
    my_parser = argparse.ArgumentParser()

    # Required parameters
    my_parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    my_parser.add_argument("--src_file",
                           default=None,
                           type=str,
                           help="The input data file name.")
    my_parser.add_argument("--model_type",
                           default=None,
                           type=str,
                           required=True,
                           help="Model type selected in the list: " +
                           ", ".join(MODEL_CLASSES.keys()))
    my_parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list: "
        + ", ".join(ALL_MODELS))
    my_parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    my_parser.add_argument(
        "--log_dir",
        default='',
        type=str,
        help="The output directory where the log will be written.")
    my_parser.add_argument("--model_recover_path",
                           default=None,
                           type=str,
                           help="The file of fine-tuned pretraining model.")
    my_parser.add_argument("--optim_recover_path",
                           default=None,
                           type=str,
                           help="The file of pretraining optimizer.")
    my_parser.add_argument(
        "--config_name",
        default="",
        type=str,
        help="Pretrained config name or path if not the same as model_name")
    my_parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name")

    # Other parameters
    my_parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    my_parser.add_argument('--max_position_embeddings',
                           type=int,
                           default=None,
                           help="max position embeddings")
    my_parser.add_argument("--do_train",
                           action='store_true',
                           help="Whether to run training.")
    my_parser.add_argument("--do_eval",
                           action='store_true',
                           help="Whether to run eval on the dev set.")
    my_parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    my_parser.add_argument("--train_batch_size",
                           default=32,
                           type=int,
                           help="Total batch size for training.")
    my_parser.add_argument("--eval_batch_size",
                           default=64,
                           type=int,
                           help="Total batch size for eval.")
    my_parser.add_argument("--learning_rate",
                           default=5e-5,
                           type=float,
                           help="The initial learning rate for Adam.")
    my_parser.add_argument("--label_smoothing",
                           default=0.1,
                           type=float,
                           help="The initial learning rate for Adam.")
    my_parser.add_argument("--weight_decay",
                           default=0.01,
                           type=float,
                           help="The weight decay rate for Adam.")
    my_parser.add_argument("--adam_epsilon",
                           default=1e-8,
                           type=float,
                           help="Epsilon for Adam optimizer.")
    my_parser.add_argument("--max_grad_norm",
                           default=1.0,
                           type=float,
                           help="Max gradient norm.")
    my_parser.add_argument("--num_train_epochs",
                           default=3.0,
                           type=float,
                           help="Total number of training epochs to perform.")
    my_parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    my_parser.add_argument("--hidden_dropout_prob",
                           default=0.1,
                           type=float,
                           help="Dropout rate for hidden states.")
    my_parser.add_argument("--attention_probs_dropout_prob",
                           default=0.1,
                           type=float,
                           help="Dropout rate for attention probabilities.")
    my_parser.add_argument("--no_cuda",
                           action='store_true',
                           help="Whether not to use CUDA when available")
    my_parser.add_argument("--local_rank",
                           type=int,
                           default=-1,
                           help="local_rank for distributed training on gpus")
    my_parser.add_argument('--seed',
                           type=int,
                           default=42,
                           help="random seed for initialization")
    my_parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    my_parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    my_parser.add_argument(
        '--fp16_opt_level',
        type=str,
        default='O1',
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html")
    my_parser.add_argument('--tokenized_input',
                           action='store_true',
                           help="Whether the input is tokenized.")
    my_parser.add_argument(
        '--max_len_a',
        type=int,
        default=0,
        help="Truncate_config: maximum length of segment A.")
    my_parser.add_argument(
        '--max_len_b',
        type=int,
        default=0,
        help="Truncate_config: maximum length of segment B.")
    my_parser.add_argument(
        '--trunc_seg',
        default='',
        help="Truncate_config: first truncate segment A/B (option: a, b).")
    my_parser.add_argument(
        '--always_truncate_tail',
        action='store_true',
        help="Truncate_config: Whether we should always truncate tail.")
    my_parser.add_argument(
        "--mask_prob",
        default=0.20,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    my_parser.add_argument(
        "--mask_prob_eos",
        default=0,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    my_parser.add_argument('--max_pred',
                           type=int,
                           default=69,
                           help="Max tokens of prediction.")
    my_parser.add_argument("--num_workers",
                           default=0,
                           type=int,
                           help="Number of workers for the data loader.")

    my_parser.add_argument('--mask_source_words',
                           action='store_true',
                           help="Whether to mask source words for training")
    my_parser.add_argument('--skipgram_prb',
                           type=float,
                           default=0.0,
                           help='prob of ngram mask')
    my_parser.add_argument('--skipgram_size',
                           type=int,
                           default=1,
                           help='the max size of ngram mask')
    my_parser.add_argument('--mask_whole_word',
                           action='store_true',
                           help="Whether masking a whole word.")

    args = my_parser.parse_args()

    if not (args.model_recover_path
            and Path(args.model_recover_path).exists()):
        args.model_recover_path = None

    args.output_dir = args.output_dir.replace('[PT_OUTPUT_DIR]',
                                              os.getenv('PT_OUTPUT_DIR', ''))
    args.log_dir = args.log_dir.replace('[PT_OUTPUT_DIR]',
                                        os.getenv('PT_OUTPUT_DIR', ''))

    os.makedirs(args.output_dir, exist_ok=True)
    if args.log_dir:
        os.makedirs(args.log_dir, exist_ok=True)
    json.dump(args.__dict__,
              open(os.path.join(args.output_dir, 'opt.json'), 'w'),
              sort_keys=True,
              indent=2)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        dist.init_process_group(backend='nccl')
    my_logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    args.model_type = args.model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    config = config_class.from_pretrained(
        args.config_name if args.config_name else args.model_name_or_path,
        max_position_embeddings=args.max_position_embeddings,
        label_smoothing=args.label_smoothing)
    tokenizer = tokenizer_class.from_pretrained(
        args.tokenizer_name
        if args.tokenizer_name else args.model_name_or_path,
        do_lower_case=args.do_lower_case)
    data_tokenizer = WhitespaceTokenizer(
    ) if args.tokenized_input else tokenizer
    if args.local_rank == 0:
        dist.barrier()

    if args.do_train:
        print("Loading Train Dataset", args.data_dir)
        bi_uni_pipeline = [
            utils_seq2seq.Preprocess4Seq2seq(
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                mask_source_words=False,
                skipgram_prb=args.skipgram_prb,
                skipgram_size=args.skipgram_size,
                mask_whole_word=args.mask_whole_word,
                tokenizer=data_tokenizer)
        ]

        file = os.path.join(args.data_dir,
                            args.src_file if args.src_file else 'train.tgt')
        train_dataset = utils_seq2seq.Seq2SeqDataset(
            file,
            args.train_batch_size,
            data_tokenizer,
            args.max_seq_length,
            bi_uni_pipeline=bi_uni_pipeline)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset, replacement=False)
            _batch_size = args.train_batch_size
        else:
            train_sampler = DistributedSampler(train_dataset)
            _batch_size = args.train_batch_size // dist.get_world_size()
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=_batch_size,
            sampler=train_sampler,
            num_workers=args.num_workers,
            collate_fn=utils_seq2seq.batch_list_to_batch_tensors,
            pin_memory=False)
        print("Loading dev dataset")
        dev_file = os.path.join(args.data_dir, 'dev_data.json')
        dev_dataset = utils_seq2seq.Seq2SeqDataset(
            dev_file,
            args.eval_batch_size,
            data_tokenizer,
            args.max_seq_length,
            bi_uni_pipeline=bi_uni_pipeline)
        dev_dataloader = torch.utils.data.DataLoader(
            dev_dataset,
            batch_size=args.eval_batch_size,
            collate_fn=utils_seq2seq.batch_list_to_batch_tensors,
            pin_memory=False,
            num_workers=args.num_workers)

    # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps)
    # t_total = int(math.ceil(len(train_dataset.ex_list) / args.train_batch_size)
    t_total = int(
        len(train_dataloader) * args.num_train_epochs /
        args.gradient_accumulation_steps)

    # Prepare model
    recover_step = _get_max_epoch_model(args.output_dir)
    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    global_step = 0
    if (recover_step is None) and (args.model_recover_path is None):
        model_recover = None
    else:
        if recover_step:
            my_logger.info("***** Recover model: %d *****", recover_step)
            model_recover = torch.load(os.path.join(
                args.output_dir, "model.{0}.bin".format(recover_step)),
                                       map_location='cpu')
            # recover_step == number of epochs
            global_step = math.floor(recover_step * t_total /
                                     args.num_train_epochs)
        elif args.model_recover_path:
            my_logger.info("***** Recover model: %s *****",
                           args.model_recover_path)
            model_recover = torch.load(args.model_recover_path,
                                       map_location='cpu')
    model = model_class.from_pretrained(args.model_name_or_path,
                                        state_dict=model_recover,
                                        config=config)
    if args.local_rank == 0:
        dist.barrier()

    model.to(device)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(args.warmup_proportion * t_total),
        num_training_steps=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    if args.local_rank != -1:
        try:
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("DistributedDataParallel")
        model = DDP(model,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank,
                    find_unused_parameters=True)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    if recover_step:
        my_logger.info("***** Recover optimizer: %d *****", recover_step)
        optim_recover = torch.load(os.path.join(
            args.output_dir, "optim.{0}.bin".format(recover_step)),
                                   map_location='cpu')
        if hasattr(optim_recover, 'state_dict'):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)

        if os.path.exists(
                os.path.join(args.output_dir,
                             "amp.{0}.bin".format(recover_step))):
            my_logger.info("***** Recover amp: %d *****", recover_step)
            amp_recover = torch.load(os.path.join(
                args.output_dir, "amp.{0}.bin".format(recover_step)),
                                     map_location='cpu')
            amp.load_state_dict(amp_recover)

        my_logger.info("***** Recover scheduler: %d *****", recover_step)
        scheduler_recover = torch.load(os.path.join(
            args.output_dir, "sched.{0}.bin".format(recover_step)),
                                       map_location='cpu')
        scheduler.load_state_dict(scheduler_recover)

    my_logger.info("***** CUDA.empty_cache() *****")
    torch.cuda.empty_cache()

    if args.do_train:
        my_logger.info("***** Running training *****")
        my_logger.info("  Batch size = %d", args.train_batch_size)
        my_logger.info("  Num steps = %d", t_total)

        model.train()
        if recover_step:
            start_epoch = recover_step + 1
        else:
            start_epoch = 1
        for i_epoch in trange(start_epoch,
                              int(args.num_train_epochs) + 1,
                              desc="Epoch",
                              disable=args.local_rank not in (-1, 0)):
            if args.local_rank != -1:
                train_sampler.set_epoch(i_epoch)
            iter_bar = tqdm(train_dataloader,
                            desc='Iter (loss=X.XXX)',
                            disable=args.local_rank not in (-1, 0))
            final_loss = 0
            for step, batch in enumerate(iter_bar):
                batch = [
                    t.to(device) if t is not None else None for t in batch
                ]
                input_ids, segment_ids, answer_tag, input_mask, lm_label_ids, masked_pos, masked_weights, _ = batch
                if answer_tag == None:
                    print("answer tag is none")
                masked_lm_loss = model(input_ids,
                                       segment_ids,
                                       answer_tag,
                                       input_mask,
                                       lm_label_ids,
                                       masked_pos=masked_pos,
                                       masked_weights=masked_weights)
                if n_gpu > 1:  # mean() to average on multi-gpu.
                    # loss = loss.mean()
                    masked_lm_loss = masked_lm_loss.mean()
                loss = masked_lm_loss
                final_loss = loss.item()

                # logging for each step (i.e., before normalization by args.gradient_accumulation_steps)
                iter_bar.set_description('Iter (loss=%5.3f)' % loss.item())

                # ensure that accumlated gradients are normalized
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                if (step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    optimizer.zero_grad()
                    global_step += 1
            # Save a trained model
            if (args.local_rank == -1 or torch.distributed.get_rank() == 0):
                my_logger.info(
                    "** ** * Saving fine-tuned model and optimizer ** ** * ")
                model_to_save = model.module if hasattr(
                    model, 'module') else model  # Only save the model it-self
                output_model_file = os.path.join(
                    args.output_dir, "model.{0}.bin".format(i_epoch))
                torch.save(model_to_save.state_dict(), output_model_file)
                output_optim_file = os.path.join(
                    args.output_dir, "optim.{0}.bin".format(i_epoch))
                torch.save(optimizer.state_dict(), output_optim_file)
                if args.fp16:
                    output_amp_file = os.path.join(
                        args.output_dir, "amp.{0}.bin".format(i_epoch))
                    torch.save(amp.state_dict(), output_amp_file)
                output_sched_file = os.path.join(
                    args.output_dir, "sched.{0}.bin".format(i_epoch))
                torch.save(scheduler.state_dict(), output_sched_file)

                my_logger.info("***** CUDA.empty_cache() *****")
                torch.cuda.empty_cache()

            if args.do_eval:
                # do_eval
                iter_dev = tqdm(dev_dataloader,
                                desc='Iter (loss=X.XXX)',
                                disable=args.local_rank not in (-1, 0))
                val_losses = []
                for step, batch in enumerate(iter_dev):
                    with torch.no_grad():
                        batch = [
                            t.to(device) if t is not None else None
                            for t in batch
                        ]
                        input_ids, segment_ids, answer_tag, input_mask, lm_label_ids, masked_pos, masked_weights, _ = batch
                        masked_dev_loss = model(input_ids,
                                                segment_ids,
                                                answer_tag,
                                                input_mask,
                                                lm_label_ids,
                                                masked_pos=masked_pos,
                                                masked_weights=masked_weights)
                        val_losses.append(masked_dev_loss.item())
                val_loss = np.mean(val_losses)
                print(
                    "Epoch {} - final loss : {:.4f} - val loss :{:.4f}".format(
                        i_epoch, final_loss, val_loss))
Example #12
0
class Trainer(object):
    def __init__(self, cfg):
        self.cfg = cfg
        if self.cfg.feature:
            net = FeatureModel(self.cfg)
        else:
            net = BasicModel(self.cfg)
        # print(tuple(self.cfg.adam_betas))
        print(net)

        if self.cfg.cuda:
            net = net.cuda()
            if self.cfg.parallel and torch.cuda.device_count() > 1:
                print("Let's use", torch.cuda.device_count(), "GPUs!")
                net = nn.DataParallel(net)
        self.net = net
        self.start_epoch = 0
        if self.cfg.pretrained is not None:
            self.load_pretrained_net(pretrained=self.cfg.pretrained)

        self.index2label = {
            0: 'A',
            1: 'B',
            2: 'C',
            3: 'D',
            4: 'E'
        } if self.cfg.task == 'commonsense_qa' else {
            0: '1',
            1: '2'
        }

        self.best = 1. / len(self.index2label)

        if self.cfg.task == 'winograde':
            self.cfg.task += '_' + self.cfg.train_size

    def train(self, train_db, val_db):
        # if self.cfg.cuda and self.cfg.parallel:
        #     net = self.net.module
        # else:
        #     net = self.net
        # net = self.net
        # print(net)
        if self.cfg.tensorboard_logdir is not None:
            summary_writer = SummaryWriter(self.cfg.tensorboard_logdir)
        else:
            summary_writer = SummaryWriter(
                osp.join(self.cfg.log_dir, self.cfg.task, 'tensorboard',
                         self.cfg.model_name))

        # log_per_steps = self.cfg.accumulation_steps * self.cfg.log_per_steps

        log_dir = osp.join(self.cfg.log_dir, self.cfg.task,
                           self.cfg.model_name)
        if not osp.exists(log_dir):
            os.makedirs(log_dir)

        code_dir = osp.join(log_dir, 'code')
        if not osp.exists(code_dir):
            os.makedirs(code_dir)

        shutil.copy('./train.py', osp.join(code_dir, 'train.py'))
        shutil.copy('./commonsense_dataset.py',
                    osp.join(code_dir, 'commonsense_dataset.py'))

        logz.configure_output_dir(log_dir)
        logz.save_config(self.cfg)

        train_loader = DataLoader(train_db,
                                  batch_size=self.cfg.batch_size,
                                  shuffle=True,
                                  num_workers=self.cfg.num_workers)

        # self.optimizer = BertAdam(net.parameters(), lr=cfg.lr, warmup=cfg.warmup)
        # self.scheduler = optim.lr_self.scheduler.StepLR(self.optimizer, step_size=3, gamma=0.8)

        num_train_steps = int(
            len(train_loader) / self.cfg.accumulation_steps * self.cfg.epochs)
        num_warmup_steps = int(num_train_steps * self.cfg.warmup)

        no_decay = ['bias', 'LayerNorm.weight']
        not_optim = []

        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in self.net.named_parameters()
                if (not any(nd in n for nd in no_decay)) and (not any(
                    nd in n for nd in not_optim))
            ],
            'weight_decay':
            self.cfg.weight_decay
        }, {
            'params': [
                p for n, p in self.net.named_parameters()
                if (any(nd in n
                        for nd in no_decay)) and (not any(nd in n
                                                          for nd in not_optim))
            ],
            'weight_decay':
            0.0
        }]

        if self.cfg.fix_emb:
            for p in self.net.embedding.embeddings.parameters():
                p.requires_grad = False

        if self.cfg.ft_last_layer:
            for p in self.net.embedding.embeddings.parameters():
                p.requires_grad = False
            for i in range(10):
                for p in self.net.embedding.encoder.layer[i]:
                    p.requires_grad = False

        self.optimizer = AdamW(optimizer_grouped_parameters,
                               lr=self.cfg.lr,
                               eps=self.cfg.adam_eps,
                               betas=eval(self.cfg.adam_betas))
        # self.optimizer = AdamW(self.net.parameters(), lr=self.cfg.lr, eps=1e-8)

        self.scheduler = WarmupLinearSchedule(self.optimizer,
                                              warmup_steps=num_warmup_steps,
                                              t_total=num_train_steps)
        loss_func = nn.CrossEntropyLoss()

        if self.cfg.fp16:
            try:
                from apex import amp
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            self.net, self.optimizer = amp.initialize(
                self.net, self.optimizer, opt_level=self.cfg.fp16_opt_level)
        # self.scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_train_steps)

        # self.optimizer.set_self.scheduler(self.scheduler)

        torch.cuda.synchronize()
        self.start = time.time()
        self.net.zero_grad()
        self.batch_loss, self.batch_acc = [], []
        self.global_step = 0
        for epoch in range(self.start_epoch, self.cfg.epochs):

            print('Training...')
            torch.cuda.empty_cache()
            self.batch_loss, self.batch_acc = [], []
            for cnt, batch in tqdm(enumerate(train_loader)):
                self.net.train()

                input_ids, input_mask, segment_ids, features, fea_mask, labels, input_indexs = self.batch_data(
                    batch)
                batch_input = (input_ids, input_mask, segment_ids, features,
                               fea_mask)
                # self.net.zero_grad()
                logits, probabilities, preds = self.net(
                    input_ids, input_mask, segment_ids, features, fea_mask)
                loss = loss_func(logits, labels).mean()
                # print(probabilities)

                # one_hot_labels = nn.functional.one_hot(labels, num_classes = Number_class[self.cfg.task.lower()]).float()
                # per_example_loss = -torch.sum(one_hot_labels * log_probs, dim=-1)
                # loss = torch.mean(per_example_loss)

                if self.cfg.accumulation_steps > 1:
                    loss = loss / self.cfg.accumulation_steps

                if self.cfg.fp16:
                    with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                        scaled_loss.backward()
                    if self.cfg.max_grad_norm > 0.0:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(self.optimizer),
                            self.cfg.max_grad_norm)
                else:
                    loss.backward()
                    if self.cfg.max_grad_norm > 0.0:
                        nn.utils.clip_grad_norm_(self.net.parameters(),
                                                 self.cfg.max_grad_norm)

                acc, _, _, _ = self.evaluate(preds, labels, input_indexs)

                self.batch_loss.append(loss.cpu().data.item() / len(input_ids))
                self.batch_acc.append(acc)

                if self.global_step == 0 and cnt == 0:
                    _ = self.update_log(summary_writer, epoch, val_db)

                if ((cnt + 1) % self.cfg.accumulation_steps) == 0:
                    # print(nn.utils.clip_grad_norm_(self.net.parameters(), max_norm=1e5))
                    self.optimizer.step()
                    self.scheduler.step()
                    self.net.zero_grad()
                    self.global_step += 1

                    if self.global_step % self.cfg.log_per_steps == 0:
                        val_acc = self.update_log(summary_writer, epoch,
                                                  val_db)
                        self.batch_loss, self.batch_acc = [], []

                        if self.cfg.save_ckpt:
                            if epoch >= (self.cfg.epochs / 4):
                                if self.best < val_acc:
                                    print('Saving checkpoint...')
                                    self.save_checkpoint(epoch, acc=val_acc)
                                    self.best = val_acc

            ##################################################################
            ## Checkpoint
            ##################################################################
            if len(self.batch_loss) > 0:
                val_acc = self.update_log(summary_writer, epoch, val_db)
                self.best = max(self.best, val_acc)
                self.batch_loss, self.batch_acc = [], []

            # val_wrong_qa = []
            # for q, a in zip(val_wrong, val_wrong_answer):
            #     val_wrong_qa.append([val_db.index2qid[q], trainer.index2label[a]])
            # epoch_wrong = {epoch: val_wrong_qa}
            if self.cfg.save_ckpt:
                if epoch >= (self.cfg.epochs / 4):
                    print('Saving checkpoint...')
                    self.save_checkpoint(epoch, True, acc=val_acc)
            torch.cuda.empty_cache()

        summary_writer.close()

    def update_log(self, summary_writer, epoch, val_db, inds=None):
        # print('Epoch %03d, iter %07d:'%(epoch, cnt))
        # print('loss: %05f, acc: %05f'%(np.mean(self.batch_loss), np.mean(self.batch_acc)))
        # # print(self.scheduler.get_lr()[0])
        # print('-------------------------')
        summary_writer.add_scalar('train_loss', np.mean(self.batch_loss),
                                  self.global_step)
        summary_writer.add_scalar('train_acc', np.mean(self.batch_acc),
                                  self.global_step)

        val_loss, val_acc, val_wrong, val_wrong_answer, eqs_ = self.validate(
            val_db)
        summary_writer.add_scalar('val_loss', np.mean(val_loss),
                                  self.global_step)
        summary_writer.add_scalar('val_acc', val_acc, self.global_step)
        summary_writer.add_scalar('lr',
                                  self.scheduler.get_lr()[0], self.global_step)

        # update optim self.scheduler
        torch.cuda.synchronize()
        logz.log_tabular("Time", time.time() - self.start)
        logz.log_tabular("Iteration", epoch)
        logz.log_tabular("TrainAverageLoss", np.mean(self.batch_loss))
        logz.log_tabular("TrainAverageAccu", np.mean(self.batch_acc))
        logz.log_tabular("ValAverageLoss", np.mean(val_loss))
        logz.log_tabular("ValAverageAccu", val_acc)

        if inds is not None:
            val_cnt = len(eqs_)
            eqs = [eqs_[i] for i in inds]
            eq0 = np.array(eqs[:int(val_cnt / 2)])
            eq1 = np.array(eqs[int(val_cnt / 2):])
            logz.log_tabular("ValAverageAccu0", eq0.sum() / len(eq0))
            logz.log_tabular("ValAverageAccu1", eq1.sum() / len(eq1))

        logz.dump_tabular()

        return val_acc

    def validate(self, val_db):
        ##################################################################
        ## Validation
        ##################################################################
        # if self.cfg.cuda and self.cfg.parallel:
        #     net = self.net.module
        # else:
        #     net = self.net
        # net = self.net
        print('Validation...')
        torch.cuda.empty_cache()
        self.net.eval()

        loss_func = nn.CrossEntropyLoss()

        val_loader = DataLoader(val_db,
                                batch_size=self.cfg.batch_size,
                                shuffle=False,
                                num_workers=self.cfg.num_workers)

        val_loss, preds_, labels_, input_indexs_ = [], [], [], []
        for _, batch in tqdm(enumerate(val_loader)):
            input_ids, input_mask, segment_ids, features, fea_mask, labels, input_indexs = self.batch_data(
                batch)
            batch_input = (input_ids, input_mask, segment_ids, features,
                           fea_mask)

            with torch.no_grad():
                logits, probabilities, preds = self.net(
                    input_ids, input_mask, segment_ids, features, fea_mask)

                preds_.extend(preds)
                labels_.extend(labels)
                input_indexs_.extend(input_indexs)

                # if gate is not None:
                #     active_gate = torch.BoolTensor([g[0] >= 0.1 or g[1] >= 0.1 for g in gate])
                #     active_index = list(np.array(input_indexs[active_gate].cpu().data))
                #     val_activate_index += active_index
                loss = loss_func(logits, labels).mean()

                # acc, wrong_indexs, wrong_answer, eq = self.evaluate(preds, labels, input_indexs)
                val_loss.append(loss.cpu().data.item() / len(input_ids))
                # val_acc.append(acc)
                # val_wrong += wrong_indexs
                # val_wrong_answer += wrong_answer
                # eqs.extend(eq)

        val_acc, val_wrong, val_wrong_answer, eqs = self.evaluate(
            torch.Tensor(preds_), torch.Tensor(labels_),
            torch.Tensor(input_indexs_))
        # print(val_acc)

        return val_loss, val_acc, val_wrong, val_wrong_answer, eqs

    def test(self, test_db):
        ##################################################################
        ## Validation
        # ##################################################################
        # if self.cfg.cuda and self.cfg.parallel:
        #     net = self.net.module
        # else:
        #     net = self.net
        # net = self.net
        print('Validation...')
        torch.cuda.empty_cache()
        self.net.eval()

        test_loader = DataLoader(test_db,
                                 batch_size=self.cfg.batch_size,
                                 shuffle=False,
                                 num_workers=self.cfg.num_workers)

        answer = []

        for _, batch in tqdm(enumerate(test_loader)):
            input_ids, input_mask, segment_ids, features, fea_mask, labels, input_indexs = self.batch_data(
                batch)
            batch_input = (input_ids, input_mask, segment_ids, features,
                           fea_mask)

            with torch.no_grad():
                logits, probabilities, preds = self.net(
                    input_ids, input_mask, segment_ids, features, fea_mask)
                input_indexs = list(np.array(input_indexs.cpu().data))
                preds = list(np.array(preds.cpu().data))
                for ind, pred in zip(input_indexs, preds):
                    answer.append(
                        (test_db.index2qid[ind], self.index2label[pred]))

        return answer

    def load_pretrained_net(self, pretrained):
        if self.cfg.cuda and self.cfg.parallel:
            net = self.net.module
        else:
            net = self.net
        # self.begin_epoch = int(pretrained_name.split('-')[1]) + 1
        print('loading ckpt from ', pretrained)

        assert osp.exists(pretrained)
        if self.cfg.cuda:
            checkpoint = torch.load(pretrained)
        else:
            checkpoint = torch.load(pretrained,
                                    map_location=lambda storage, loc: storage)

        net.load_state_dict(checkpoint['net'])

    def save_checkpoint(self, epoch, force=False, acc=None):
        # wrong_index_path = osp.join(self.cfg.log_dir, self.cfg.task, self.cfg.model_name, "wrong_index.jsonl")
        # with jsonlines.open(wrong_index_path, 'a+') as writer:
        #     writer.write(epoch_wrong)

        print(" [*] Saving checkpoints...")
        if self.cfg.cuda and self.cfg.parallel:
            net = self.net.module
        else:
            net = self.net

        checkpoint_dir = osp.join(self.cfg.log_dir, self.cfg.task,
                                  self.cfg.model_name)
        if not osp.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        tail = ''
        if acc is not None:
            acc = str(round(acc, 5))[2:]
            tail += '-'
            tail += acc
        if force:
            tail += '-'
            tail += 'end'
        model_name = "ckpt-%03d%s.pkl" % (epoch, tail)

        print('saving ckpt to ', checkpoint_dir)
        if self.cfg.fp16:
            state = {
                'net': net.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'epoch': epoch,
                'amp': amp.state_dict()
            }
        else:
            state = {
                'net': net.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'epoch': epoch
            }
        # torch.save(net.state_dict(), osp.join(checkpoint_dir, model_name))
        torch.save(state, osp.join(checkpoint_dir, model_name))

    def batch_data(self, entry):
        features, fea_mask = None, None
        input_ids = entry['token_ids'].long()
        segment_ids = entry['segment_ids'].long()
        input_mask = entry['mask'].long()
        labels = entry['label_ids'].long()
        input_indexs = entry['index'].long()

        if self.cfg.feature:
            features = entry['feature'].float()
            fea_mask = entry['fea_mask'].long()

        # print(input_ids[0])
        # exit()

        if self.cfg.cuda:
            input_ids = input_ids.cuda()
            input_mask = input_mask.cuda()
            segment_ids = segment_ids.cuda()
            labels = labels.cuda()
            input_indexs = input_indexs.cuda()
            if self.cfg.feature:
                features = features.cuda()
                fea_mask = fea_mask.cuda()

        return input_ids, input_mask, segment_ids, features, fea_mask, labels, input_indexs

    def evaluate(self, pred, labels, input_indexs):
        eq = torch.eq(pred, labels)
        # print(labels.shape)
        wrong_indexs = list(np.array(input_indexs[~eq].cpu().data))
        wrong_answer = list(np.array(pred[~eq].cpu().data))
        correct = eq.sum().cpu().data.item()
        acc = correct / len(labels)

        return acc, wrong_indexs, wrong_answer, np.array(eq.cpu())