Ejemplo n.º 1
0
    def train_D(self, context, context_lens, utt_lens, floors, response,
                res_lens):
        self.context_encoder.eval()
        self.discriminator.train()

        self.optimizer_D.zero_grad()

        batch_size = context.size(0)

        c = self.context_encoder(context, context_lens, utt_lens, floors)
        x, _ = self.utt_encoder(response[:, 1:], res_lens - 1)
        post_z = self.sample_code_post(x, c)
        errD_post = torch.mean(
            self.discriminator(torch.cat((post_z.detach(), c.detach()), 1)))
        errD_post.backward(self.one)

        prior_z = self.sample_code_prior(c)
        errD_prior = torch.mean(
            self.discriminator(torch.cat((prior_z.detach(), c.detach()), 1)))
        errD_prior.backward(self.minus_one)

        alpha = gData(torch.rand(batch_size, 1))
        alpha = alpha.expand(prior_z.size())
        interpolates = alpha * prior_z.data + ((1 - alpha) * post_z.data)
        interpolates = Variable(interpolates, requires_grad=True)
        d_input = torch.cat((interpolates, c.detach()), 1)
        disc_interpolates = torch.mean(self.discriminator(d_input))
        gradients = torch.autograd.grad(
            outputs=disc_interpolates,
            inputs=interpolates,
            grad_outputs=gData(torch.ones(disc_interpolates.size())),
            create_graph=True,
            retain_graph=True,
            only_inputs=True)[0]
        gradient_penalty = (
            (gradients.contiguous().view(gradients.size(0), -1).norm(2, dim=1)
             - 1)**2).mean() * self.lambda_gp
        gradient_penalty.backward()

        self.optimizer_D.step()
        costD = -(errD_prior - errD_post) + gradient_penalty
        return [('train_loss_D', costD.item())]
