예제 #1
0
def train(args, train_iter, dev, test, src_field, tgt_field, tag_field,
          checkpoint):
    # srcpadid = src_field.vocab.stoi['<pad>']
    tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

    model = Classify_Extractor(args, tgt_field)

    if torch.cuda.is_available():
        model.cuda()

    print_params(model)

    decay = args.decay

    if args.optimizer == 'bert':
        weight_decay = 0.0
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            weight_decay
        }, {
            'params': [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }]
        opt = AdamW(optimizer_grouped_parameters, lr=args.lr, eps=1e-8)
        totalnum = 0
        for i in train_iter:
            totalnum += 1
        #print(args.lr)
        #print(args.maximum_steps)
        #exit()
        t_total = totalnum // decay * args.maximum_steps
        scheduler = WarmupLinearSchedule(opt, warmup_steps=0, t_total=t_total)
    else:
        opt = torch.optim.Adadelta(model.parameters(), lr=args.lr)

    best_e = 0.0
    best_c = 0.0
    best_epoch_for_c = 0
    best_epoch_for_e = 0
    offset = 0.0
    pre_epoch = 0
    patience_c = 0
    patience_e = 0

    if checkpoint is not None:
        print('model.load_state_dict(checkpoint[model])')
        model.load_state_dict(checkpoint['model'])
        if args.resume:
            opt.load_state_dict(checkpoint['optim'])

            best_f = checkpoint['f']
            offset = checkpoint['iters']
            pre_epoch = checkpoint['epoch']

            print('*************************************')
            print('resume from {} epoch {} iters and best_f {}'.format(
                pre_epoch, offset, best_f))
            print('*************************************')

    print("**************start training****************")
    start = time.time()

    for epoch in range(args.maxepoch):
        train_iter.init_epoch()
        epoch += pre_epoch

        for iters, train_batch in enumerate(train_iter):
            iters += offset
            model.train()
            # model.zero_grad()
            # model.constrain_transition()
            t1 = time.time()
            batch_src = train_batch.src
            #print(batch_src)
            #exit()
            src = [tokenizer.convert_tokens_to_ids(s) for s in batch_src]
            maxlen = max([len(s) for s in batch_src])

            src_mask = []
            padded_sents = []
            for s in src:
                new_s = s + [0] * (maxlen - len(s))
                padded_sents.append(new_s)
                mask = [1] * len(s) + [0] * (maxlen - len(s))
                src_mask.append(mask)
            # B T
            src = torch.tensor(padded_sents).long().cuda()
            # B T
            src_mask = torch.tensor(src_mask).byte().cuda()
            # src, src_mask = prepare_src(train_batch.src, srcpadid)
            tgt = prepare_tgt(train_batch.tgt)
            tag = train_batch.tag

            loss = model(src, src_mask, tgt, tag)

            # "update parameters"

            if decay > 1:
                loss = loss / decay

            loss.backward()

            # if args.grad_clip:
            #     torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)

            if (iters + 1) % decay == 0:
                opt.step()
                scheduler.step()  # Update learning rate schedule
                opt.zero_grad()

            # opt.step()

            t2 = time.time()

            loss = loss.item()

            print("epoch:{} iters:{} src:({},{}) tgt:({},{}) "
                  "loss:{:.2f} t:{:.2f}".format(epoch + 1, iters + 1,
                                                *src.size(), *tgt.size(), loss,
                                                t2 - t1))

        # if torch.cuda.is_available():
        #     torch.cuda.empty_cache()

        if (epoch + 1) % 1 == 0:
            print("=============validate model==============")
            with torch.no_grad():
                dev.init_epoch()
                model.eval()
                # model.constrain_transition()
                sents = []
                cy_true = []
                cy_pred = []
                for j, dev_batch in enumerate(dev):
                    t1 = time.time()
                    # src, src_mask = prepare_src(dev_batch.src, srcpadid)
                    batch_src = dev_batch.src
                    src = [
                        tokenizer.convert_tokens_to_ids(s) for s in batch_src
                    ]
                    maxlen = max([len(s) for s in batch_src])

                    src_mask = []
                    padded_sents = []
                    for s in src:
                        new_s = s + [0] * (maxlen - len(s))
                        padded_sents.append(new_s)
                        mask = [1] * len(s) + [0] * (maxlen - len(s))
                        src_mask.append(mask)
                    # B T
                    src = torch.tensor(padded_sents).long().cuda()
                    # B T
                    src_mask = torch.tensor(src_mask).byte().cuda()

                    tgt = prepare_tgt(dev_batch.tgt)
                    tag = dev_batch.tag.squeeze(-1)
                    _, pre_tag = model.component_extraction(src, src_mask)
                    pre_ctag = model.simile_classify(src, src_mask)
                    cy_true.extend(tag.tolist())
                    cy_pred.extend(pre_ctag.tolist())

                    for sen, tags, p_tags, c_tags in zip(
                            src, tgt, pre_tag, tag):
                        sen = sen[:len(p_tags)].tolist()
                        tags = tags[:len(p_tags)].tolist()
                        if c_tags == 1:
                            sents.append([
                                sen, [tgt_field.vocab.itos[t] for t in tags],
                                [tgt_field.vocab.itos[t] for t in p_tags]
                            ])
                    print('dev iters: {}, t:{}'.format(j, time.time() - t1))

                _, eprecision, erecall, ef1 = evaluate(sents)

                cprecision = precision_score(cy_true, cy_pred)
                crecall = recall_score(cy_true, cy_pred)
                cf1 = f1_score(cy_true, cy_pred)

                print(
                    'epoch: {} classify--> precision: {} recall: {} f1: {} best:{}'
                    .format(epoch + 1, cprecision, crecall, cf1, best_c))
                print('extractor--> precision: {} recall: {} f1: {} best: {}'.
                      format(eprecision, erecall, ef1, best_e))

                if cf1 > best_c:
                    best_c = cf1
                    best_epoch_for_c = epoch + 1

                    print(
                        'save best classifier model at epoch={}'.format(epoch +
                                                                        1))
                    checkpoint = {
                        'model': model.state_dict(),
                        'optim': opt.state_dict(),
                        'args': args
                    }
                    torch.save(
                        checkpoint, '{}/{}.classify.best.pt'.format(
                            args.model_path, args.model))
                    patience_c = 0
                else:
                    patience_c += 1

                if ef1 > best_e:
                    best_e = ef1
                    best_epoch_for_e = epoch + 1

                    print(
                        'save best extractor model at epoch={}'.format(epoch +
                                                                       1))
                    checkpoint = {
                        'model': model.state_dict(),
                        'optim': opt.state_dict(),
                        'args': args
                    }
                    torch.save(
                        checkpoint, '{}/{}.extractor.best.pt'.format(
                            args.model_path, args.model))
                    patience_e = 0
                else:
                    patience_e += 1

        if patience_c > args.patience and patience_e > args.patience:
            print("early stop at {}".format(epoch))
            break

        if args.decay:
            opt.param_groups[0]['lr'] = opt.param_groups[0]['lr'] * args.decay

    print('*******Done********{}'.format(
        time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())))
    minutes = (time.time() - start) // 60
    if minutes < 60:
        print(
            'best_c:{}, best_e:{} best_epoch_c:{}, best_epoch_e:{}, time:{} mins'
            .format(best_c, best_e, best_epoch_for_c, best_epoch_for_e,
                    minutes))
    else:
        hours = minutes / 60
        print(
            'best_c:{}, best_e:{} best_epoch_c:{}, best_epoch_e:{}, time:{:.1f} hours'
            .format(best_c, best_e, best_epoch_for_c, best_epoch_for_e, hours))

    print('*******Testing************')
    model1 = Classify_Extractor(args, tgt_field)
    model1.cuda()
    load_from = '{}/{}.classify.best.pt'.format(args.model_path, args.model)
    print('load the best model {}'.format(load_from))
    checkpoint = torch.load(load_from, map_location='cpu')
    print('load parameters')
    model1.load_state_dict(checkpoint['model'])

    model2 = Classify_Extractor(args, tgt_field)
    model2.cuda()
    load_from = '{}/{}.extractor.best.pt'.format(args.model_path, args.model)
    print('load the best model {}'.format(load_from))
    checkpoint = torch.load(load_from, map_location='cpu')
    print('load parameters')
    model2.load_state_dict(checkpoint['model'])
    with torch.no_grad():
        test.init_epoch()
        model1.eval()
        model2.eval()
        sents = []
        cy_true = []
        cy_pred = []
        for j, test_batch in enumerate(test):
            t1 = time.time()
            # src, src_mask = prepare_src(test_batch.src, srcpadid)
            batch_src = test_batch.src
            src = [tokenizer.convert_tokens_to_ids(s) for s in batch_src]
            maxlen = max([len(s) for s in batch_src])

            src_mask = []
            padded_sents = []
            for s in src:
                new_s = s + [0] * (maxlen - len(s))
                padded_sents.append(new_s)
                mask = [1] * len(s) + [0] * (maxlen - len(s))
                src_mask.append(mask)
            # B T
            src = torch.tensor(padded_sents).long().cuda()
            # B T
            src_mask = torch.tensor(src_mask).byte().cuda()

            tgt = prepare_tgt(test_batch.tgt)
            tag = test_batch.tag.squeeze(-1)
            _, pre_tag = model2.component_extraction(src, src_mask)
            pre_ctag = model1.simile_classify(src, src_mask)
            cy_true.extend(tag.tolist())
            cy_pred.extend(pre_ctag.tolist())

            # for sen, tags, p_tags in zip(src, tgt, pre_tag):
            #     sen = sen[:len(p_tags)].tolist()
            #     tags = tags[:len(p_tags)].tolist()
            #     sents.append([sen, [tgt_field.vocab.itos[t] for t in tags],
            #                  [tgt_field.vocab.itos[t] for t in p_tags]])
            for sen, tags, p_tags, c_tags in zip(src, tgt, pre_tag, pre_ctag):
                sen = sen[:len(p_tags)].tolist()
                tags = tags[:len(p_tags)].tolist()
                if c_tags == 1:
                    sents.append([
                        sen, [tgt_field.vocab.itos[t] for t in tags],
                        [tgt_field.vocab.itos[t] for t in p_tags]
                    ])
                elif c_tags == 0:
                    sents.append([
                        sen, [tgt_field.vocab.itos[t] for t in tags],
                        ['O' for t in p_tags]
                    ])

            print('test iters: {}, t:{}'.format(j, time.time() - t1))

        _, eprecision, erecall, ef1 = evaluate(sents)

        cprecision = precision_score(cy_true, cy_pred)
        crecall = recall_score(cy_true, cy_pred)
        cf1 = f1_score(cy_true, cy_pred)

        print('Testing classify--> precision: {} recall: {} f1: {}'.format(
            cprecision, crecall, cf1))
        print('extractor--> precision: {} recall: {} f1: {}'.format(
            eprecision, erecall, ef1))
