예제 #1
0
def main():
    parser = argparse.ArgumentParser(
        description='Train the individual Transformer model')
    parser.add_argument('--dataset_folder', type=str, default='datasets')
    parser.add_argument('--dataset_name', type=str, default='zara1')
    parser.add_argument('--obs', type=int, default=8)
    parser.add_argument('--preds', type=int, default=12)
    parser.add_argument('--emb_size', type=int, default=1024)
    parser.add_argument('--heads', type=int, default=8)
    parser.add_argument('--layers', type=int, default=6)
    parser.add_argument('--dropout', type=float, default=0.1)
    parser.add_argument('--cpu', action='store_true')
    parser.add_argument('--output_folder', type=str, default='Output')
    parser.add_argument('--val_size', type=int, default=50)
    parser.add_argument('--gpu_device', type=str, default="0")
    parser.add_argument('--verbose', action='store_true')
    parser.add_argument('--max_epoch', type=int, default=100)
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--validation_epoch_start', type=int, default=30)
    parser.add_argument('--resume_train', action='store_true')
    parser.add_argument('--delim', type=str, default='\t')
    parser.add_argument('--name', type=str, default="zara1")

    args = parser.parse_args()
    model_name = args.name

    try:
        os.mkdir('models')
    except:
        pass
    try:
        os.mkdir('output')
    except:
        pass
    try:
        os.mkdir('output/BERT')
    except:
        pass
    try:
        os.mkdir(f'models/BERT')
    except:
        pass

    try:
        os.mkdir(f'output/BERT/{args.name}')
    except:
        pass

    try:
        os.mkdir(f'models/BERT/{args.name}')
    except:
        pass

    log = SummaryWriter('logs/BERT_%s' % model_name)

    log.add_scalar('eval/mad', 0, 0)
    log.add_scalar('eval/fad', 0, 0)

    try:
        os.mkdir(args.name)
    except:
        pass

    device = torch.device("cuda")
    if args.cpu or not torch.cuda.is_available():
        device = torch.device("cpu")

    args.verbose = True

    ## creation of the dataloaders for train and validation
    train_dataset, _ = baselineUtils.create_dataset(args.dataset_folder,
                                                    args.dataset_name,
                                                    0,
                                                    args.obs,
                                                    args.preds,
                                                    delim=args.delim,
                                                    train=True,
                                                    verbose=args.verbose)
    val_dataset, _ = baselineUtils.create_dataset(args.dataset_folder,
                                                  args.dataset_name,
                                                  0,
                                                  args.obs,
                                                  args.preds,
                                                  delim=args.delim,
                                                  train=False,
                                                  verbose=args.verbose)
    test_dataset, _ = baselineUtils.create_dataset(args.dataset_folder,
                                                   args.dataset_name,
                                                   0,
                                                   args.obs,
                                                   args.preds,
                                                   delim=args.delim,
                                                   train=False,
                                                   eval=True,
                                                   verbose=args.verbose)

    from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertConfig, AdamW

    config = BertConfig(vocab_size=30522,
                        hidden_size=768,
                        num_hidden_layers=12,
                        num_attention_heads=12,
                        intermediate_size=3072,
                        hidden_act='relu',
                        hidden_dropout_prob=0.1,
                        attention_probs_dropout_prob=0.1,
                        max_position_embeddings=512,
                        type_vocab_size=2,
                        initializer_range=0.02,
                        layer_norm_eps=1e-12)
    model = BertModel(config).to(device)

    from individual_TF import LinearEmbedding as NewEmbed, Generator as GeneratorTS
    a = NewEmbed(3, 768).to(device)
    model.set_input_embeddings(a)
    generator = GeneratorTS(768, 2).to(device)
    #model.set_output_embeddings(GeneratorTS(1024,2))

    tr_dl = torch.utils.data.DataLoader(train_dataset,
                                        batch_size=args.batch_size,
                                        shuffle=True,
                                        num_workers=0)
    val_dl = torch.utils.data.DataLoader(val_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=0)
    test_dl = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=args.batch_size,
                                          shuffle=False,
                                          num_workers=0)

    #optim = SGD(list(a.parameters())+list(model.parameters())+list(generator.parameters()),lr=0.01)
    #sched=torch.optim.lr_scheduler.StepLR(optim,0.0005)
    optim = NoamOpt(
        768, 0.1, len(tr_dl),
        torch.optim.Adam(list(a.parameters()) + list(model.parameters()) +
                         list(generator.parameters()),
                         lr=0,
                         betas=(0.9, 0.98),
                         eps=1e-9))
    #optim=Adagrad(list(a.parameters())+list(model.parameters())+list(generator.parameters()),lr=0.01,lr_decay=0.001)
    epoch = 0

    mean = train_dataset[:]['src'][:, :, 2:4].mean((0, 1)) * 0
    std = train_dataset[:]['src'][:, :, 2:4].std((0, 1)) * 0 + 1

    while epoch < args.max_epoch:
        epoch_loss = 0
        model.train()

        for id_b, batch in enumerate(tr_dl):

            optim.optimizer.zero_grad()
            r = 0
            rot_mat = np.array([[np.cos(r), np.sin(r)],
                                [-np.sin(r), np.cos(r)]])

            inp = ((batch['src'][:, :, 2:4] - mean) / std).to(device)
            inp = torch.matmul(inp,
                               torch.from_numpy(rot_mat).float().to(device))
            trg_masked = torch.zeros((inp.shape[0], args.preds, 2)).to(device)
            inp_cls = torch.ones(inp.shape[0], inp.shape[1], 1).to(device)
            trg_cls = torch.zeros(trg_masked.shape[0], trg_masked.shape[1],
                                  1).to(device)
            inp_cat = torch.cat((inp, trg_masked), 1)
            cls_cat = torch.cat((inp_cls, trg_cls), 1)
            net_input = torch.cat((inp_cat, cls_cat), 2)

            position = torch.arange(0, net_input.shape[1]).repeat(
                inp.shape[0], 1).long().to(device)
            token = torch.zeros(
                (inp.shape[0], net_input.shape[1])).long().to(device)
            attention_mask = torch.ones(
                (inp.shape[0], net_input.shape[1])).long().to(device)

            out = model(input_ids=net_input,
                        position_ids=position,
                        token_type_ids=token,
                        attention_mask=attention_mask)

            pred = generator(out[0])

            loss = F.pairwise_distance(
                pred[:, :].contiguous().view(-1, 2),
                torch.matmul(
                    torch.cat(
                        (batch['src'][:, :, 2:4], batch['trg'][:, :, 2:4]),
                        1).contiguous().view(-1, 2).to(device),
                    torch.from_numpy(rot_mat).float().to(device))).mean()
            loss.backward()
            optim.step()
            print("epoch %03i/%03i  frame %04i / %04i loss: %7.4f" %
                  (epoch, args.max_epoch, id_b, len(tr_dl), loss.item()))
            epoch_loss += loss.item()
        #sched.step()
        log.add_scalar('Loss/train', epoch_loss / len(tr_dl), epoch)
        with torch.no_grad():
            model.eval()

            gt = []
            pr = []
            val_loss = 0
            for batch in val_dl:
                inp = ((batch['src'][:, :, 2:4] - mean) / std).to(device)
                trg_masked = torch.zeros(
                    (inp.shape[0], args.preds, 2)).to(device)
                inp_cls = torch.ones(inp.shape[0], inp.shape[1], 1).to(device)
                trg_cls = torch.zeros(trg_masked.shape[0], trg_masked.shape[1],
                                      1).to(device)
                inp_cat = torch.cat((inp, trg_masked), 1)
                cls_cat = torch.cat((inp_cls, trg_cls), 1)
                net_input = torch.cat((inp_cat, cls_cat), 2)

                position = torch.arange(0, net_input.shape[1]).repeat(
                    inp.shape[0], 1).long().to(device)
                token = torch.zeros(
                    (inp.shape[0], net_input.shape[1])).long().to(device)
                attention_mask = torch.zeros(
                    (inp.shape[0], net_input.shape[1])).long().to(device)

                out = model(input_ids=net_input,
                            position_ids=position,
                            token_type_ids=token,
                            attention_mask=attention_mask)

                pred = generator(out[0])

                loss = F.pairwise_distance(
                    pred[:, :].contiguous().view(-1, 2),
                    torch.cat(
                        (batch['src'][:, :, 2:4], batch['trg'][:, :, 2:4]),
                        1).contiguous().view(-1, 2).to(device)).mean()
                val_loss += loss.item()

                gt_b = batch['trg'][:, :, 0:2]
                preds_tr_b = pred[:, args.obs:].cumsum(1).to(
                    'cpu').detach() + batch['src'][:, -1:, 0:2]
                gt.append(gt_b)
                pr.append(preds_tr_b)

            gt = np.concatenate(gt, 0)
            pr = np.concatenate(pr, 0)
            mad, fad, errs = baselineUtils.distance_metrics(gt, pr)
            log.add_scalar('validation/loss', val_loss / len(val_dl), epoch)
            log.add_scalar('validation/mad', mad, epoch)
            log.add_scalar('validation/fad', fad, epoch)

            model.eval()

            gt = []
            pr = []
            for batch in test_dl:
                inp = ((batch['src'][:, :, 2:4] - mean) / std).to(device)
                trg_masked = torch.zeros(
                    (inp.shape[0], args.preds, 2)).to(device)
                inp_cls = torch.ones(inp.shape[0], inp.shape[1], 1).to(device)
                trg_cls = torch.zeros(trg_masked.shape[0], trg_masked.shape[1],
                                      1).to(device)
                inp_cat = torch.cat((inp, trg_masked), 1)
                cls_cat = torch.cat((inp_cls, trg_cls), 1)
                net_input = torch.cat((inp_cat, cls_cat), 2)

                position = torch.arange(0, net_input.shape[1]).repeat(
                    inp.shape[0], 1).long().to(device)
                token = torch.zeros(
                    (inp.shape[0], net_input.shape[1])).long().to(device)
                attention_mask = torch.zeros(
                    (inp.shape[0], net_input.shape[1])).long().to(device)

                out = model(input_ids=net_input,
                            position_ids=position,
                            token_type_ids=token,
                            attention_mask=attention_mask)

                pred = generator(out[0])

                gt_b = batch['trg'][:, :, 0:2]
                preds_tr_b = pred[:, args.obs:].cumsum(1).to(
                    'cpu').detach() + batch['src'][:, -1:, 0:2]
                gt.append(gt_b)
                pr.append(preds_tr_b)

            gt = np.concatenate(gt, 0)
            pr = np.concatenate(pr, 0)
            mad, fad, errs = baselineUtils.distance_metrics(gt, pr)

            torch.save(model.state_dict(),
                       "models/BERT/%s/ep_%03i.pth" % (args.name, epoch))
            torch.save(generator.state_dict(),
                       "models/BERT/%s/gen_%03i.pth" % (args.name, epoch))
            torch.save(a.state_dict(),
                       "models/BERT/%s/emb_%03i.pth" % (args.name, epoch))

            log.add_scalar('eval/mad', mad, epoch)
            log.add_scalar('eval/fad', fad, epoch)

        epoch += 1

    ab = 1