Ejemplo n.º 2
0
def evaluate(model, metrics, test_loader, ivocab, vocab, repeat, PAD_token=0):
    recall_bleus, prec_bleus, bows_extrema, bows_avg, bows_greedy, intra_dist1s, intra_dist2s, \
    avg_lens, inter_dist1s, inter_dist2s = [], [], [], [], [], [], [], [], [], []
    local_t = 0

    model.eval()
    pbar = tqdm(range(test_loader.num_batch))

    for bat in pbar:
        batch = test_loader.next_batch()
        if bat == test_loader.num_batch: break  # end of epoch

        local_t += 1

        context, context_lens, utt_lens, floors, _, _, _, response, res_lens, _ = batch
        # remove the sos token in the context and reduce the context length
        context, utt_lens = context[:, :, 1:], utt_lens - 1

        if local_t % 2000 == 0:
            logging.info("Batch %d \n" % (local_t))  # print the context

        start = np.maximum(0, context_lens[0] - 5)
        for t_id in range(start, context.shape[1], 1):
            context_str = indexes2sent(context[0, t_id], ivocab,
                                       ivocab["</s>"], PAD_token)
            if local_t % 2000 == 0:
                logging.info("Context %d-%d: %s\n" %
                             (t_id, floors[0, t_id], context_str))
        # print the true outputs
        ref_str, _ = indexes2sent(response[0], ivocab, ivocab["</s>"],
                                  ivocab["<s>"])
        ref_tokens = ref_str.split(' ')

        if local_t % 2000 == 0:
            logging.info("Target >> %s\n" % (ref_str.replace(" ' ", "'")))

        context, context_lens, utt_lens, floors = gVar(context), gVar(
            context_lens), gVar(utt_lens), gData(floors)
        sample_words, sample_lens = model.sample(context, context_lens,
                                                 utt_lens, floors, repeat,
                                                 ivocab["<s>"], ivocab["</s>"])
        # nparray: [repeat x seq_len]
        pred_sents, _ = indexes2sent(sample_words, ivocab, ivocab["</s>"],
                                     PAD_token)
        pred_tokens = [sent.split(' ') for sent in pred_sents]
        for r_id, pred_sent in enumerate(pred_sents):
            if local_t % 2000 == 0:
                logging.info("Sample %d >> %s\n" %
                             (r_id, pred_sent.replace(" ' ", "'")))

        max_bleu, avg_bleu = metrics.sim_bleu(pred_tokens, ref_tokens)
        recall_bleus.append(max_bleu)
        prec_bleus.append(avg_bleu)

        bow_extrema, bow_avg, bow_greedy = metrics.sim_bow(
            sample_words, sample_lens, response[:, 1:], res_lens - 2)
        bows_extrema.append(bow_extrema)
        bows_avg.append(bow_avg)
        bows_greedy.append(bow_greedy)

        intra_dist1, intra_dist2, inter_dist1, inter_dist2 = metrics.div_distinct(
            sample_words, sample_lens)
        intra_dist1s.append(intra_dist1)
        intra_dist2s.append(intra_dist2)
        avg_lens.append(np.mean(sample_lens))
        inter_dist1s.append(inter_dist1)
        inter_dist2s.append(inter_dist2)

    recall_bleu = float(np.mean(recall_bleus))
    prec_bleu = float(np.mean(prec_bleus))
    f1 = 2 * (prec_bleu * recall_bleu) / (prec_bleu + recall_bleu + 10e-12)

    bow_extrema = float(np.mean(bows_extrema))
    bow_avg = float(np.mean(bows_avg))
    bow_greedy = float(np.mean(bows_greedy))

    intra_dist1 = float(np.mean(intra_dist1s))
    intra_dist2 = float(np.mean(intra_dist2s))

    avg_len = float(np.mean(avg_lens))

    inter_dist1 = float(np.mean(inter_dist1s))
    inter_dist2 = float(np.mean(inter_dist2s))

    report = "Avg recall BLEU %f, avg precision BLEU %f, F1 %f, \nbow_avg %f, bow_extrema %f, bow_greedy %f, \n" \
             "intra_dist1 %f, intra_dist2 %f, inter_dist1 %f, inter_dist2 %f, \navg_len %f" \
             % (recall_bleu, prec_bleu, f1, bow_avg, bow_extrema, bow_greedy, intra_dist1, intra_dist2,
                inter_dist1, inter_dist2, avg_len)
    print(report)
    logging.info(report + "\n")
    print("Done testing")

    model.train()

    return recall_bleu, prec_bleu, bow_extrema, bow_avg, bow_greedy, intra_dist1, intra_dist2, avg_len, inter_dist1, inter_dist2