예제 #2
0
def main(args):

    assert args.use_one_optim is True

    if args.use_cls_only:
        args.no_dial = True

    print("### use_cls_only: {:}".format(args.use_cls_only))
    print("### no_dial: {:}".format(args.no_dial))

    if args.recover_e > 0:
        raise NotImplementedError("This option is from my oldest code version. "
                                  "I have not checked it for this code version.")

    if not os.path.exists(args.save_dir):
        os.mkdir(args.save_dir)
        print("### mkdir {:}".format(args.save_dir))

    def worker_init_fn(worker_id):
        np.random.seed(args.random_seed + worker_id)

    n_gpu = 0
    if torch.cuda.is_available() and (not args.use_cpu):
        n_gpu = torch.cuda.device_count()
        device = torch.device('cuda')
        print("### Device: {:}".format(device))
    else:
        print("### Use CPU (Debugging)")
        device = torch.device("cpu")

    if args.random_seed < 0:
        print("### Pick a random seed")
        args.random_seed = random.sample(list(range(0, 100000)), 1)[0]

    print("### Random Seed: {:}".format(args.random_seed))
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)
    rng = random.Random(args.random_seed)
    torch.manual_seed(args.random_seed)

    if n_gpu > 0:
        if args.random_seed >= 0:
            torch.cuda.manual_seed(args.random_seed)
            torch.cuda.manual_seed_all(args.random_seed)

        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    if not os.path.exists(args.save_dir):
        os.mkdir(args.save_dir)

    ontology = json.load(open(args.ontology_data))
    slot_meta, ontology = make_slot_meta(ontology)
    op2id = OP_SET[args.op_code]
    print(op2id)

    tokenizer = BertTokenizer(args.vocab_path, do_lower_case=True)

    train_path = os.path.join(args.data_root, "train.pt")
    dev_path = os.path.join(args.data_root, "dev.pt")
    test_path = os.path.join(args.data_root, "test.pt")

    if not os.path.exists(test_path):
        test_data_raw = prepare_dataset(data_path=args.test_data_path,
                                        tokenizer=tokenizer,
                                        slot_meta=slot_meta,
                                        n_history=args.n_history,
                                        max_seq_length=args.max_seq_length,
                                        op_code=args.op_code)
        torch.save(test_data_raw, test_path)
    else:
        test_data_raw = torch.load(test_path)

    print("# test examples %d" % len(test_data_raw))

    if not os.path.exists(train_path):
        train_data_raw = prepare_dataset(data_path=args.train_data_path,
                                         tokenizer=tokenizer,
                                         slot_meta=slot_meta,
                                         n_history=args.n_history,
                                         max_seq_length=args.max_seq_length,
                                         op_code=args.op_code)

        torch.save(train_data_raw, train_path)
    else:
        train_data_raw = torch.load(train_path)

    train_data = MultiWozDataset(train_data_raw,
                                 tokenizer,
                                 slot_meta,
                                 args.max_seq_length,
                                 rng,
                                 ontology,
                                 args.word_dropout,
                                 args.shuffle_state,
                                 args.shuffle_p, pad_id=tokenizer.convert_tokens_to_ids(['[PAD]'])[0],
                                 slot_id=tokenizer.convert_tokens_to_ids(['[SLOT]'])[0],
                                 decoder_teacher_forcing=args.decoder_teacher_forcing,
                                 use_full_slot=args.use_full_slot,
                                 use_dt_only=args.use_dt_only, no_dial=args.no_dial,
                                 use_cls_only=args.use_cls_only)

    print("# train examples %d" % len(train_data_raw))

    if not os.path.exists(dev_path):
        dev_data_raw = prepare_dataset(data_path=args.dev_data_path,
                                       tokenizer=tokenizer,
                                       slot_meta=slot_meta,
                                       n_history=args.n_history,
                                       max_seq_length=args.max_seq_length,
                                       op_code=args.op_code)
        torch.save(dev_data_raw,  dev_path)
    else:
        dev_data_raw = torch.load(dev_path)

    print("# dev examples %d" % len(dev_data_raw))

    model_config = BertConfig.from_json_file(args.bert_config_path)
    model_config.dropout = args.dropout
    model_config.attention_probs_dropout_prob = args.attention_probs_dropout_prob
    model_config.hidden_dropout_prob = args.hidden_dropout_prob

    type_vocab_size = 4
    dec_config = args
    model = TransformerDST(model_config, dec_config, len(op2id), len(domain2id),
                           op2id['update'],
                           tokenizer.convert_tokens_to_ids(['[MASK]'])[0],
                           tokenizer.convert_tokens_to_ids(['[SEP]'])[0],
                           tokenizer.convert_tokens_to_ids(['[PAD]'])[0],
                           tokenizer.convert_tokens_to_ids(['-'])[0],
                           type_vocab_size, args.exclude_domain)

    if not os.path.exists(args.bert_ckpt_path):
        args.bert_ckpt_path = download_ckpt(args.bert_ckpt_path, args.bert_config_path, 'assets')

    state_dict = torch.load(args.bert_ckpt_path, map_location='cpu')
    _k = 'embeddings.token_type_embeddings.weight'
    print("config.type_vocab_size != state_dict[bert.embeddings.token_type_embeddings.weight] ({0} != {1})".format(
            type_vocab_size, state_dict[_k].shape[0]))
    state_dict[_k].resize_(
        type_vocab_size, state_dict[_k].shape[1])
    state_dict[_k].data[2, :].copy_(state_dict[_k].data[0, :])
    state_dict[_k].data[3, :].copy_(state_dict[_k].data[0, :])
    model.bert.load_state_dict(state_dict)
    print("\n### Done Load BERT")
    sys.stdout.flush()

    # re-initialize added special tokens ([SLOT], [NULL], [EOS])
    model.bert.embeddings.word_embeddings.weight.data[1].normal_(mean=0.0, std=0.02)
    model.bert.embeddings.word_embeddings.weight.data[2].normal_(mean=0.0, std=0.02)
    model.bert.embeddings.word_embeddings.weight.data[3].normal_(mean=0.0, std=0.02)

    # re-initialize seg-2, seg-3
    model.bert.embeddings.token_type_embeddings.weight.data[2].normal_(mean=0.0, std=0.02)
    model.bert.embeddings.token_type_embeddings.weight.data[3].normal_(mean=0.0, std=0.02)
    model.to(device)

    num_train_steps = int(len(train_data_raw) / args.batch_size * args.n_epochs)

    if args.use_one_optim:
        print("### Use One Optim")
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(
                nd in n for nd in no_decay)], 'weight_decay': 0.01},
            {'params': [p for n, p in param_optimizer if any(
                nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=args.enc_lr)
        scheduler = WarmupLinearSchedule(optimizer, int(num_train_steps * args.enc_warmup),
                                             t_total=num_train_steps)
    else:
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        enc_param_optimizer = list(model.bert.named_parameters())  # TODO: For BERT only
        print('### Optim BERT: {:}'.format(len(enc_param_optimizer)))
        enc_optimizer_grouped_parameters = [
            {'params': [p for n, p in enc_param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
            {'params': [p for n, p in enc_param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
            ]

        enc_optimizer = AdamW(enc_optimizer_grouped_parameters, lr=args.enc_lr)
        enc_scheduler = WarmupLinearSchedule(enc_optimizer, int(num_train_steps * args.enc_warmup),
                                             t_total=num_train_steps)

        dec_param_optimizer = list(model.named_parameters())  # TODO:  For other parameters
        print('### Optim All: {:}'.format(len(dec_param_optimizer)))
        dec_param_optimizer = [p for (n, p) in dec_param_optimizer if 'bert' not in n]
        print('### Optim OTH: {:}'.format(len(dec_param_optimizer)))
        dec_optimizer = AdamW(dec_param_optimizer, lr=args.dec_lr)
        dec_scheduler = WarmupLinearSchedule(dec_optimizer, int(num_train_steps * args.dec_warmup),
                                             t_total=num_train_steps)

    if args.recover_e > 0:
        model_recover, enc_recover, dec_recover = load(args, str(args.recover_e))
        print("### Recover Model E{:}".format(args.recover_e))
        sys.stdout.flush()
        model.load_state_dict(model_recover)
        print("### Recover Optim E{:}".format(args.recover_e))
        sys.stdout.flush()
        enc_optimizer.load_state_dict(enc_recover)
        dec_optimizer.load_state_dict(dec_optimizer)

    if n_gpu > 1:
        model = torch.nn.DataParallel(model)

    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data,
                                  sampler=train_sampler,
                                  batch_size=args.batch_size,
                                  collate_fn=train_data.collate_fn,
                                  num_workers=args.num_workers,
                                  worker_init_fn=worker_init_fn)

    loss_fnc = nn.CrossEntropyLoss()
    best_score = {'epoch': 0, 'joint_acc': 0, 'op_acc': 0, 'final_slot_f1': 0}

    start_time = time.time()

    for epoch in range(args.n_epochs):
        batch_loss = []
        model.train()
        for step, batch in enumerate(train_dataloader):

            batch = [b.to(device) if (not isinstance(b, int)) and (not isinstance(b, dict) and (not isinstance(b, list)) and (not isinstance(b, np.ndarray))) else b for b in batch]

            input_ids_p, segment_ids_p, input_mask_p, \
            state_position_ids, op_ids, domain_ids, input_ids_g, segment_ids_g, position_ids_g, input_mask_g, \
            masked_pos, masked_weights, lm_label_ids, id_n_map, gen_max_len, n_total_pred = batch

            domain_scores, state_scores, loss_g = model(input_ids_p, segment_ids_p, input_mask_p, state_position_ids,
                input_ids_g, segment_ids_g, position_ids_g, input_mask_g,
                masked_pos, masked_weights, lm_label_ids, id_n_map, gen_max_len, only_pred_op=args.only_pred_op, n_gpu=n_gpu)

            if n_total_pred > 0:
                loss_g = loss_g.sum() / n_total_pred
            else:
                loss_g = 0

            loss_s = loss_fnc(state_scores.view(-1, len(op2id)), op_ids.view(-1))

            if args.only_pred_op:
                loss = loss_s
            else:
                loss = loss_s + loss_g

            if args.exclude_domain is not True:
                loss_d = loss_fnc(domain_scores.view(-1, len(domain2id)), domain_ids.view(-1))
                loss = loss + loss_d

            batch_loss.append(loss.item())

            loss.backward()

            if args.use_one_optim:
                optimizer.step()
                scheduler.step()
            else:
                enc_optimizer.step()
                enc_scheduler.step()
                dec_optimizer.step()
                dec_scheduler.step()

            model.zero_grad()

            if step % 100 == 0:
                try:
                    loss_g = loss_g.item()
                except AttributeError:
                    loss_g = loss_g

                if args.exclude_domain is not True:
                    print("time %.1f min, [%d/%d] [%d/%d] mean_loss : %.3f, state_loss : %.3f, gen_loss : %.3f, dom_loss : %.3f" \
                          % ((time.time()-start_time)/60, epoch+1, args.n_epochs, step,
                             len(train_dataloader), np.mean(batch_loss),
                             loss_s.item(), loss_g, loss_d.item()))
                else:
                    print("time %.1f min, [%d/%d] [%d/%d] mean_loss : %.3f, state_loss : %.3f, gen_loss : %.3f" \
                          % ((time.time()-start_time)/60, epoch+1, args.n_epochs, step,
                             len(train_dataloader), np.mean(batch_loss),
                             loss_s.item(), loss_g))

                sys.stdout.flush()
                batch_loss = []

        if args.use_one_optim:
            save(args, epoch + 1, model, optimizer)
        else:
            save(args, epoch + 1, model, enc_optimizer, dec_optimizer)

        if ((epoch+1) % args.eval_epoch == 0) and (epoch+1 >= 8):
            eval_res = model_evaluation(model, dev_data_raw, tokenizer, slot_meta, epoch+1, args.op_code,
                                        use_full_slot=args.use_full_slot, use_dt_only=args.use_dt_only, no_dial=args.no_dial, use_cls_only=args.use_cls_only, n_gpu=n_gpu)
            print("### Epoch {:} Score : ".format(epoch+1), eval_res)

            if eval_res['joint_acc'] > best_score['joint_acc']:
                best_score = eval_res
                print("### Best Joint Acc: {:} ###".format(best_score['joint_acc']))
                print('\n')

                if epoch+1 >= 8:  # To speed up
                    eval_res_test = model_evaluation(model, test_data_raw, tokenizer, slot_meta, epoch + 1, args.op_code,
                                                     use_full_slot=args.use_full_slot, use_dt_only=args.use_dt_only, no_dial=args.no_dial, use_cls_only=args.use_cls_only, n_gpu=n_gpu)
                    print("### Epoch {:} Test Score : ".format(epoch + 1), eval_res_test)
예제 #3
0
def main():    
    parser = argparse.ArgumentParser("")
    parser.add_argument("--model", type=str, default='')    
    parser.add_argument("--resume", action='store_true')
    parser.add_argument("--eval", action='store_true')
    parser.add_argument("--batch_size", type=int, default=CFG.batch_size)
    parser.add_argument("--nepochs", type=int, default=CFG.num_train_epochs)    
    parser.add_argument("--wsteps", type=int, default=CFG.warmup_steps)
    parser.add_argument("--nlayers", type=int, default=CFG.num_hidden_layers)
    parser.add_argument("--nahs", type=int, default=CFG.num_attention_heads)
    parser.add_argument("--seed", type=int, default=7)
    parser.add_argument("--lr", type=float, default=CFG.learning_rate)
    parser.add_argument("--dropout", type=float, default=CFG.dropout)
    parser.add_argument("--types", nargs='+', type=str, 
                        default=['1JHC', '1JHN', '2JHC', '2JHH', '2JHN', '3JHC', '3JHH', '3JHN'], 
                        help='3JHC,2JHC,1JHC,3JHH,2JHH,3JHN,2JHN,1JHN')
    parser.add_argument("--train_file", default="train_mute_cp")
    parser.add_argument("--test_file", default="test_mute_cp")
    parser.add_argument("--pseudo_path", default="")
    parser.add_argument("--pseudo", action='store_true')
    parser.add_argument("--gen_pseudo", action='store_true')
    parser.add_argument("--use_all", action='store_true')
    parser.add_argument("--structure_file", default="structures_mu")
    parser.add_argument("--contribution_file", default="scalar_coupling_contributions")        
    args = parser.parse_args()
    print(args) 
    
    CFG.batch_size=args.batch_size
    CFG.num_train_epochs=args.nepochs
    CFG.warmup_steps=args.wsteps
    CFG.num_hidden_layers=args.nlayers
    CFG.num_attention_heads=args.nahs
    CFG.learning_rate=args.lr
    CFG.dropout=args.dropout
    CFG.seed =  args.seed
    print(CFG.__dict__)
    
    random.seed(CFG.seed)
    np.random.seed(CFG.seed)
    torch.manual_seed(CFG.seed)
    
    #if not args.eval:    
    if True:
        train_df = load_csv(args.train_file)
        
        structures_df = load_csv(args.structure_file)  
        structures_df[['x', 'y', 'z']] -= structures_df.groupby('molecule_name')[['x', 'y', 'z']].transform('mean')        
        
        contributions_df = load_csv(args.contribution_file)
        train_df = train_df.merge(contributions_df, how='left')   
        train_df = normalize_cols(train_df, ['scalar_coupling_constant', 'fc', 'sd', 'pso', 'dso'])        
        train_df = add_extra_features(train_df, structures_df)
        train_df = train_df.fillna(1e08)
        n_mols = train_df['molecule_name'].nunique()
        train_df, valid_df = train_test_split(train_df, 5000 )
        
        # only molecules with the args.types
        print(train_df['molecule_name'].nunique())
        mol_names_with_at = train_df[train_df['type'].isin(args.types)]['molecule_name'].unique()
        train_df = train_df[train_df['molecule_name'].isin(mol_names_with_at)].reset_index(drop=True)
        print(train_df['molecule_name'].nunique())
        
        # Print the 5 rows of valid_df to verify whether the valid_df is the same as the previous experiment.
        print(valid_df.head(5))
        
        if args.pseudo:        
            test_df = load_csv(args.test_file)
            logger.info(f'loading dataset - {args.pseudo_path} ...')
            test_pseudo_df = pd.read_csv(args.pseudo_path)
            #mol_names_jhn = train_df[test_df['type'].isin(['1JHN', '2JHN', '3JHN'])]['molecule_name'].unique()
            #test_df = test_df[test_df['molecule_name'].isin(mol_names_jhn)].reset_index(drop=True)        
            test_df = add_extra_features(test_df, structures_df)
            test_df = test_df.set_index('id')
            test_pseudo_df = test_pseudo_df.set_index('id')
            test_df[['scalar_coupling_constant',  'fc', 'sd', 'pso', 'dso']] = test_pseudo_df[['scalar_coupling_constant',  'fc', 'sd', 'pso', 'dso']]
            test_df = test_df.reset_index()            
            #test_df = normalize_target(test_df)
            test_df = normalize_cols(test_df, ['scalar_coupling_constant', 'fc', 'sd', 'pso', 'dso'])
            #test_df = test_df.assign(fc=1e08, sd=1e08, pso=1e08, dso=1e08)
            train_df['weight'] = 1.0
            valid_df['weight'] = 1.0
            test_df['weight'] = 1.0
            n_mols = test_df['molecule_name'].nunique()            
            train_df = train_df.append(test_df).reset_index(drop=True)
        else:
            train_df['weight'] = 1.0
            valid_df['weight'] = 1.0
        
        if args.use_all:
            train_df = train_df.append(valid_df) 
        
        print(f' n_train:{len(train_df)}, n_valid:{len(valid_df)}')
    
    config = BertConfig(            
            3, # not used
            hidden_size=CFG.hidden_size,
            num_hidden_layers=CFG.num_hidden_layers,
            num_attention_heads=CFG.num_attention_heads,
            intermediate_size=CFG.intermediate_size,
            hidden_dropout_prob=CFG.dropout,
            attention_probs_dropout_prob=CFG.dropout,
        )    
    model = cust_model.SelfAttn(config)
    if args.model != "":
        print("=> loading checkpoint '{}'".format(args.model))
        checkpoint = torch.load(args.model)
        CFG.start_epoch = checkpoint['epoch']        
        model.load_state_dict(checkpoint['state_dict'])        
        print("=> loaded checkpoint '{}' (epoch {})"
              .format(args.model, checkpoint['epoch']))
    model.cuda()
    
    def count_parameters(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('parameters: ', count_parameters(model))
    
    n_gpu = torch.cuda.device_count()
    if n_gpu > 1:
        model = torch.nn.DataParallel(model)
    
    # to produce the submission.csv
    if args.eval:
        test_df = load_csv(args.test_file)
        structures_df = load_csv(args.structure_file)
        structures_df[['x', 'y', 'z']] -= structures_df.groupby('molecule_name')[['x', 'y', 'z']].transform('mean')        
        test_df = add_extra_features(test_df, structures_df)
        test_df = test_df.assign(fc=1e08, sd=1e08, pso=1e08, dso=1e08) 
        test_df['scalar_coupling_constant'] = 0
        test_df['weight'] = 1.0
        test_db = db.MolDB(test_df, CFG.max_seq_length)
        test_loader = DataLoader(
            test_db, batch_size=CFG.batch_size, shuffle=False,
            num_workers=CFG.num_workers)
        res_df = validate(test_loader, model, args.types)        
        res_df = unnormalize_cols(res_df, cols=['fc', 'sd', 'pso', 'dso'])
        res_df = unnormalize_target(res_df, 'prediction1')
        if args.gen_pseudo:
            res_df['scalar_coupling_constant'] = res_df['prediction1']
            res_df = res_df[res_df['id']>-1].sort_values('id')
            res_df[['id', 'scalar_coupling_constant', 'fc', 'sd', 'pso', 'dso']].to_csv(f'pseudo_{CFG.seed}.csv', index=False)
            return
        res_df['prediction4']= res_df[['fc', 'sd', 'pso', 'dso']].sum(1)
        res_df['prediction']= res_df[['prediction1','prediction4']].mean(1)        
        res_df['scalar_coupling_constant'] = res_df['prediction']
        res_df = res_df[res_df['id']>-1].sort_values('id')
        os.makedirs('output', exist_ok=True)
        res_df[['id', 'scalar_coupling_constant']].to_csv(f'output/submission_{CFG.seed}.csv', index=False)        
        return
    
    train_db = db.MolDB(train_df, CFG.max_seq_length)    
    print('preloading dataset ...')
    train_db = db.MolDB_FromDB(train_db, 10)    
    valid_db = db.MolDB(valid_df, CFG.max_seq_length)    
    num_train_optimization_steps = int(
        len(train_db) / CFG.batch_size / CFG.gradient_accumulation_steps) * (CFG.num_train_epochs-CFG.start_epoch)
    print('num_train_optimization_steps', num_train_optimization_steps)      

    train_loader = DataLoader(
        train_db, batch_size=CFG.batch_size, shuffle=True,
        num_workers=CFG.num_workers, pin_memory=True)
    val_loader = DataLoader(
        valid_db, batch_size=CFG.batch_size, shuffle=False,
        num_workers=CFG.num_workers)
    
    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    
    optimizer = AdamW(optimizer_grouped_parameters,
                           lr=CFG.learning_rate,
                           weight_decay=CFG.weight_decay,                           
                           )
    scheduler = WarmupLinearSchedule(optimizer, CFG.warmup_steps,
                                        t_total=num_train_optimization_steps
                                     )
    
    def get_lr():
        return scheduler.get_lr()[0]
    
    if args.model != "":
        if args.resume:
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
        #for param_group in optimizer.param_groups:
        #    param_group['lr'] = CFG.learning_rate
        mae_log_df = checkpoint['mae_log']
        del checkpoint
    else:
        mae_log_df = pd.DataFrame(columns=(['EPOCH']+['LR']+args.types + ['OVERALL']) )     
    os.makedirs('log', exist_ok=True)
    
    
    res_df = validate(val_loader, model, args.types)        
    res_df = unnormalize_cols(res_df, cols=['scalar_coupling_constant', 'fc', 'sd', 'pso', 'dso'])
    res_df = unnormalize_target(res_df, 'prediction1')            
    res_df['prediction4']= res_df[['fc', 'sd', 'pso', 'dso']].sum(1)
    res_df['prediction']= res_df[['prediction1','prediction4']].mean(1)
    res_df.to_csv(f'log/valid_df_{"_".join(args.types)}.csv', index=False)
    overall_mae, maes = metric(res_df, args.types)
    print(overall_mae, maes)    
    
    
    curr_lr = get_lr()
    print(f'initial learning rate:{curr_lr}')
    for epoch in range(CFG.start_epoch, CFG.num_train_epochs):
        # train for one epoch
                
        #print(adjust_learning_rate(optimizer, epoch))    
        train(train_loader, model, optimizer, epoch, args.types, scheduler)
       
        if epoch % CFG.test_freq == 0:
            res_df = validate(val_loader, model, args.types)        
            res_df = unnormalize_cols(res_df, cols=['scalar_coupling_constant', 'fc', 'sd', 'pso', 'dso'])
            res_df = unnormalize_target(res_df, 'prediction1')            
            res_df['prediction4']= res_df[['fc', 'sd', 'pso', 'dso']].sum(1)
            res_df['prediction']= res_df[['prediction1','prediction4']].mean(1)
            res_df.to_csv(f'log/valid_df_{"_".join(args.types)}.csv', index=False)
            overall_mae, maes = metric(res_df, args.types)
            
            # write log file
            mae_row = dict([(typ, [mae]) for typ, mae in maes.items() if typ in args.types])
            mae_row.update({'EPOCH':(epoch),'OVERALL':overall_mae, 'LR':curr_lr})
            mae_log_df = mae_log_df.append(pd.DataFrame(mae_row), sort=False)
            print(mae_log_df.tail(20))        
            mae_log_df.to_csv(f'log/{"_".join(args.types)}.csv', index=False)
            
            #scheduler.step(overall_mae)
            curr_lr = get_lr()
            print(f'set the learning_rate: {curr_lr}')
            
            # evaluate on validation set
            batch_size = CFG.batch_size            
            pseudo_path = '' if not args.pseudo else '_' + args.pseudo_path 
            curr_model_name = (f'b{batch_size}_l{config.num_hidden_layers}_'
                               f'mh{config.num_attention_heads}_h{config.hidden_size}_'
                               f'd{CFG.dropout}_'
                               f'ep{epoch}_{"_".join(args.types)}_s{CFG.seed}{pseudo_path}.pt')
            model_to_save = model.module if hasattr(model, 'module') else model  # Only save the cust_model it-self    
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': 'transformer',
                'state_dict': model_to_save.state_dict(),
                'mae_log': mae_log_df,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                },
                FINETUNED_MODEL_PATH, curr_model_name
            )                                                
                                         
    print('done')
예제 #4
0
def train(args, train_dataset, val_dataset, model, tokenizer):
    """ Train the model """
    pretrained_model = model[0]
    adapter_model = model[1]

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size // args.gradient_accumulation_steps)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in adapter_model.named_parameters() if not any(nd in n for nd in no_decay)],
         'weight_decay': args.weight_decay},
        {'params': [p for n, p in adapter_model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        adapter_model, optimizer = amp.initialize(adapter_model, optimizer, opt_level=args.fp16_opt_level)
    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        pretrained_model = torch.nn.DataParallel(pretrained_model)
        adapter_model = torch.nn.DataParallel(adapter_model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        pretrained_model = torch.nn.parallel.DistributedDataParallel(pretrained_model, device_ids=[args.local_rank],
                                                          output_device=args.local_rank)
        adapter_model = torch.nn.parallel.DistributedDataParallel(adapter_model, device_ids=[args.local_rank],
                                                          output_device=args.local_rank)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num train examples = %d", len(train_dataset)) #logging.info(f"  Num train_examples = {len(train_examples)}")
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
                   args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    logger.info("Try resume from checkpoint")
    if args.restore:
        if os.path.exists(os.path.join(args.output_dir, 'global_step.bin')):
            logger.info("Load last checkpoint data")
            global_step = torch.load(os.path.join(args.output_dir, 'global_step.bin'))
            output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
            logger.info("Load from output_dir {}".format(output_dir))

            optimizer.load_state_dict(torch.load(os.path.join(output_dir, 'optimizer.bin')))
            scheduler.load_state_dict(torch.load(os.path.join(output_dir, 'scheduler.bin')))
            # args = torch.load(os.path.join(output_dir, 'training_args.bin'))
            if hasattr(adapter_model, 'module'):
                adapter_model.module.load_state_dict(torch.load(os.path.join(output_dir, 'pytorch_model.bin')))
            else:  # Take care of distributed/parallel training
                adapter_model.load_state_dict(torch.load(os.path.join(output_dir, 'pytorch_model.bin')))

            global_step += 1
            start_epoch = int(global_step / len(train_dataloader))
            start_step = global_step-start_epoch*len(train_dataloader)-1
            logger.info("Start from global_step={} epoch={} step={}".format(global_step, start_epoch, start_step))
            if args.local_rank in [-1, 0]:
                tb_writer = SummaryWriter(log_dir="runs/" + args.my_model_name, purge_step=global_step)

        else:
            global_step = 0
            start_epoch = 0
            start_step = 0
            if args.local_rank in [-1, 0]:
                tb_writer = SummaryWriter(log_dir="runs/" + args.my_model_name, purge_step=global_step)
            logger.info("Start from scratch")
    else:
        global_step = 0
        start_epoch = 0
        start_step = 0
        if args.local_rank in [-1, 0]:
            tb_writer = SummaryWriter(log_dir="runs/" + args.my_model_name, purge_step=global_step)
        logger.info("Start from scratch")

    tr_loss, logging_loss = 0.0, 0.0
    pretrained_model.zero_grad()
    adapter_model.zero_grad()

    set_seed(args)  # Added here for reproductibility (even between python 2 and 3)

    for epoch in range(start_epoch, int(args.num_train_epochs)):
        for step, batch in enumerate(train_dataloader):
            start = time.time()
            if args.restore and (step < start_step):
                continue
            # if args.restore and (flag_count < global_step):
            #     flag_count+=1
            #     continue
            pretrained_model.eval()
            adapter_model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {'input_ids':      batch[0],
                      'attention_mask': batch[1],
                      'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None,  # XLM and RoBERTa don't use segment_ids
                      'labels':         batch[3]}
            pretrained_model_outputs = pretrained_model(**inputs)
            outputs = adapter_model(pretrained_model_outputs,**inputs)

            loss = outputs[0]  # model outputs are always tuple in pytorch-transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean() # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            # epoch_iterator.set_description("loss {}".format(loss))
            logger.info("Epoch {}/{} - Iter {} / {}, loss = {:.5f}, time used = {:.3f}s".format(epoch, int(args.num_train_epochs),step,
                                                                                             len(train_dataloader),
                                                                                             loss.item(),
                                                                                             time.time() - start))
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(adapter_model.parameters(), args.max_grad_norm)


            tr_loss += loss.item()

            if (step + 1) % args.gradient_accumulation_steps == 0:
                scheduler.step()  # Update learning rate schedule
                optimizer.step()
                pretrained_model.zero_grad()
                adapter_model.zero_grad()
                global_step += 1
                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)

                    logging_loss = tr_loss

                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = adapter_model.module if hasattr(adapter_model,
                                                            'module') else adapter_model  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)  # save to pytorch_model.bin  model.state_dict()

                    torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer.bin'))
                    torch.save(scheduler.state_dict(), os.path.join(output_dir, 'scheduler.bin'))
                    torch.save(args, os.path.join(output_dir, 'training_args.bin'))
                    torch.save(global_step, os.path.join(args.output_dir, 'global_step.bin'))

                    logger.info("Saving model checkpoint, optimizer, global_step to %s", output_dir)
                    if (global_step/args.save_steps) > args.max_save_checkpoints:
                        try:
                            shutil.rmtree(os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step-args.max_save_checkpoints*args.save_steps)))
                        except OSError as e:
                            print(e)
                if args.local_rank == -1 and args.evaluate_during_training and global_step %args.eval_steps== 0:  # Only evaluate when single GPU otherwise metrics may not average well
                    model = (pretrained_model, adapter_model)
                    results = evaluate(args, val_dataset, model, tokenizer)
                    for key, value in results.items():
                        tb_writer.add_scalar('eval_{}'.format(key), value, global_step)

            if args.max_steps > 0 and global_step > args.max_steps:
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
예제 #5
0
class SmallTalk:
    def __init__(self, name, model_name, model_type='gpt2', opt_level=None, lr=6.25e-5, lm_coef=1.0, mc_coef=1.0, gradient_accumulation_steps=8, max_norm=1.0, device='cuda:0'):
        self.lr, self.lm_coef, self.mc_coef, self.gradient_accumulation_steps, self.max_norm, self.device = lr, lm_coef, mc_coef, gradient_accumulation_steps, max_norm, device
        self.name, self.model_name, self.model_type, self.opt_level = name, model_name, model_type, opt_level

        self.logger, self.tb_logger, self.checkpoint_handler = stu.setup_training_loggers(self.name)

        self.verbose = False
        self.epoch = 0

        # TODO: Add logger statement here
        model_class, tokenizer_class = (GPT2DoubleHeadsModel, GPT2Tokenizer) if self.model_type == 'gpt2' else (OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer)
        self.model, self.tokenizer = model_class.from_pretrained(self.model_name).to(self.device), tokenizer_class.from_pretrained(self.model_name)

        stu.add_special_tokens_(model=self.model, tokenizer=self.tokenizer)

        self.optimizer = AdamW(self.model.parameters(), lr=self.lr, correct_bias=True)

        if self.opt_level:
            self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level=self.opt_level)

        self.trainer = Engine(self.update)
        self.evaluator = Engine(self.inference)

    def update(self, engine, batch):
        self.model.train()
        batch = tuple(input_tensor.to(self.device) for input_tensor in batch)
        input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch
        (lm_loss), (mc_loss), *_ = self.model(
            input_ids, token_type_ids=token_type_ids, mc_token_ids=mc_token_ids,
            mc_labels=mc_labels, lm_labels=lm_labels
        )
        loss = (lm_loss * self.lm_coef + mc_loss * self.mc_coef) / self.gradient_accumulation_steps

        if self.opt_level:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_norm)

        if engine.state.iteration % self.gradient_accumulation_steps == 0:
            self.optimizer.step()
            self.optimizer.zero_grad()
        return loss.item()

    def inference(self, engine, batch):
        self.model.eval()
        with torch.no_grad():
            batch = tuple(input_tensor.to(self.device) for input_tensor in batch)
            input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch

            if self.verbose:
                self.logger.info(self.tokenizer.decode(input_ids[0, -1, :].tolist()))

            # if we dont send labels to model, it doesnt return losses
            lm_logits, mc_logits, *_ = self.model(
                input_ids, token_type_ids=token_type_ids, mc_token_ids=mc_token_ids,
            )

            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(-1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)

            return (lm_logits_flat_shifted, mc_logits), (lm_labels_flat_shifted, mc_labels)

    def train_model(self, n_epochs, train_loader, val_loader, eval_before_start=True):
        # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
        self.trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: self.evaluator.run(val_loader))
        self.trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: self.update_epoch())
        if eval_before_start:
            self.trainer.add_event_handler(Events.STARTED, lambda _: self.evaluator.run(val_loader))

        # Linearly decrease the learning rate from lr to zero
        scheduler = PiecewiseLinear(self.optimizer, "lr", [(0, self.lr), (n_epochs * len(train_loader), 0.0)])
        self.trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

        # Prepare metrics
        RunningAverage(output_transform=lambda x: x).attach(self.trainer, "loss")
        metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1), output_transform=lambda x: (x[0][0], x[1][0])),
                   "accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))}
        metrics["average_ppl"] = MetricsLambda(math.exp, metrics["nll"])
        for name, metric in metrics.items():
            metric.attach(self.evaluator, name)

        # On the main process: add progress bar, tensorboard, checkpoints and save model
        pbar = ProgressBar(persist=True)
        pbar.attach(self.trainer, metric_names=["loss"])

        if not self.verbose:
            pbar_eval = ProgressBar(persist=False)
            pbar_eval.attach(self.evaluator)

        self.evaluator.add_event_handler(Events.STARTED, lambda _: self.logger.info(f'Beginning validation for epoch {self.epoch}...'))
        self.evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Validation: %s" % pformat(self.evaluator.state.metrics)))

        self.tb_logger.attach(self.trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED)
        self.tb_logger.attach(self.trainer, log_handler=OptimizerParamsHandler(self.optimizer), event_name=Events.ITERATION_STARTED)
        self.tb_logger.attach(self.evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys()), another_engine=self.trainer),
                              event_name=Events.EPOCH_COMPLETED)

        self.trainer.add_event_handler(Events.EPOCH_COMPLETED, self.checkpoint_handler,
                                       {'mymodel': getattr(self.model, 'module', self.model)})  # "getattr" takes care of distributed encapsulation

        # Run the training
        self.trainer.run(train_loader, max_epochs=n_epochs)

        # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
        if n_epochs > 0:
            os.rename(self.checkpoint_handler._saved[-1][1][-1], os.path.join(cfg.checkpoint_log_folder, self.name, WEIGHTS_NAME))
            self.tb_logger.close()

    def save(self, path, inference_only=False):
        """ Saves important components of model to be imported later. """
        save_dict = {
            'epoch': self.epoch,
            'model_state_dict': self.model.state_dict(),
            'model_name': self.model_name,
            'model_type': self.model_type,
            'opt_level': self.opt_level
        }

        if not inference_only:
            save_dict['optimizer_state_dict'] = self.optimizer.state_dict()

        torch.save(save_dict, path)

    # TODO: May want to revisit here if we want to do evaluation on a cpu. See https://github.com/NVIDIA/apex/issues/242
    def load(self, path):
        """ Loads important components of model back into memory to pick up where we left off. """
        checkpoint = torch.load(path)
        assert self.model_type == checkpoint['model_type'], f"Model types do not match, current model is {self.model_type} and loaded model is {checkpoint['model_type']}"
        assert self.model_name == checkpoint['model_name'], f"Model names do not match, current model is {self.model_name} and loaded model is {checkpoint['model_name']}"
        assert self.opt_level == checkpoint['opt_level'], f"Model opt_levels do not match, current model is {self.opt_level} and loaded model is {checkpoint['opt_level']}"

        self.model.load_state_dict(checkpoint['model_state_dict'])
        if 'optimizer_state_dict' in checkpoint:
            self.logger.info('Optimizer information saved for continued training. Loading into model.')
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        else:
            self.logger.info('Model previously saved for inference only.')

        self.epoch = checkpoint['epoch']

    def load_checkpoint(self, path):
        """ Loads an entire checkpoint and overwrite model """
        self.model.load_state_dict(torch.load(path, map_location=self.device))

    def update_epoch(self):
        self.epoch += 1

    def get_num_params(self, trainable_only=True):
        if trainable_only:
            return sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        return sum(p.numel() for p in self.model.parameters())

    def interact(self, personality=None, max_history=2, max_length=20, min_length=1, temperature=0.7, top_k=0, top_p=0.9, no_sample=False, random_pause=None):
        """
        Interact with bot in python setting
        :param personality: Personality to use to condition model on for chat. None will pull a random one from training data set. List of several short sentences describing personality.
        :param max_history: Number of responses per individual to retain for model to generate text with (in addition to the utterance the model is directly responding to).
        :param max_length: Maximum length of output utterances
        :param min_length: Minimum length of output utterances
        :param temperature: Sampling softmax temperature. 1.0 is standard softmax, as it decreases it allows for less diversity in outputs (makes peaks higher in distribution).
        :param top_k: Filter top_k tokens before sampling (<=0 is no filtering)
        :param top_p: Nucleus filtering
        :param no_sample: Whether to simply choose the most likely token at each sample and skip fancy sampling methods above
        :param random_pause: Whether to pause for random amounts of time to seem more human (should be tuple of low and high value to randomly pause between).
        """
        personality = self.get_personality(personality=personality)
        self.logger.info(self.tokenizer.decode(list(chain(*personality))))

        self.model.eval()
        history = []

        self.logger.info('You may now begin talking to the bot. Don\'t be shy, say hello!')

        while True:
            raw_text = input('>>> ')
            while not raw_text:
                print('Please enter in a non-empty value.')
                raw_text = input('>>> ')
            history.append(self.tokenizer.encode(raw_text))

            if random_pause:
                assert len(random_pause) == 2, 'random_pause arg should be a tuple of length 2 if passed'
                time.sleep(random_pause[0] + random.random() * (random_pause[1] - random_pause[0]))

            with torch.no_grad():
                out_ids = stu.sample_sequence(personality=personality, history=history, tokenizer=self.tokenizer, model=self.model, device=self.device,
                                              max_length=max_length, min_length=min_length, temperature=temperature, top_k=top_k, top_p=top_p, no_sample=no_sample)
            history.append(out_ids)
            history = history[-(2 * max_history + 1):]
            out_text = self.tokenizer.decode(out_ids, skip_special_tokens=True)
            print(out_text)

    def get_reply(self, conversation_history, personality, max_history=2, max_length=20, min_length=1, temperature=0.7, top_k=0, top_p=0.9, no_sample=False, random_pause=None):
        """
        Based heavily on self.interact. See above documentation for detail on parameters.
        Alternate version of interact for use with chatbot. Uses ConversationHistory object to put together the history and return one reply at a time, rather than manage
        an entire conversation.
        """
        self.model.eval()

        # Build history object from ConversationHistory class
        history = conversation_history.get_list_of_conversation_latest_n_exchanges(n=max_history)
        history = [self.tokenizer.encode(msg) for msg in history]

        # Get ids from model
        with torch.no_grad():
            out_ids = stu.sample_sequence(personality=personality, history=history, tokenizer=self.tokenizer, model=self.model, device=self.device,
                                          max_length=max_length, min_length=min_length, temperature=temperature, top_k=top_k, top_p=top_p, no_sample=no_sample)

        return self.tokenizer.decode(out_ids, skip_special_tokens=True)

    def get_personality(self, personality=None):
        """
        Retrieves a random personality if personality is None, otherwise converts personality raw text to a format the model understands.
        :param personality: List of 4-5 sentences in raw text string form
        """
        if personality is None:
            return stu.get_random_personality(self)
        else:
            return [self.tokenizer.encode(sentence) for sentence in personality]

    def print_personality(self, personality):
        print(self.tokenizer.decode(chain(*personality)))
def main():
    my_parser = argparse.ArgumentParser()

    # Required parameters
    my_parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    my_parser.add_argument("--src_file",
                           default=None,
                           type=str,
                           help="The input data file name.")
    my_parser.add_argument("--model_type",
                           default=None,
                           type=str,
                           required=True,
                           help="Model type selected in the list: " +
                           ", ".join(MODEL_CLASSES.keys()))
    my_parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list: "
        + ", ".join(ALL_MODELS))
    my_parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    my_parser.add_argument(
        "--log_dir",
        default='',
        type=str,
        help="The output directory where the log will be written.")
    my_parser.add_argument("--model_recover_path",
                           default=None,
                           type=str,
                           help="The file of fine-tuned pretraining model.")
    my_parser.add_argument("--optim_recover_path",
                           default=None,
                           type=str,
                           help="The file of pretraining optimizer.")
    my_parser.add_argument(
        "--config_name",
        default="",
        type=str,
        help="Pretrained config name or path if not the same as model_name")
    my_parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name")

    # Other parameters
    my_parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    my_parser.add_argument('--max_position_embeddings',
                           type=int,
                           default=None,
                           help="max position embeddings")
    my_parser.add_argument("--do_train",
                           action='store_true',
                           help="Whether to run training.")
    my_parser.add_argument("--do_eval",
                           action='store_true',
                           help="Whether to run eval on the dev set.")
    my_parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    my_parser.add_argument("--train_batch_size",
                           default=32,
                           type=int,
                           help="Total batch size for training.")
    my_parser.add_argument("--eval_batch_size",
                           default=64,
                           type=int,
                           help="Total batch size for eval.")
    my_parser.add_argument("--learning_rate",
                           default=5e-5,
                           type=float,
                           help="The initial learning rate for Adam.")
    my_parser.add_argument("--label_smoothing",
                           default=0.1,
                           type=float,
                           help="The initial learning rate for Adam.")
    my_parser.add_argument("--weight_decay",
                           default=0.01,
                           type=float,
                           help="The weight decay rate for Adam.")
    my_parser.add_argument("--adam_epsilon",
                           default=1e-8,
                           type=float,
                           help="Epsilon for Adam optimizer.")
    my_parser.add_argument("--max_grad_norm",
                           default=1.0,
                           type=float,
                           help="Max gradient norm.")
    my_parser.add_argument("--num_train_epochs",
                           default=3.0,
                           type=float,
                           help="Total number of training epochs to perform.")
    my_parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    my_parser.add_argument("--hidden_dropout_prob",
                           default=0.1,
                           type=float,
                           help="Dropout rate for hidden states.")
    my_parser.add_argument("--attention_probs_dropout_prob",
                           default=0.1,
                           type=float,
                           help="Dropout rate for attention probabilities.")
    my_parser.add_argument("--no_cuda",
                           action='store_true',
                           help="Whether not to use CUDA when available")
    my_parser.add_argument("--local_rank",
                           type=int,
                           default=-1,
                           help="local_rank for distributed training on gpus")
    my_parser.add_argument('--seed',
                           type=int,
                           default=42,
                           help="random seed for initialization")
    my_parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    my_parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    my_parser.add_argument(
        '--fp16_opt_level',
        type=str,
        default='O1',
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html")
    my_parser.add_argument('--tokenized_input',
                           action='store_true',
                           help="Whether the input is tokenized.")
    my_parser.add_argument(
        '--max_len_a',
        type=int,
        default=0,
        help="Truncate_config: maximum length of segment A.")
    my_parser.add_argument(
        '--max_len_b',
        type=int,
        default=0,
        help="Truncate_config: maximum length of segment B.")
    my_parser.add_argument(
        '--trunc_seg',
        default='',
        help="Truncate_config: first truncate segment A/B (option: a, b).")
    my_parser.add_argument(
        '--always_truncate_tail',
        action='store_true',
        help="Truncate_config: Whether we should always truncate tail.")
    my_parser.add_argument(
        "--mask_prob",
        default=0.20,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    my_parser.add_argument(
        "--mask_prob_eos",
        default=0,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    my_parser.add_argument('--max_pred',
                           type=int,
                           default=69,
                           help="Max tokens of prediction.")
    my_parser.add_argument("--num_workers",
                           default=0,
                           type=int,
                           help="Number of workers for the data loader.")

    my_parser.add_argument('--mask_source_words',
                           action='store_true',
                           help="Whether to mask source words for training")
    my_parser.add_argument('--skipgram_prb',
                           type=float,
                           default=0.0,
                           help='prob of ngram mask')
    my_parser.add_argument('--skipgram_size',
                           type=int,
                           default=1,
                           help='the max size of ngram mask')
    my_parser.add_argument('--mask_whole_word',
                           action='store_true',
                           help="Whether masking a whole word.")

    args = my_parser.parse_args()

    if not (args.model_recover_path
            and Path(args.model_recover_path).exists()):
        args.model_recover_path = None

    args.output_dir = args.output_dir.replace('[PT_OUTPUT_DIR]',
                                              os.getenv('PT_OUTPUT_DIR', ''))
    args.log_dir = args.log_dir.replace('[PT_OUTPUT_DIR]',
                                        os.getenv('PT_OUTPUT_DIR', ''))

    os.makedirs(args.output_dir, exist_ok=True)
    if args.log_dir:
        os.makedirs(args.log_dir, exist_ok=True)
    json.dump(args.__dict__,
              open(os.path.join(args.output_dir, 'opt.json'), 'w'),
              sort_keys=True,
              indent=2)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        dist.init_process_group(backend='nccl')
    my_logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    args.model_type = args.model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    config = config_class.from_pretrained(
        args.config_name if args.config_name else args.model_name_or_path,
        max_position_embeddings=args.max_position_embeddings,
        label_smoothing=args.label_smoothing)
    tokenizer = tokenizer_class.from_pretrained(
        args.tokenizer_name
        if args.tokenizer_name else args.model_name_or_path,
        do_lower_case=args.do_lower_case)
    data_tokenizer = WhitespaceTokenizer(
    ) if args.tokenized_input else tokenizer
    if args.local_rank == 0:
        dist.barrier()

    if args.do_train:
        print("Loading Train Dataset", args.data_dir)
        bi_uni_pipeline = [
            utils_seq2seq.Preprocess4Seq2seq(
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                mask_source_words=False,
                skipgram_prb=args.skipgram_prb,
                skipgram_size=args.skipgram_size,
                mask_whole_word=args.mask_whole_word,
                tokenizer=data_tokenizer)
        ]

        file = os.path.join(args.data_dir,
                            args.src_file if args.src_file else 'train.tgt')
        train_dataset = utils_seq2seq.Seq2SeqDataset(
            file,
            args.train_batch_size,
            data_tokenizer,
            args.max_seq_length,
            bi_uni_pipeline=bi_uni_pipeline)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset, replacement=False)
            _batch_size = args.train_batch_size
        else:
            train_sampler = DistributedSampler(train_dataset)
            _batch_size = args.train_batch_size // dist.get_world_size()
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=_batch_size,
            sampler=train_sampler,
            num_workers=args.num_workers,
            collate_fn=utils_seq2seq.batch_list_to_batch_tensors,
            pin_memory=False)
        print("Loading dev dataset")
        dev_file = os.path.join(args.data_dir, 'dev_data.json')
        dev_dataset = utils_seq2seq.Seq2SeqDataset(
            dev_file,
            args.eval_batch_size,
            data_tokenizer,
            args.max_seq_length,
            bi_uni_pipeline=bi_uni_pipeline)
        dev_dataloader = torch.utils.data.DataLoader(
            dev_dataset,
            batch_size=args.eval_batch_size,
            collate_fn=utils_seq2seq.batch_list_to_batch_tensors,
            pin_memory=False,
            num_workers=args.num_workers)

    # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps)
    # t_total = int(math.ceil(len(train_dataset.ex_list) / args.train_batch_size)
    t_total = int(
        len(train_dataloader) * args.num_train_epochs /
        args.gradient_accumulation_steps)

    # Prepare model
    recover_step = _get_max_epoch_model(args.output_dir)
    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    global_step = 0
    if (recover_step is None) and (args.model_recover_path is None):
        model_recover = None
    else:
        if recover_step:
            my_logger.info("***** Recover model: %d *****", recover_step)
            model_recover = torch.load(os.path.join(
                args.output_dir, "model.{0}.bin".format(recover_step)),
                                       map_location='cpu')
            # recover_step == number of epochs
            global_step = math.floor(recover_step * t_total /
                                     args.num_train_epochs)
        elif args.model_recover_path:
            my_logger.info("***** Recover model: %s *****",
                           args.model_recover_path)
            model_recover = torch.load(args.model_recover_path,
                                       map_location='cpu')
    model = model_class.from_pretrained(args.model_name_or_path,
                                        state_dict=model_recover,
                                        config=config)
    if args.local_rank == 0:
        dist.barrier()

    model.to(device)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(args.warmup_proportion * t_total),
        num_training_steps=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    if args.local_rank != -1:
        try:
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("DistributedDataParallel")
        model = DDP(model,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank,
                    find_unused_parameters=True)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    if recover_step:
        my_logger.info("***** Recover optimizer: %d *****", recover_step)
        optim_recover = torch.load(os.path.join(
            args.output_dir, "optim.{0}.bin".format(recover_step)),
                                   map_location='cpu')
        if hasattr(optim_recover, 'state_dict'):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)

        if os.path.exists(
                os.path.join(args.output_dir,
                             "amp.{0}.bin".format(recover_step))):
            my_logger.info("***** Recover amp: %d *****", recover_step)
            amp_recover = torch.load(os.path.join(
                args.output_dir, "amp.{0}.bin".format(recover_step)),
                                     map_location='cpu')
            amp.load_state_dict(amp_recover)

        my_logger.info("***** Recover scheduler: %d *****", recover_step)
        scheduler_recover = torch.load(os.path.join(
            args.output_dir, "sched.{0}.bin".format(recover_step)),
                                       map_location='cpu')
        scheduler.load_state_dict(scheduler_recover)

    my_logger.info("***** CUDA.empty_cache() *****")
    torch.cuda.empty_cache()

    if args.do_train:
        my_logger.info("***** Running training *****")
        my_logger.info("  Batch size = %d", args.train_batch_size)
        my_logger.info("  Num steps = %d", t_total)

        model.train()
        if recover_step:
            start_epoch = recover_step + 1
        else:
            start_epoch = 1
        for i_epoch in trange(start_epoch,
                              int(args.num_train_epochs) + 1,
                              desc="Epoch",
                              disable=args.local_rank not in (-1, 0)):
            if args.local_rank != -1:
                train_sampler.set_epoch(i_epoch)
            iter_bar = tqdm(train_dataloader,
                            desc='Iter (loss=X.XXX)',
                            disable=args.local_rank not in (-1, 0))
            final_loss = 0
            for step, batch in enumerate(iter_bar):
                batch = [
                    t.to(device) if t is not None else None for t in batch
                ]
                input_ids, segment_ids, answer_tag, input_mask, lm_label_ids, masked_pos, masked_weights, _ = batch
                if answer_tag == None:
                    print("answer tag is none")
                masked_lm_loss = model(input_ids,
                                       segment_ids,
                                       answer_tag,
                                       input_mask,
                                       lm_label_ids,
                                       masked_pos=masked_pos,
                                       masked_weights=masked_weights)
                if n_gpu > 1:  # mean() to average on multi-gpu.
                    # loss = loss.mean()
                    masked_lm_loss = masked_lm_loss.mean()
                loss = masked_lm_loss
                final_loss = loss.item()

                # logging for each step (i.e., before normalization by args.gradient_accumulation_steps)
                iter_bar.set_description('Iter (loss=%5.3f)' % loss.item())

                # ensure that accumlated gradients are normalized
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                if (step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    optimizer.zero_grad()
                    global_step += 1
            # Save a trained model
            if (args.local_rank == -1 or torch.distributed.get_rank() == 0):
                my_logger.info(
                    "** ** * Saving fine-tuned model and optimizer ** ** * ")
                model_to_save = model.module if hasattr(
                    model, 'module') else model  # Only save the model it-self
                output_model_file = os.path.join(
                    args.output_dir, "model.{0}.bin".format(i_epoch))
                torch.save(model_to_save.state_dict(), output_model_file)
                output_optim_file = os.path.join(
                    args.output_dir, "optim.{0}.bin".format(i_epoch))
                torch.save(optimizer.state_dict(), output_optim_file)
                if args.fp16:
                    output_amp_file = os.path.join(
                        args.output_dir, "amp.{0}.bin".format(i_epoch))
                    torch.save(amp.state_dict(), output_amp_file)
                output_sched_file = os.path.join(
                    args.output_dir, "sched.{0}.bin".format(i_epoch))
                torch.save(scheduler.state_dict(), output_sched_file)

                my_logger.info("***** CUDA.empty_cache() *****")
                torch.cuda.empty_cache()

            if args.do_eval:
                # do_eval
                iter_dev = tqdm(dev_dataloader,
                                desc='Iter (loss=X.XXX)',
                                disable=args.local_rank not in (-1, 0))
                val_losses = []
                for step, batch in enumerate(iter_dev):
                    with torch.no_grad():
                        batch = [
                            t.to(device) if t is not None else None
                            for t in batch
                        ]
                        input_ids, segment_ids, answer_tag, input_mask, lm_label_ids, masked_pos, masked_weights, _ = batch
                        masked_dev_loss = model(input_ids,
                                                segment_ids,
                                                answer_tag,
                                                input_mask,
                                                lm_label_ids,
                                                masked_pos=masked_pos,
                                                masked_weights=masked_weights)
                        val_losses.append(masked_dev_loss.item())
                val_loss = np.mean(val_losses)
                print(
                    "Epoch {} - final loss : {:.4f} - val loss :{:.4f}".format(
                        i_epoch, final_loss, val_loss))