def train(config, bert_config, train_path, dev_path, rel2id, id2rel,
          tokenizer):
    if os.path.exists(config.output_dir) is False:
        os.makedirs(config.output_dir, exist_ok=True)
    if os.path.exists('./data/train_file.pkl'):
        train_data = pickle.load(open("./data/train_file.pkl", mode='rb'))
    else:
        train_data = data.load_data(train_path, tokenizer, rel2id, num_rels)
        pickle.dump(train_data, open("./data/train_file.pkl", mode='wb'))
    dev_data = json.load(open(dev_path))
    for sent in dev_data:
        data.to_tuple(sent)
    data_manager = data.SPO(train_data)
    train_sampler = RandomSampler(data_manager)
    train_data_loader = DataLoader(data_manager,
                                   sampler=train_sampler,
                                   batch_size=config.batch_size,
                                   drop_last=True)
    num_train_steps = int(
        len(data_manager) / config.batch_size) * config.max_epoch

    if config.bert_pretrained_model is not None:
        logger.info('load bert weight')
        Bert_model = BertModel.from_pretrained(config.bert_pretrained_model,
                                               config=bert_config)
    else:
        logger.info('random initialize bert model')
        Bert_model = BertModel(config=bert_config).init_weights()
    Bert_model.to(device)
    submodel = sub_model(config).to(device)
    objmodel = obj_model(config).to(device)

    loss_fuc = nn.BCELoss(reduction='none')
    params = list(Bert_model.parameters()) + list(
        submodel.parameters()) + list(objmodel.parameters())
    optimizer = AdamW(params, lr=config.lr)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(data_manager))
    logger.info("  Num Epochs = %d", config.max_epoch)
    logger.info("  Total train batch size = %d", config.batch_size)
    logger.info("  Total optimization steps = %d", num_train_steps)
    logger.info("  Logging steps = %d", config.print_freq)
    logger.info("  Save steps = %d", config.save_freq)

    global_step = 0
    Bert_model.train()
    submodel.train()
    objmodel.train()

    for _ in range(config.max_epoch):
        optimizer.zero_grad()
        epoch_itorator = tqdm(train_data_loader, disable=None)
        for step, batch in enumerate(epoch_itorator):
            batch = tuple(t.to(device) for t in batch)
            input_ids, segment_ids, input_masks, sub_positions, sub_heads, sub_tails, obj_heads, obj_tails = batch

            bert_output = Bert_model(input_ids, input_masks, segment_ids)[0]
            pred_sub_heads, pred_sub_tails = submodel(
                bert_output)  # [batch_size, seq_len, 1]
            pred_obj_heads, pred_obj_tails = objmodel(bert_output,
                                                      sub_positions)

            # 计算loss
            mask = input_masks.view(-1)

            # loss1
            sub_heads = sub_heads.unsqueeze(-1)  # [batch_szie, seq_len, 1]
            sub_tails = sub_tails.unsqueeze(-1)

            loss1_head = loss_fuc(pred_sub_heads, sub_heads).view(-1)
            loss1_head = torch.sum(loss1_head * mask) / torch.sum(mask)

            loss1_tail = loss_fuc(pred_sub_tails, sub_tails).view(-1)
            loss1_tail = torch.sum(loss1_tail * mask) / torch.sum(mask)

            loss1 = loss1_head + loss1_tail

            # loss2
            loss2_head = loss_fuc(pred_obj_heads,
                                  obj_heads).view(-1, obj_heads.shape[-1])
            loss2_head = torch.sum(
                loss2_head * mask.unsqueeze(-1)) / torch.sum(mask)

            loss2_tail = loss_fuc(pred_obj_tails,
                                  obj_tails).view(-1, obj_tails.shape[-1])
            loss2_tail = torch.sum(
                loss2_tail * mask.unsqueeze(-1)) / torch.sum(mask)

            loss2 = loss2_head + loss2_tail

            # optimize
            loss = loss1 + loss2
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            global_step += 1
            if (global_step + 1) % config.print_freq == 0:
                logger.info(
                    "epoch : {} step: {} #### loss1: {}  loss2: {}".format(
                        _, global_step + 1,
                        loss1.cpu().item(),
                        loss2.cpu().item()))

            if (global_step + 1) % config.eval_freq == 0:
                logger.info("***** Running evaluating *****")
                with torch.no_grad():
                    Bert_model.eval()
                    submodel.eval()
                    objmodel.eval()
                    P, R, F1 = utils.metric(Bert_model, submodel, objmodel,
                                            dev_data, id2rel, tokenizer)
                    logger.info(f'precision:{P}\nrecall:{R}\nF1:{F1}')
                Bert_model.train()
                submodel.train()
                objmodel.train()

            if (global_step + 1) % config.save_freq == 0:
                # Save a trained model
                model_name = "pytorch_model_%d" % (global_step + 1)
                output_model_file = os.path.join(config.output_dir, model_name)
                state = {
                    'bert_state_dict': Bert_model.state_dict(),
                    'subject_state_dict': submodel.state_dict(),
                    'object_state_dict': objmodel.state_dict(),
                }
                torch.save(state, output_model_file)

    model_name = "pytorch_model_last"
    output_model_file = os.path.join(config.output_dir, model_name)
    state = {
        'bert_state_dict': Bert_model.state_dict(),
        'subject_state_dict': submodel.state_dict(),
        'object_state_dict': objmodel.state_dict(),
    }
    torch.save(state, output_model_file)