Ejemplo n.º 3
0
    def __init__(self, config, vocab_size, PAD_token=0):
        super(DialogWAE, self).__init__()
        self.vocab_size = vocab_size
        self.maxlen = config['maxlen']
        self.clip = config['clip']
        self.lambda_gp = config['lambda_gp']
        self.temp = config['temp']

        self.embedder = nn.Embedding(vocab_size,
                                     config['emb_size'],
                                     padding_idx=PAD_token)
        self.utt_encoder = Encoder(self.embedder, config['emb_size'],
                                   config['n_hidden'], True,
                                   config['n_layers'], config['noise_radius'])
        self.context_encoder = ContextEncoder(self.utt_encoder,
                                              config['n_hidden'] * 2 + 2,
                                              config['n_hidden'], 1,
                                              config['noise_radius'])
        self.prior_net = Variation(config['n_hidden'],
                                   config['z_size'])  # p(e|c)
        self.post_net = Variation(config['n_hidden'] * 3,
                                  config['z_size'])  # q(e|c,x)

        self.post_generator = nn.Sequential(
            nn.Linear(config['z_size'], config['z_size']),
            nn.BatchNorm1d(config['z_size'], eps=1e-05,
                           momentum=0.1), nn.ReLU(),
            nn.Linear(config['z_size'], config['z_size']),
            nn.BatchNorm1d(config['z_size'], eps=1e-05, momentum=0.1),
            nn.ReLU(), nn.Linear(config['z_size'], config['z_size']))
        self.post_generator.apply(self.init_weights)

        self.prior_generator = nn.Sequential(
            nn.Linear(config['z_size'], config['z_size']),
            nn.BatchNorm1d(config['z_size'], eps=1e-05,
                           momentum=0.1), nn.ReLU(),
            nn.Linear(config['z_size'], config['z_size']),
            nn.BatchNorm1d(config['z_size'], eps=1e-05, momentum=0.1),
            nn.ReLU(), nn.Linear(config['z_size'], config['z_size']))
        self.prior_generator.apply(self.init_weights)

        self.decoder = Decoder(self.embedder,
                               config['emb_size'],
                               config['n_hidden'] + config['z_size'],
                               vocab_size,
                               n_layers=1)

        self.discriminator = nn.Sequential(
            nn.Linear(config['n_hidden'] + config['z_size'],
                      config['n_hidden'] * 2),
            nn.BatchNorm1d(config['n_hidden'] * 2, eps=1e-05, momentum=0.1),
            nn.LeakyReLU(0.2),
            nn.Linear(config['n_hidden'] * 2, config['n_hidden'] * 2),
            nn.BatchNorm1d(config['n_hidden'] * 2, eps=1e-05, momentum=0.1),
            nn.LeakyReLU(0.2),
            nn.Linear(config['n_hidden'] * 2, 1),
        )
        self.discriminator.apply(self.init_weights)

        self.optimizer_AE = optim.SGD(list(self.context_encoder.parameters()) +
                                      list(self.post_net.parameters()) +
                                      list(self.post_generator.parameters()) +
                                      list(self.decoder.parameters()),
                                      lr=config['lr_ae'])
        self.optimizer_G = optim.RMSprop(
            list(self.post_net.parameters()) +
            list(self.post_generator.parameters()) +
            list(self.prior_net.parameters()) +
            list(self.prior_generator.parameters()),
            lr=config['lr_gan_g'])
        self.optimizer_D = optim.RMSprop(self.discriminator.parameters(),
                                         lr=config['lr_gan_d'])

        self.lr_scheduler_AE = optim.lr_scheduler.StepLR(self.optimizer_AE,
                                                         step_size=10,
                                                         gamma=0.6)

        self.criterion_ce = nn.CrossEntropyLoss()

        self.one = gData(torch.FloatTensor([1]))
        self.minus_one = self.one * -1
Ejemplo n.º 4
0
    n_iters = train_loader.num_batch / max(1, config['n_iters_d'])

    itr = 1
    pbar = tqdm(range(train_loader.num_batch))

    for bat in pbar:
        model.train()
        loss_records = []
        batch = train_loader.next_batch()
        if bat == train_loader.num_batch: break  # end of epoch

        context, context_lens, utt_lens, floors, _, _, _, response, res_lens, _ = batch
        # remove the sos token in the context and reduce the context length
        context, utt_lens = context[:, :, 1:], utt_lens - 1
        context, context_lens, utt_lens, floors, response, res_lens \
            = gVar(context), gVar(context_lens), gVar(utt_lens), gData(floors), gVar(response), gVar(res_lens)

        loss_AE = model.train_AE(context, context_lens, utt_lens, floors,
                                 response, res_lens)
        loss_records.extend(loss_AE)

        loss_G = model.train_G(context, context_lens, utt_lens, floors,
                               response, res_lens)
        loss_records.extend(loss_G)

        for i in range(config['n_iters_d']):  # train discriminator/critic
            loss_D = model.train_D(context, context_lens, utt_lens, floors,
                                   response, res_lens)
            if i == 0:
                loss_records.extend(loss_D)
            if i == config['n_iters_d'] - 1: