示例#1
0
    def step(self, x_train: Tensor, y_train: Tensor, x_val: Tensor or None,
             y_val: Tensor or None, optimizer,
             batch_size: int) -> Tuple[Tensor, Tensor or None]:
        """Performs a full epoch of training and validation.

        :param x_train: Training inputs
        :param y_train: Training truth labels
        :param x_val: Validation inputs
        :param y_val: Validation truth labels
        :param optimizer:
        :param batch_size: Batch size for training
        """
        device_type = self.get_device_type()

        # optimizer = optim.SGD(self.parameters(), lr=learn_rate)

        tr_loss = 0
        train_loader = get_data_loader(x_train, y_train, batch_size)
        train_batches = len(train_loader)
        progress_bar = tqdm(desc='TRAIN', total=train_batches)

        for i, (inputs, labels) in enumerate(train_loader):
            progress_bar.update()
            if device_type == 'cuda':
                inputs, labels = inputs.cuda(), labels.cuda()

            optimizer.zero_grad()
            loss = self.get_loss(labels, inputs)
            loss.backward()
            optimizer.step()
            tr_loss += loss.data.item() / train_batches

        progress_bar.close()
        if x_val is None or y_val is None:
            return tr_loss, None

        va_loss = 0
        val_loader = get_data_loader(x_val, y_val, batch_size)
        val_batches = len(val_loader)
        progress_bar = tqdm(desc='  VAL', total=val_batches)

        for i, (inputs, labels) in enumerate(val_loader):
            progress_bar.update()
            if device_type == 'cuda':
                inputs, labels = inputs.cuda(), labels.cuda()

            loss = self.get_loss(labels, inputs)
            va_loss += loss.data.item() / val_batches

        progress_bar.close()

        return tr_loss, va_loss
示例#2
0
def train_retriever(retriever,
                    optimizer,
                    dataset,
                    pad_idx=1,
                    batch_size=32,
                    epoch=0,
                    distributed=True):
    retriever.train()
    data_loader = get_data_loader(dataset, collate_fn(pad_idx), batch_size,
                                  distributed, epoch)
    label = torch.tensor([i for i in range(4096)]).long().cuda()
    tk0 = tqdm(data_loader, total=len(data_loader))
    item_num = 0
    acc_report = 0
    loss_report = 0
    for batch in tk0:
        context = batch['context'].cuda()
        knowledge_pool = batch['knowledge_pool'].cuda()
        # concat_context = batch['concat_context'].cuda()
        # response = batch['response'].cuda()
        # d_id = batch['d_id']
        # k_id = batch['k_id']

        bc_size = context.size(0)
        item_num += bc_size

        context = retriever(context)
        knowledge = retriever(knowledge_pool)

        _, pooled_dim = context['pooled'].size()
        _, compressed_dim = context['compressed'].size()

        pooled_context = context['pooled'] / np.sqrt(pooled_dim)
        pooled_knowledge = knowledge['pooled'] / np.sqrt(pooled_dim)
        attention = torch.mm(pooled_context, pooled_knowledge.t())
        pooled_loss = F.cross_entropy(attention, label[:bc_size])

        compressed_context = context['compressed'] / np.sqrt(compressed_dim)
        compressed_knowledge = knowledge['compressed'] / np.sqrt(
            compressed_dim)
        attention = torch.mm(compressed_context, compressed_knowledge.t())
        compressed_loss = F.cross_entropy(attention, label[:bc_size])

        loss = pooled_loss + compressed_loss
        loss.backward()
        torch.nn.utils.clip_grad_norm_(retriever.parameters(), 2)
        optimizer.step()
        optimizer.zero_grad()

        acc_report += (attention.argmax(
            dim=-1) == label[:bc_size]).sum().item()
        loss_report += loss.item()

        tk0.set_postfix(loss=round(loss_report / item_num, 4),
                        acc=round(acc_report / item_num, 4))