예제 #3
0
class BertForQuestionAnsweringWithCRF(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.bert = BertModel(config)
        self.hidden_size = self.bert.config.hidden_size
        self.CRF_fc1 = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(self.hidden_size, config.num_labels + 2, bias=True),
        )

        self.CRF = CRF(target_size=self.bert.config.num_labels,
                       device=torch.device("cuda"))
        self.CrossEntropyLoss = nn.CrossEntropyLoss()
        self.fc2 = nn.Linear(self.hidden_size, 2, bias=True)

    def forward(self, tokens_id_l, token_type_ids_l, answer_offset_l,
                answer_seq_label_l, IsQA_l):

        ## 字符ID [batch_size, seq_length]
        tokens_x_2d = torch.LongTensor(tokens_id_l).to(self.device)
        token_type_ids_2d = torch.LongTensor(token_type_ids_l).to(self.device)

        # 计算sql_len 不包含[CLS]
        batch_size, seq_length = tokens_x_2d[:, 1:].size()

        ## CRF答案ID [batch_size, seq_length]
        y_2d = torch.LongTensor(answer_seq_label_l).to(self.device)[:, 1:]
        ## (batch_size,)
        y_IsQA_2d = torch.LongTensor(IsQA_l).to(self.device)

        if self.training:  # self.training基层的外部类
            self.bert.train()
            output = self.bert(
                input_ids=tokens_x_2d,
                token_type_ids=token_type_ids_2d,
                output_hidden_states=True,
                return_dict=True)  #[batch_size, seq_len, hidden_size]
        else:
            self.bert.eval()
            with torch.no_grad():
                output = self.bert(input_ids=tokens_x_2d,
                                   token_type_ids=token_type_ids_2d,
                                   output_hidden_states=True,
                                   return_dict=True)

        ## [CLS] for IsQA  [batch_size, hidden_size]
        cls_emb = output.last_hidden_state[:, 0, :]

        IsQA_logits = self.fc2(cls_emb)  ## [batch_size, 2]
        IsQA_loss = self.CrossEntropyLoss.forward(IsQA_logits, y_IsQA_2d)

        ## [batch_size, 1]
        IsQA_prediction = IsQA_logits.argmax(dim=-1).unsqueeze(dim=-1)

        # CRF mask
        mask = np.ones(shape=[batch_size, seq_length], dtype=np.uint8)
        mask = torch.ByteTensor(mask).to(
            self.device)  # [batch_size, seq_len, 4]

        # No [CLS]
        crf_logits = self.CRF_fc1(output.last_hidden_state[:, 1:, :])
        crf_loss = self.CRF.neg_log_likelihood_loss(feats=crf_logits,
                                                    mask=mask,
                                                    tags=y_2d)
        _, CRFprediction = self.CRF.forward(feats=crf_logits, mask=mask)

        return IsQA_prediction, CRFprediction, IsQA_loss, crf_loss, y_2d, y_IsQA_2d.unsqueeze(
            dim=-1)  # (batch_size,) -> (batch_size, 1)
예제 #4
0
class DialogBERT(nn.Module):
    '''Hierarchical BERT for dialog v5 with two features:
    - Masked context utterances prediction with direct MSE matching of their vectors
    - Energy-based Utterance order prediction: A novel approach to shuffle the context and predict the original order with distributed order prediction'''

    # TODO: 1. Enhance sorting net
    #       2. Better data loader for permutation ((avoid returning perm_id and use max(pos_ids) instead,

    def __init__(self, args, base_model_name='bert-base-uncased'):
        super(DialogBERT, self).__init__()

        if args.language == 'chinese': base_model_name = 'bert-base-chinese'

        self.tokenizer = BertTokenizer.from_pretrained(base_model_name,
                                                       cache_dir='./cache/')
        if args.model_size == 'tiny':
            self.encoder_config = BertConfig(vocab_size=30522,
                                             hidden_size=256,
                                             num_hidden_layers=6,
                                             num_attention_heads=2,
                                             intermediate_size=1024)
            self.utt_encoder = BertForPreTraining(self.encoder_config)
        elif args.model_size == 'small':
            self.encoder_config = BertConfig(vocab_size=30522,
                                             hidden_size=512,
                                             num_hidden_layers=8,
                                             num_attention_heads=4,
                                             intermediate_size=2048)
            self.utt_encoder = BertForPreTraining(self.encoder_config)
        else:
            self.encoder_config = BertConfig.from_pretrained(
                base_model_name, cache_dir='./cache/')
            self.utt_encoder = BertForPreTraining.from_pretrained(
                base_model_name,
                config=self.encoder_config,
                cache_dir='./cache/')

        self.context_encoder = BertModel(
            self.encoder_config)  # context encoder: encode context to vector

        self.mlm_mode = 'mse'  # 'mdn', 'mse'
        if self.mlm_mode == 'mdn':
            self.context_mlm_trans = MixtureDensityNetwork(
                self.encoder_config.hidden_size,
                self.encoder_config.hidden_size, 3)
        else:
            self.context_mlm_trans = BertPredictionHeadTransform(
                self.encoder_config
            )  # transform context hidden states back to utterance encodings

        self.dropout = nn.Dropout(self.encoder_config.hidden_dropout_prob)
        self.context_order_trans = SelfSorting(self.encoder_config.hidden_size)
        #       self.context_order_trans = MLP(self.encoder_config.hidden_size, '200-200-200', 1)

        self.decoder_config = deepcopy(self.encoder_config)
        self.decoder_config.is_decoder = True
        self.decoder_config.add_cross_attention = True
        self.decoder = BertLMHeadModel(self.decoder_config)

    def init_weights(self, m):  # Initialize Linear Weight for GAN
        if isinstance(m, nn.Linear):
            m.weight.data.uniform_(-0.08,
                                   0.08)  #nn.init.xavier_normal_(m.weight)
            nn.init.constant_(m.bias, 0.)

    @classmethod
    def from_pretrained(self, model_dir):
        self.encoder_config = BertConfig.from_pretrained(model_dir)
        self.tokenizer = BertTokenizer.from_pretrained(
            path.join(model_dir, 'tokenizer'),
            do_lower_case=args.do_lower_case)
        self.utt_encoder = BertForPreTraining.from_pretrained(
            path.join(model_dir, 'utt_encoder'))
        self.context_encoder = BertForSequenceClassification.from_pretrained(
            path.join(model_dir, 'context_encoder'))
        self.context_mlm_trans = BertPredictionHeadTransform(
            self.encoder_config)
        self.context_mlm_trans.load_state_dict(
            torch.load(path.join(model_dir, 'context_mlm_trans.pkl')))
        self.context_order_trans = SelfSorting(self.encoder_config.hidden_size)
        self.context_order_trans.load_state_dict(
            torch.load(path.join(model_dir, 'context_order_trans.pkl')))
        self.decoder_config = BertConfig.from_pretrained(model_dir)
        self.decoder = BertLMHeadModel.from_pretrained(
            path.join(model_dir, 'decoder'))

    def save_pretrained(self, output_dir):
        def save_module(model, save_path):
            torch.save(model_to_save.state_dict(), save_path)

        def make_list_dirs(dir_list):
            for dir_ in dir_list:
                os.makedirs(dir_, exist_ok=True)

        make_list_dirs([
            path.join(output_dir, name) for name in
            ['tokenizer', 'utt_encoder', 'context_encoder', 'decoder']
        ])
        model_to_save = self.module if hasattr(self, 'module') else self
        model_to_save.encoder_config.save_pretrained(
            output_dir)  # Save configuration file
        model_to_save.tokenizer.save_pretrained(
            path.join(output_dir, 'tokenizer'))
        model_to_save.utt_encoder.save_pretrained(
            path.join(output_dir, 'utt_encoder'))
        model_to_save.context_encoder.save_pretrained(
            path.join(output_dir, 'context_encoder'))
        save_module(model_to_save.context_mlm_trans,
                    path.join(output_dir, 'context_mlm_trans.pkl'))
        save_module(model_to_save.context_order_trans,
                    path.join(output_dir, 'context_order_trans.pkl'))
        model_to_save.decoder_config.save_pretrained(
            output_dir)  # Save configuration file
        model_to_save.decoder.save_pretrained(path.join(output_dir, 'decoder'))

    def utt_encoding(self, context, utts_attn_mask):
        batch_size, max_ctx_len, max_utt_len = context.size(
        )  #context: [batch_size x diag_len x max_utt_len]

        utts = context.view(
            -1, max_utt_len)  # [(batch_size*diag_len) x max_utt_len]
        utts_attn_mask = utts_attn_mask.view(-1, max_utt_len)
        _, utts_encodings, *_ = self.utt_encoder.bert(utts, utts_attn_mask)
        utts_encodings = utts_encodings.view(batch_size, max_ctx_len, -1)
        return utts_encodings

    def context_encoding(self, context, utts_attn_mask, ctx_attn_mask):
        #with torch.no_grad():
        utt_encodings = self.utt_encoding(context, utts_attn_mask)
        context_hiddens, pooled_output, *_ = self.context_encoder(
            None, ctx_attn_mask, None, None, None, utt_encodings)
        # context_hiddens:[batch_size x ctx_len x dim]; pooled_output=[batch_size x dim]

        return context_hiddens, pooled_output

    def train_dialog_flow(self, context, context_utts_attn_mask,
                          context_attn_mask, context_lm_targets,
                          context_position_perm_id, context_position_ids,
                          response):
        """
        only train the dialog flow model
        """
        self.context_encoder.train()  # set the module in training mode.
        self.context_mlm_trans.train()

        context_hiddens, context_encoding = self.context_encoding(
            context, context_utts_attn_mask, context_attn_mask)
        lm_pred_encodings = self.context_mlm_trans(
            self.dropout(context_hiddens))

        context_lm_targets[context_lm_targets == -100] = 0
        ctx_lm_mask = context_lm_targets.sum(2)
        if (ctx_lm_mask > 0).sum() == 0: ctx_lm_mask[0, 0] = 1
        lm_pred_encodings = lm_pred_encodings[ctx_lm_mask > 0]
        context_lm_targets = context_lm_targets[ctx_lm_mask > 0]
        context_lm_targets_attn_mask = context_utts_attn_mask[ctx_lm_mask > 0]

        with torch.no_grad():
            _, lm_tgt_encodings, *_ = self.utt_encoder.bert(
                context_lm_targets, context_lm_targets_attn_mask)

        loss_ctx_mlm = MSELoss()(lm_pred_encodings,
                                 lm_tgt_encodings)  # [num_selected_utts x dim]

        # context order prediction
        if isinstance(self.context_order_trans, SelfSorting):
            sorting_scores = self.context_order_trans(context_hiddens,
                                                      context_attn_mask)
        else:
            sorting_scores = self.context_order_trans(context_hiddens)
        sorting_pad_mask = context_attn_mask == 0
        sorting_pad_mask[
            context_position_perm_id <
            1] = True  # exclude single-turn and unshuffled dialogs
        loss_ctx_uop = listNet(sorting_scores, context_position_ids,
                               sorting_pad_mask)
        #loss_ctx_uop = listMLE(sorting_scores, context_position_ids, sorting_pad_mask)

        loss = loss_ctx_mlm + loss_ctx_uop

        return {
            'loss': loss,
            'loss_ctx_mlm': loss_ctx_lm,
            'loss_ctx_uop': loss_ctx_uop
        }

    def train_decoder(self, context, context_utts_attn_mask, context_attn_mask,
                      context_lm_targets, context_position_perm_id,
                      context_position_ids, response):
        """
         only train the decoder
         """
        self.decoder.train()

        with torch.no_grad():
            context_hiddens, context_encoding = self.context_encoding(
                context, context_utts_attn_mask, context_attn_mask)

        ## train decoder
        dec_input, dec_target = response[:, :-1].contiguous(
        ), response[:, 1:].clone()

        dec_output, *_ = self.decoder(
            dec_input,
            dec_input.ne(self.tokenizer.pad_token_id).long(),
            None,
            None,
            None,
            None,
            encoder_hidden_states=context_hiddens,
            encoder_attention_mask=context_attn_mask,
        )

        batch_size, seq_len, vocab_size = dec_output.size()
        dec_target[response[:, 1:] == self.tokenizer.pad_token_id] = -100
        dec_target[context_position_perm_id >
                   1] == -100  # ignore responses whose contexts are shuffled
        loss_decoder = CrossEntropyLoss()(dec_output.view(-1, vocab_size),
                                          dec_target.view(-1))

        results = {'loss': loss_decoder, 'loss_decoder': loss_decoder}

        return results

    def forward(self, context, context_utts_attn_mask, context_attn_mask,
                context_mlm_targets, context_position_perm_id,
                context_position_ids, response):
        self.train()
        batch_size, max_ctx_len, max_utt_len = context.size(
        )  #context: [batch_size x diag_len x max_utt_len]

        context_hiddens, context_encoding = self.context_encoding(
            context, context_utts_attn_mask, context_attn_mask)

        ## train dialog flow modeling
        context_mlm_targets[context_mlm_targets == -100] = 0
        ctx_mlm_mask = context_mlm_targets.sum(2)  #[batch_size x num_utts]
        if (ctx_mlm_mask > 0).sum() == 0: ctx_mlm_mask[0, 0] = 1
        ctx_mlm_mask = ctx_mlm_mask > 0

        with torch.no_grad():
            _, mlm_tgt_encodings, *_ = self.utt_encoder.bert(
                context_mlm_targets[ctx_mlm_mask],
                context_utts_attn_mask[ctx_mlm_mask])

        if self.mlm_mode == 'mdn':  # mixture density network
            mlm_pred_pi, mlm_pred_normal = self.context_mlm_trans(
                self.dropout(context_hiddens[ctx_mlm_mask]))
            loss_ctx_mlm = self.context_mlm_trans.loss(mlm_pred_pi,
                                                       mlm_pred_normal,
                                                       mlm_tgt_encodings)
        else:  # simply mean square loss
            mlm_pred_encodings = self.context_mlm_trans(
                self.dropout(context_hiddens[ctx_mlm_mask]))
            loss_ctx_mlm = MSELoss()(
                mlm_pred_encodings,
                mlm_tgt_encodings)  # [num_selected_utts x dim]

        # context order prediction
        if isinstance(self.context_order_trans, SelfSorting):
            sorting_scores = self.context_order_trans(context_hiddens,
                                                      context_attn_mask)
        else:
            sorting_scores = self.context_order_trans(context_hiddens)
        sorting_pad_mask = context_attn_mask == 0
        sorting_pad_mask[
            context_position_perm_id <
            1] = True  # exclude single-turn and unshuffled dialogs
        loss_ctx_uop = listNet(sorting_scores, context_position_ids,
                               sorting_pad_mask)
        #loss_ctx_uop = listMLE(sorting_scores, context_position_ids, sorting_pad_mask)

        ## train decoder
        dec_input, dec_target = response[:, :-1].contiguous(
        ), response[:, 1:].clone()

        dec_output, *_ = self.decoder(
            dec_input,
            dec_input.ne(self.tokenizer.pad_token_id).long(),
            None,
            None,
            None,
            None,
            encoder_hidden_states=context_hiddens,
            encoder_attention_mask=context_attn_mask,
        )

        batch_size, seq_len, vocab_size = dec_output.size()
        dec_target[response[:, 1:] == self.tokenizer.pad_token_id] = -100
        dec_target[context_position_perm_id >
                   1] = -100  # ignore responses whose context was shuffled
        loss_decoder = CrossEntropyLoss()(dec_output.view(-1, vocab_size),
                                          dec_target.view(-1))

        loss = loss_ctx_mlm + loss_ctx_uop + loss_decoder

        results = {
            'loss': loss,
            'loss_ctx_mlm': loss_ctx_mlm,
            'loss_ctx_uop': loss_ctx_uop,
            'loss_decoder': loss_decoder
        }

        return results

    def validate(self, context, context_utts_attn_mask, context_attn_mask,
                 context_lm_targets, context_position_perm_id,
                 context_position_ids, response):
        results = self.train_decoder(context, context_utts_attn_mask,
                                     context_attn_mask, context_lm_targets,
                                     context_position_perm_id,
                                     context_position_ids, response)
        return results['loss'].item()

    def generate(self, input_batch, max_len=30, num_samples=1, mode='sample'):
        self.eval()
        device = next(self.parameters()).device
        context, context_utts_attn_mask, context_attn_mask = [
            t.to(device) for t in input_batch[:3]
        ]
        ground_truth = input_batch[6].numpy()

        context_hiddens, context_encoding = self.context_encoding(
            context, context_utts_attn_mask, context_attn_mask)

        generated = torch.zeros(
            (num_samples, 1), dtype=torch.long,
            device=device).fill_(self.tokenizer.cls_token_id)
        # [batch_sz x 1] (1=seq_len)

        sample_lens = torch.ones((num_samples, 1),
                                 dtype=torch.long,
                                 device=device)
        len_inc = torch.ones((num_samples, 1), dtype=torch.long, device=device)
        for _ in range(max_len):
            outputs, *_ = self.decoder(
                generated,
                generated.ne(self.tokenizer.pad_token_id).long(),
                None,
                None,
                None,
                None,
                encoder_hidden_states=context_hiddens,
                encoder_attention_mask=context_attn_mask,
            )  # [batch_size x seq_len x vocab_size]
            next_token_logits = outputs[:,
                                        -1, :] / self.decoder_config.temperature

            # repetition penalty from CTRL (https://arxiv.org/abs/1909.05858)
            for i in range(num_samples):
                for _ in set(generated[i].tolist()):
                    next_token_logits[
                        i, _] /= self.decoder_config.repetition_penalty

            filtered_logits = top_k_top_p_filtering(
                next_token_logits,
                top_k=self.decoder_config.top_k,
                top_p=self.decoder_config.top_p)
            if mode == 'greedy':  # greedy sampling:
                next_token = torch.argmax(filtered_logits,
                                          dim=-1).unsqueeze(-1)
            else:
                next_token = torch.multinomial(torch.softmax(filtered_logits,
                                                             dim=-1),
                                               num_samples=num_samples)
            next_token[len_inc == 0] = self.tokenizer.pad_token_id
            generated = torch.cat((generated, next_token), dim=1)
            len_inc = len_inc * (
                next_token != self.tokenizer.sep_token_id).long(
                )  # stop incresing length (set 0 bit) when EOS is encountered
            if len_inc.sum() < 1: break
            sample_lens = sample_lens + len_inc

        # to numpy
        sample_words = generated.data.cpu().numpy()
        sample_lens = sample_lens.data.cpu().numpy()

        context = context.data.cpu().numpy()
        return sample_words, sample_lens, context, ground_truth  # nparray: [repeat x seq_len]