示例#3
0
def test_generator(generator,
                   dataset,
                   language,
                   tokenizer,
                   pad_idx=1,
                   batch_size=32,
                   epoch=0,
                   word_mask=None):
    generator.eval()
    data_loader = get_data_loader(dataset, collate_ckgc(pad_idx), batch_size,
                                  False, epoch)
    tk0 = tqdm(data_loader, total=len(data_loader))
    f1_report = []
    outputs_predict = []
    outputs_true = []
    with torch.no_grad():
        for batch in tk0:
            # context = batch['context'].cuda()
            # knowledge_pool = batch['knowledge_pool'].cuda()
            concat_context = batch['concat_context'].cuda()
            response = batch['response'].cuda()
            # d_ids = batch['d_id']
            # k_ids = batch['k_id']

            predict = generator.generate(
                input_ids=concat_context,
                decoder_start_token_id=tokenizer.lang_code_to_id[language],
                num_beams=3,
                max_length=128,
                bad_words_ids=word_mask)

            predict_sent = tokenizer.batch_decode(predict,
                                                  skip_special_tokens=True)
            label_sent = tokenizer.batch_decode(response,
                                                skip_special_tokens=True)

            outputs_predict.extend(predict_sent)
            outputs_true.extend(label_sent)
            if language == 'zh':
                f1 = [
                    f1_score(' '.join(pred), [' '.join(label)])
                    for pred, label in zip(predict_sent, label_sent)
                ]
            else:
                f1 = [
                    f1_score(pred, [label])
                    for pred, label in zip(predict_sent, label_sent)
                ]
            f1_report.extend(f1)

            tk0.set_postfix(f1score=round(sum(f1_report) / len(f1_report), 4))

    return outputs_predict, outputs_true
示例#4
0
def train_generator(generator,
                    optimizer,
                    dataset,
                    pad_idx=1,
                    batch_size=32,
                    epoch=0,
                    distributed=True):
    generator.train()
    data_loader = get_data_loader(dataset, collate_fn(pad_idx), batch_size,
                                  distributed, epoch)
    tk0 = tqdm(data_loader, total=len(data_loader))
    item_num = 0
    acc_report = []
    ppl_report = 0
    for batch in tk0:
        # context = batch['context'].cuda()
        # knowledge_pool = batch['knowledge_pool'].cuda()
        concat_context = batch['concat_context'].cuda()
        response = batch['response'].cuda()
        # d_id = batch['d_id']
        # k_id = batch['k_id']

        predict = generator(concat_context, response)['logits']

        bc_size, length, emb_size = predict.size()
        predict = predict[:, :-1, :].contiguous().view(-1, emb_size)
        gt = response[:, 1:].contiguous().view(-1)
        length -= 1

        loss = F.cross_entropy(predict,
                               gt,
                               ignore_index=pad_idx,
                               reduction='none')
        loss = loss.view(-1, length).sum(dim=1)
        tk_num = response.ne(pad_idx).long().sum(dim=-1)

        ppl_report += torch.exp(loss / tk_num).sum().item()
        item_num += bc_size
        acc_report.append(
            ((predict.argmax(dim=-1) == gt) & gt.ne(pad_idx)).sum().item() /
            tk_num.sum().item())

        loss = loss.sum() / tk_num.sum()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(generator.parameters(), 2)

        optimizer.step()
        optimizer.zero_grad()

        tk0.set_postfix(acc=round(sum(acc_report) / len(acc_report), 3),
                        ppl=ppl_report / item_num)
示例#5
0
def test_retriever(retriever, dataset, pad_idx=1, batch_size=32, epoch=0):
    retriever.eval()
    data_loader = get_data_loader(dataset, collate_ckgc(pad_idx), batch_size,
                                  False, epoch)
    tk0 = tqdm(data_loader, total=len(data_loader))
    k_ranks = []
    rat1 = []
    rat5 = []
    with torch.no_grad():
        for batch in tk0:
            context = batch['context'].cuda()
            knowledge_pool = batch['knowledge_pool'].cuda()
            # concat_context = batch['concat_context'].cuda()
            # response = batch['response'].cuda()
            # d_ids = batch['d_id']
            k_ids = batch['k_id']

            context = retriever(context)['pooled']

            bc_size, pool_size, _ = knowledge_pool.size()
            knowledge_pool = knowledge_pool.view(bc_size * pool_size, -1)
            knowledge_pool = retriever(knowledge_pool)['pooled']
            knowledge_pool = knowledge_pool.view(bc_size, pool_size, -1)

            d_model = context.size(-1)
            context = context / np.sqrt(d_model)
            knowledge_pool = knowledge_pool / np.sqrt(d_model)

            attention = torch.bmm(knowledge_pool,
                                  context.unsqueeze(-1)).squeeze(-1)

            # chose = attention.argmax(dim=-1)

            rank = [
                remove_duplicates([kid[ri] for ri in prob.argsort()])
                for prob, kid in zip(attention, k_ids)
            ]
            k_ranks.extend(rank)

            rat1.extend(
                [int(true[0] == pred[0]) for true, pred in zip(k_ids, rank)])
            rat5.extend(
                [int(true[0] in pred[:5]) for true, pred in zip(k_ids, rank)])

            tk0.set_postfix(rat1=round(sum(rat1) / len(rat1), 4),
                            rat5=round(sum(rat5) / len(rat5), 4))

    return k_ranks