class BertVisdEmbedding(nn.Module):
    '''
      The layer of generate Bert contextual representation
      '''
    def __init__(self, config=None, device=t.device("cpu")):
        '''
          Args:
            @config: configuration file of internal Bert layer
          '''
        super(BertVisdEmbedding, self).__init__()
        if config is None:
            self.bert = BertModel.from_pretrained('bert-base-uncased')
        else:
            self.bert = BertModel(config=config)  # transformers correspondence
        self.device = device
        self.bert_hidden_size = self.bert.config.hidden_size
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.CLS = tokenizer.convert_tokens_to_ids(
            ['[CLS]'])[0]  #ID of the Bert [CLS] token
        self.SEP = tokenizer.convert_tokens_to_ids(
            ['[SEP]'])[0]  #ID of the Bert [SEP] token
        self.PAD = tokenizer.convert_tokens_to_ids(
            ['[PAD]'])[0]  #ID of the Bert [PAD] token

    def make_bert_input(self, content_idxs, content_type, seg_ids):
        '''
          Args:
            @content_idxs (tensor): Bert IDs of the content. (batch_size, max_seq_len) Note that the max_seq_len is a fixed number due to padding/clamping policy.
            @content_type (str): whether the content is "question", "history" or "answer".
            @the initial segment ID: for "question" and "answer", this should be None; for 'history', this is should be well-initialized [0,..,0,1,...,1].
          Return:
            cmp_idx (tensor): [CLS] context_idxs [SEP]. (batch_size, max_seq_len+2)
            segment_ids (tensor): for "question" and "answer", this should be "1,1,...,1"; for "history", this should be "seg_ids[0], seg_ids, seg_ids[-1]". (batch_size, max_seq_len+2)
            input_mask (tensor): attention of the real token in content. Note [CLS] and [SEP] are count as real token. (batch_size, q_len + ctx_len + 2)
          '''
        mask = content_idxs != self.PAD  #get the mask indicating the non-padding tokens in the content
        if content_type == 'question' or content_type == 'answer':  #question/answer type
            seg_ids = t.zeros_like(content_idxs,
                                   dtype=content_idxs.dtype,
                                   device=content_idxs.device)

        seq_len = mask.sum(dim=1)  #(batch_size, ) length of each sequence
        batch_size, _ = content_idxs.size()
        content_idxs = t.cat(
            (content_idxs,
             t.tensor([[self.PAD]] * batch_size, device=content_idxs.device)),
            dim=1)  #(batch_size, max_seq_len+1)
        content_idxs[
            t.arange(0, batch_size),
            seq_len] = self.SEP  #append [SEP] token to obtain "content_idxs [SEP]"
        seg_last = seg_ids[t.arange(0, batch_size), seq_len -
                           1]  #get the last segment id of each sequence
        seg_ids = t.cat(
            (seg_ids, t.tensor([[0]] * batch_size,
                               device=content_idxs.device)),
            dim=1)  #(batch_size, max_seq_len+1)
        seg_ids[t.arange(0, batch_size),
                seq_len] = seg_last  #the segment id of the new appended [SEP]
        content_idxs = t.cat(
            (t.tensor([[self.CLS]] * batch_size,
                      device=content_idxs.device), content_idxs),
            dim=1
        )  #(batch_size, max_seq_len+2)append [CLS] token to obtain "[CLS] content_idxs [SEP]"
        seg_ids = t.cat(
            (seg_ids[:, 0].view(-1, 1), seg_ids), dim=1
        )  #(batch_size, max_seq_len+2) extend the first column of the segment id
        input_mask = (content_idxs !=
                      self.PAD).long()  #(batch_size, max_seq_len+2)

        return content_idxs, seg_ids, input_mask

    def parse_bert_output(self, bert_output, orig_PAD_mask):
        '''
          Args:
            @bert_output (tensor): Bert output with [CLS] and [SEP] embeddings. (batch_size, 1+max_seq_len+1, bert_hidden_size) 
            @orig_PAD_mask (tensor): 1 for PAD token, 0 for non-PAD token. (batch_size, max_seq_len)
          Return:
            bert_enc (tensor): Bert output without [CLS] and [SEP] embeddings, and with zero-embedding for all PAD tokens. (batch_size, max_seq_len, bert_hidden_size)
          '''
        bert_enc = bert_output[:, 1:
                               -1]  #(batch_size, max_seq_len, bert_hidden_size)
        pad_emb = t.zeros(
            self.bert_hidden_size, device=bert_output.device
        )  #manually set the embedding of PAD token to be zero
        #print(bert_enc.size(), orig_PAD_mask.size(), pad_emb.size(), bert_enc.device, orig_PAD_mask.device, pad_emb.device)
        bert_enc = bert_enc.contiguous()
        bert_enc[
            orig_PAD_mask] = pad_emb  #set the PAD token embeddings to be zero.
        return bert_enc

    def forward(self, content_idxs, content_type, seg_ids=None):
        '''
          Args:
            @content_idxs (tensor): Bert IDs of the contents. (batch_size, max_seq_len) Note that the max_seq_len is a fixed number due to padding/clamping policy
            @content_type (str): whether the tensor is "question", "history" or "answer"
          Return:
            bert_ctx_emb (tensor): contextual embedding condition on question. (batch_size, max_seq_len, bert_hidden_size)
          '''
        orig_PAD_mask = content_idxs == self.PAD
        cmp_idxs, segment_ids, bert_att = self.make_bert_input(
            content_idxs, content_type, seg_ids)
        outputs = self.bert(cmp_idxs, segment_ids, bert_att)
        bert_output = outputs[0]
        bert_enc = self.parse_bert_output(bert_output, orig_PAD_mask)
        return bert_enc

    def train(self, mode=True):
        '''
          Specifically set self.bert into training mode
          '''
        self.training = mode
        self.bert.train(mode)
        return self

    def eval(self):
        '''
          Specifically set self.bert into evaluation mode 
          '''
        return self.train(False)

    def to(self, *args, **kwargs):
        '''
          Override to() interface.
          '''
        print("bert emd to() called!")
        self = super().to(*args, **kwargs)
        self.bert = self.bert.to(*args, **kwargs)
        return self
# download vocab
vocab_info = tokenizer
vocab_path = download(vocab_info['url'],
                      vocab_info['fname'],
                      vocab_info['chksum'],
                      cachedir=cachedir)
#################################################################################################
print('BERT 모델 선언')

bertmodel = BertModel(config=BertConfig.from_dict(bert_config))
bertmodel.state_dict(torch.load(model_path))

print("GPU 디바이스 세팅")
device = torch.device(ctx)
bertmodel.to(device)
bertmodel.train()
vocab = nlp.vocab.BERTVocab.from_sentencepiece(vocab_path,
                                               padding_token='[PAD]')

#################################################################################################
# 파라미터 세팅
tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)

max_len = 64
batch_size = 64
warmup_ratio = 0.1
max_grad_norm = 1
log_interval = 200
learning_rate = 5e-5
#################################################################################################
예제 #7
0
    return np.sum(pred_flat == labels_flat) / len(labels_flat)


#@title The Training Loop
t = []

# Store our loss and accuracy for plotting
train_loss_set = []

# trange is a tqdm wrapper around the normal python range
for _ in trange(epochs, desc="Epoch"):

    # Training

    # Set our model to training mode (as opposed to evaluation mode)
    model.train()
    gc.collect()
    torch.cuda.empty_cache()
    # Tracking variables
    tr_loss = 0
    nb_tr_examples, nb_tr_steps = 0, 0

    # Train the data for one epoch
    for step, batch in enumerate(train_dataloader):
        # Add batch to GPU
        batch = tuple(t.to(device) for t in batch)
        # Unpack the inputs from our dataloader
        b_input_ids, b_input_mask, b_labels = batch

        b_input_ids = torch.tensor(b_input_ids).long()
        # Clear out the gradients (by default they accumulate)