Beispiel #1
0
 def forward(self, context):
     batch_size,_=context.size()
     context = self.fc(context)
     mu=self.context_to_mu(context)
     logsigma = self.context_to_logsigma(context) 
     std = torch.exp(0.5 * logsigma)
     
     epsilon = gVar(torch.randn([batch_size, self.z_size]))
     z = epsilon * std + mu  
     return z, mu, logsigma 
Beispiel #2
0
    def sampling(self,
                 init_h,
                 enc_hids,
                 context,
                 maxlen,
                 mode='greedy',
                 to_numpy=True):
        """
        A simple greedy sampling
        :param init_h: [batch_sz x hid_sz]
        :param enc_hids: a tuple of (enc_hids, mask) for attention use. [batch_sz x seq_len x hid_sz]
        """
        batch_size = init_h.size(0)
        decoded_words = gVar(torch.zeros(batch_size, maxlen)).long()
        sample_lens, len_inc = gVar(torch.zeros(batch_size)).long(), gVar(
            torch.ones(batch_size)).long()

        x = gVar(
            torch.LongTensor([[SOS_ID] * batch_size
                              ]).view(batch_size,
                                      1))  # [batch_sz x 1] (1=seq_len)
        h = init_h.unsqueeze(0)  # [1 x batch_sz x hid_sz]
        for di in range(maxlen):
            out, h = self.forward(h.squeeze(0), enc_hids, context, x)
            if mode == 'greedy':
                x = out[:, -1].max(1, keepdim=True)[
                    1]  # x:[batch_sz x 1] indexes of predicted words
            elif mode == 'sample':
                x = torch.multinomial(F.softmax(out[:, -1], dim=1), 1)
            decoded_words[:, di] = x.squeeze()
            len_inc = len_inc * (x.squeeze() != EOS_ID).long(
            )  # stop increse length (set 0 bit) when EOS is met
            sample_lens = sample_lens + len_inc

        if to_numpy:
            decoded_words = decoded_words.data.cpu().numpy()
            sample_lens = sample_lens.data.cpu().numpy()
        return decoded_words, sample_lens
Beispiel #3
0
def evaluate(model, test_loader, vocab, repeat, f_eval):
    # TODO:
    z_vec = []
    local_t = 0
    while True:
        batch = test_loader.next_batch()
        if batch is None:
            break
        local_t += 1
        if local_t != 1033:
            continue
        context, context_lens, utt_lens, floors, _, _, _, response, res_lens, _ = batch
        context, utt_lens = context[:, :,
                                    1:], utt_lens - 1  # remove the sos token in the context and reduce the context length
        f_eval.write("Batch %d \n" % (local_t))  # print the context
        start = np.maximum(0, context_lens[0] - 5)
        for t_id in range(start, context.shape[1], 1):
            context_str = indexes2sent(context[0, t_id], vocab, vocab["</s>"],
                                       PAD_token)
            f_eval.write("Context %d-%d: %s\n" %
                         (t_id, floors[0, t_id], context_str))
        # print the true outputs
        ref_str, _ = indexes2sent(response[0], vocab, vocab["</s>"],
                                  vocab["<s>"])
        ref_tokens = ref_str.split(' ')
        f_eval.write("Target >> %s\n" % (ref_str.replace(" ' ", "'")))

        context, context_lens, utt_lens, floors = gVar(context), gVar(
            context_lens), gVar(utt_lens), gData(floors)
        prior_z, c_repeated = model.sample_fix_z(context, context_lens,
                                                 utt_lens, floors, repeat,
                                                 vocab["<s>"], vocab["</s>"])

        # nparray: [repeat x hidden_z]
        z_vec.append(prior_z)

    return z_vec
 def forward(self, context):
     batch_size,_=context.size()
     context = self.fc(context)
     
     pi=self.pi_net(context) 
     pi=F.gumbel_softmax(pi, tau=self.gumbel_temp, hard=True, eps=1e-10)
     pi=pi.unsqueeze(1) 
 
     mus=self.context_to_mu(context)
     logsigmas = self.context_to_logsigma(context) 
     stds = torch.exp(0.5 * logsigmas)
     
     epsilons = gVar(torch.randn([batch_size, self.n_components*self.z_size]))
     
     zi = (epsilons * stds + mus).view(batch_size, self.n_components, self.z_size)
     z = torch.bmm(pi, zi).squeeze(1)  # [batch_sz x z_sz]
     mu = torch.bmm(pi, mus.view(batch_size, self.n_components, self.z_size))
     logsigma = torch.bmm(pi, logsigmas.view(batch_size, self.n_components, self.z_size))
     return z, mu, logsigma
Beispiel #5
0
    def sampling(self,
                 init_hidden,
                 context,
                 maxlen,
                 SOS_tok,
                 EOS_tok,
                 mode='greedy'):
        batch_size = init_hidden.size(0)
        decoded_words = np.zeros((batch_size, maxlen), dtype=np.int)
        sample_lens = np.zeros(batch_size, dtype=np.int)

        decoder_input = gVar(
            torch.LongTensor([[SOS_tok] * batch_size]).view(batch_size, 1))
        decoder_input = self.embedding(
            decoder_input) if self.embedding is not None else decoder_input
        decoder_input = torch.cat(
            [decoder_input, context.unsqueeze(1)],
            2) if context is not None else decoder_input
        decoder_hidden = init_hidden.unsqueeze(0)
        for di in range(maxlen):
            decoder_output, decoder_hidden = self.rnn(decoder_input,
                                                      decoder_hidden)
            decoder_output = self.out(decoder_output)
            if mode == 'greedy':
                topi = decoder_output[:, -1].max(1, keepdim=True)[1]
            elif mode == 'sample':
                topi = torch.multinomial(
                    F.softmax(decoder_output[:, -1], dim=1), 1)
            decoder_input = self.embedding(
                topi) if self.embedding is not None else topi
            decoder_input = torch.cat(
                [decoder_input, context.unsqueeze(1)],
                2) if context is not None else decoder_input
            ni = topi.squeeze().data.cpu().numpy()
            decoded_words[:, di] = ni

        for i in range(batch_size):
            for word in decoded_words[i]:
                if word == EOS_tok:
                    break
                sample_lens[i] = sample_lens[i] + 1
        return decoded_words, sample_lens
Beispiel #6
0
    def forward(self, inputs, input_lens=None, noise=False):
        if self.embedding is not None:
            inputs = self.embedding(
                inputs
            )  # input: [batch_sz x seq_len] -> [batch_sz x seq_len x emb_sz]

        batch_size, seq_len, emb_size = inputs.size()
        inputs = F.dropout(inputs, self.dropout, self.training)  # dropout

        if input_lens is not None:  # sort and pack sequence
            input_lens_sorted, indices = input_lens.sort(descending=True)
            inputs_sorted = inputs.index_select(0, indices)
            inputs = pack_padded_sequence(inputs_sorted,
                                          input_lens_sorted.data.tolist(),
                                          batch_first=True)

        #self.rnn.flatten_parameters() # time consuming!!
        hids, h_n = self.rnn(inputs)  # hids: [b x seq x (n_dir*hid_sz)]
        # h_n: [(n_layers*n_dir) x batch_sz x hid_sz] (2=fw&bw)
        if input_lens is not None:  # reorder and pad
            _, inv_indices = indices.sort()
            hids, lens = pad_packed_sequence(hids, batch_first=True)
            hids = hids.index_select(0, inv_indices)
            hids = (hids, sequence_mask(input_lens)
                    )  # append mask for attention
            h_n = h_n.index_select(1, inv_indices)
        h_n = h_n.view(
            self.n_layers, (1 + self.bidirectional), batch_size,
            self.hidden_size)  #[n_layers x n_dirs x batch_sz x hid_sz]
        h_n = h_n[-1]  # get the last layer [n_dirs x batch_sz x hid_sz]
        enc = h_n.transpose(1, 0).contiguous().view(
            batch_size, -1)  #[batch_sz x (n_dirs*hid_sz)]
        if noise and self.noise_radius > 0:
            gauss_noise = gVar(
                torch.normal(means=torch.zeros(enc.size()),
                             std=self.noise_radius))
            enc = enc + gauss_noise

        return enc, hids
                            shuffle=True)

    n_iters = train_loader.num_batch / max(1, config['n_iters_d'])

    itr = 1
    while True:  # loop through all batches in training data
        model.train()
        loss_records = []
        batch = train_loader.next_batch()
        if batch is None:  # end of epoch
            break
        context, context_lens, utt_lens, floors, _, _, _, response, res_lens, _ = batch
        context, utt_lens = context[:, :,
                                    1:], utt_lens - 1  # remove the sos token in the context and reduce the context length
        context, context_lens, utt_lens, floors, response, res_lens\
                = gVar(context), gVar(context_lens), gVar(utt_lens), gData(floors), gVar(response), gVar(res_lens)
        #for i in range(config['n_iters_d']):
        loss_AE = model.train_AE(context, context_lens, utt_lens, floors,
                                 response, res_lens)
        loss_records.extend(loss_AE)

        loss_G = model.train_G(context, context_lens, utt_lens, floors,
                               response, res_lens)
        loss_records.extend(loss_G)
        '''
        for i in range(config['n_iters_d']):# train discriminator/critic
            loss_D = model.train_D(context, context_lens, utt_lens, floors, response, res_lens)  
            if i==0:
                loss_records.extend(loss_D)
            if i==config['n_iters_d']-1:
                break
Beispiel #8
0
def evaluate(model, metrics, test_loader, vocab, ivocab, f_eval, repeat):

    recall_bleus, prec_bleus, bows_extrema, bows_avg, bows_greedy, intra_dist1s, intra_dist2s, avg_lens, inter_dist1s, inter_dist2s\
        = [], [], [], [], [], [], [], [], [], []
    local_t = 0
    while True:
        batch = test_loader.next_batch()
        if batch is None:
            break
        local_t += 1
        context, context_lens, utt_lens, floors, _, _, _, response, res_lens, _ = batch
        context, utt_lens = context[:, :,
                                    1:], utt_lens - 1  # remove the sos token in the context and reduce the context length
        f_eval.write("Batch %d \n" % (local_t))  # print the context
        start = np.maximum(0, context_lens[0] - 5)
        for t_id in range(start, context.shape[1], 1):
            context_str = indexes2sent(context[0, t_id], vocab, vocab["</s>"],
                                       PAD_token)
            f_eval.write("Context %d-%d: %s\n" %
                         (t_id, floors[0, t_id], context_str))
        # print the true outputs
        ref_str, _ = indexes2sent(response[0], vocab, vocab["</s>"],
                                  vocab["<s>"])
        ref_tokens = ref_str.split(' ')
        f_eval.write("Target >> %s\n" % (ref_str.replace(" ' ", "'")))

        context, context_lens, utt_lens, floors = gVar(context), gVar(
            context_lens), gVar(utt_lens), gData(floors)
        sample_words, sample_lens = model.sample(context, context_lens,
                                                 utt_lens, floors, repeat,
                                                 vocab["<s>"], vocab["</s>"])
        # nparray: [repeat x seq_len]
        pred_sents, _ = indexes2sent(sample_words, vocab, vocab["</s>"],
                                     PAD_token)
        pred_tokens = [sent.split(' ') for sent in pred_sents]
        for r_id, pred_sent in enumerate(pred_sents):
            f_eval.write("Sample %d >> %s\n" %
                         (r_id, pred_sent.replace(" ' ", "'")))
        max_bleu, avg_bleu = metrics.sim_bleu(pred_tokens, ref_tokens)
        recall_bleus.append(max_bleu)
        prec_bleus.append(avg_bleu)

        bow_extrema, bow_avg, bow_greedy = metrics.sim_bow(
            sample_words, sample_lens, response[:, 1:], res_lens - 2)
        bows_extrema.append(bow_extrema)
        bows_avg.append(bow_avg)
        bows_greedy.append(bow_greedy)

        intra_dist1, intra_dist2, inter_dist1, inter_dist2 = metrics.div_distinct(
            sample_words, sample_lens)
        intra_dist1s.append(intra_dist1)
        intra_dist2s.append(intra_dist2)
        avg_lens.append(np.mean(sample_lens))
        inter_dist1s.append(inter_dist1)
        inter_dist2s.append(inter_dist2)

        f_eval.write("\n")

    recall_bleu = float(np.mean(recall_bleus))
    prec_bleu = float(np.mean(prec_bleus))
    f1 = 2 * (prec_bleu * recall_bleu) / (prec_bleu + recall_bleu + 10e-12)
    bow_extrema = float(np.mean(bows_extrema))
    bow_avg = float(np.mean(bows_avg))
    bow_greedy = float(np.mean(bows_greedy))
    intra_dist1 = float(np.mean(intra_dist1s))
    intra_dist2 = float(np.mean(intra_dist2s))
    avg_len = float(np.mean(avg_lens))
    inter_dist1 = float(np.mean(inter_dist1s))
    inter_dist2 = float(np.mean(inter_dist2s))
    report = "Avg recall BLEU %f, avg precision BLEU %f, F1 %f, bow_extrema %f, bow_avg %f, bow_greedy %f,\
    intra_dist1 %f, intra_dist2 %f, avg_len %f, inter_dist1 %f, inter_dist2 %f (only 1 ref, not final results)" \
    % (recall_bleu, prec_bleu, f1, bow_extrema, bow_avg, bow_greedy, intra_dist1, intra_dist2, avg_len, inter_dist1, inter_dist2)
    print(report)
    f_eval.write(report + "\n")
    print("Done testing")

    return recall_bleu, prec_bleu, bow_extrema, bow_avg, bow_greedy, intra_dist1, intra_dist2, avg_len, inter_dist1, inter_dist2
Beispiel #9
0
def main(args):
    # setup logging
    log = get_logger(args.log)
    log(args)
    timestamp = datetime.now().strftime('%Y%m%d%H%M')
    tb_writer = SummaryWriter("./output/{}/{}/{}/logs/".format(args.model, args.expname, args.dataset)\
                          +timestamp) if args.visual else None

    config = getattr(configs, 'config_' + args.model)()

    # instantiate the dmm
    model = getattr(models, args.model)(config)
    model = model.cuda()
    if args.reload_from >= 0:
        load_model(model, args.reload_from)

    train_set = PolyphonicDataset(args.data_path + 'train.pkl')
    valid_set = PolyphonicDataset(args.data_path + 'valid.pkl')
    test_set = PolyphonicDataset(args.data_path + 'test.pkl')

    #################
    # TRAINING LOOP #
    #################
    times = [time.time()]
    for epoch in range(config['epochs']):

        train_loader = torch.utils.data.DataLoader(
            dataset=train_set,
            batch_size=config['batch_size'],
            shuffle=True,
            num_workers=1)
        train_data_iter = iter(train_loader)
        n_iters = train_data_iter.__len__()

        epoch_nll = 0.0  # accumulator for our estimate of the negative log likelihood (or rather -elbo) for this epoch
        i_batch = 1
        n_slices = 0
        loss_records = {}
        while True:
            try:
                x, x_rev, x_lens = train_data_iter.next()
            except StopIteration:
                break  # end of epoch
            x, x_rev, x_lens = gVar(x), gVar(x_rev), gVar(x_lens)

            if config['anneal_epochs'] > 0 and epoch < config[
                    'anneal_epochs']:  # compute the KL annealing factor
                min_af = config['min_anneal']
                kl_anneal = min_af + (
                    1.0 - min_af) * (float(i_batch + epoch * n_iters + 1) /
                                     float(config['anneal_epochs'] * n_iters))
            else:
                kl_anneal = 1.0  # by default the KL annealing factor is unity

            loss_AE = model.train_AE(x, x_rev, x_lens, kl_anneal)

            epoch_nll += loss_AE['train_loss_AE']
            i_batch = i_batch + 1
            n_slices = n_slices + x_lens.sum().item()

        loss_records.update(loss_AE)
        loss_records.update({'epo_nll': epoch_nll / n_slices})
        times.append(time.time())
        epoch_time = times[-1] - times[-2]
        log("[Epoch %04d]\t\t(dt = %.3f sec)" % (epoch, epoch_time))
        log(loss_records)
        if args.visual:
            for k, v in loss_records.items():
                tb_writer.add_scalar(k, v, epoch)
        # do evaluation on test and validation data and report results
        if (epoch + 1) % args.test_freq == 0:
            save_model(model, epoch)
            test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=config['batch_size'],
                shuffle=False,
                num_workers=1)
            for x, x_rev, x_lens in test_loader:
                x, x_rev, x_lens = gVar(x), gVar(x_rev), gVar(x_lens)
                test_nll = model.valid(x, x_rev, x_lens) / x_lens.sum()
            log("[val/test epoch %08d]  %.8f" % (epoch, test_nll))
    # shuffle (re-define) data between epochs
    train_loader = torch.utils.data.DataLoader(dataset=train_set,
                                               batch_size=config['batch_size'],
                                               shuffle=True,
                                               num_workers=1)
    train_data_iter = iter(train_loader)
    n_iters = train_data_iter.__len__()

    itr = 1
    while True:  # loop through all batches in training data
        model.train()
        try:
            descs, apiseqs, desc_lens, api_lens = train_data_iter.next()
        except StopIteration:  # end of epoch
            break
        descs, apiseqs, desc_lens, api_lens = gVar(descs), gVar(apiseqs), gVar(
            desc_lens), gVar(api_lens)
        loss_AE = model.train_AE(descs, desc_lens, apiseqs, api_lens)

        if itr % args.log_every == 0:
            elapsed = time.time() - itr_start_time
            log = '%s-%s|@gpu%d epo:[%d/%d] iter:[%d/%d] step_time:%ds elapsed:%s \n                      '\
            %(args.model, args.expname, args.gpu_id, epoch, config['epochs'],
                     itr, n_iters, elapsed, timeSince(epoch_start_time,itr/n_iters))
            for loss_name, loss_value in loss_AE.items():
                log = log + loss_name + ':%.4f ' % (loss_value)
                if args.visual:
                    tb_writer.add_scalar(loss_name, loss_value, itr_global)
            logger.info(log)

            itr_start_time = time.time()
Beispiel #11
0
                            shuffle=True)

    n_iters = train_loader.num_batch / max(1, config['n_iters_d'])

    itr = 1
    while True:  # loop through all batches in training data
        model.train()
        loss_records = []
        batch = train_loader.next_batch()
        if batch is None:  # end of epoch
            break
        context, context_lens, utt_lens, floors, _, _, _, response, res_lens, dialog_act = batch
        context, utt_lens = context[:, :,
                                    1:], utt_lens - 1  # remove the sos token in the context and reduce the context length
        context, context_lens, utt_lens, floors, response, res_lens, dialog_act\
                = gVar(context), gVar(context_lens), gVar(utt_lens), gData(floors), gVar(response), gVar(res_lens), gVar(dialog_act)

        loss_AE = model.train_AE(context, context_lens, utt_lens, floors,
                                 response, res_lens, dialog_act)
        loss_records.extend(loss_AE)

        loss_G = model.train_G(context, context_lens, utt_lens, floors,
                               response, res_lens)
        loss_records.extend(loss_G)

        for i in range(config['n_iters_d']):  # train discriminator/critic
            loss_D = model.train_D(context, context_lens, utt_lens, floors,
                                   response, res_lens)
            if i == 0:
                loss_records.extend(loss_D)
            if i == config['n_iters_d'] - 1:
Beispiel #12
0
    def beam_decode(self,
                    init_h,
                    enc_hids,
                    context,
                    beam_width,
                    max_unroll,
                    topk=1):
        '''
        https://github.com/budzianowski/PyTorch-Beam-Search-Decoding/blob/master/decode_beam.py
        :param init_h: input tensor of shape [B, H] for start of the decoding
        :param enc_hids: if you are using attention mechanism you can pass encoder outputs, [B, T, H] where T is the maximum length of input sentence
        :param topk: how many sentence do you want to generate
        :return: decoded_batch
        '''
        batch_size = init_h.size(0)
        decoded_words = np.zeros((batch_size, topk, max_unroll), dtype=np.int)
        sample_lens = np.zeros((batch_size, topk), dtype=np.int)
        scores = np.zeros((batch_size, topk))

        for idx in range(batch_size):  # decoding goes sentence by sentence
            if isinstance(init_h, tuple):  # LSTM case
                h = (init_h[0][idx, :].view(1, 1, -1),
                     init_h[1][idx, :].view(1, 1, -1))
            else:
                h = init_h[idx, :].view(1, 1, -1)
            if enc_hids is not None:
                enc_outs, enc_outs_mask = enc_hids
                enc_outs = enc_outs[idx, :, :].unsqueeze(0)
                enc_outs_mask = enc_outs_mask[idx, :].unsqueeze(0)
                enc_outs = (enc_outs, enc_outs_mask)

            # Start with the start of the sentence token
            x = gVar(torch.LongTensor([[SOS_ID]]))

            # Number of sentence to generate
            endnodes = []
            number_required = min((topk + 1), topk - len(endnodes))

            # starting node -  hidden vector, previous node, word id, logp, length
            node = BeamSearchNode(h, None, x, 0, 1)
            nodes = PriorityQueue()

            # start the queue
            nodes.put((-node.eval(), node))
            qsize = 1

            # start beam search
            while True:
                if qsize > 2000: break  # give up when decoding takes too long

                score, n = nodes.get()  # fetch the best node
                x = n.wordid
                h = n.h
                qsize -= 1

                if n.wordid.item() == EOS_ID and n.prevNode != None:
                    endnodes.append((score, n))
                    # if we reached maximum # of sentences required
                    if len(endnodes) >= number_required:
                        break
                    else:
                        continue

                # decode for one step using decoder
                out, h = self.forward(h.squeeze(0), enc_outs, None,
                                      x)  # out [1 x 1 x vocab_size]
                out = out.squeeze(1)  # [1 x vocab_size]
                out = F.log_softmax(out, 1)

                # PUT HERE REAL BEAM SEARCH OF TOP
                log_prob, indexes = torch.topk(out,
                                               beam_width)  # [1 x beam_width]

                for new_k in range(beam_width):
                    decoded_t = indexes[0][new_k].view(1, -1)
                    log_p = log_prob[0][new_k].item()
                    node = BeamSearchNode(h, n, decoded_t, n.logp + log_p,
                                          n.len + 1)
                    score = -node.eval()
                    nodes.put((score, node))  # put them into queue
                    qsize += 1  # increase qsize

            # choose nbest paths, back trace them
            if len(endnodes) == 0:
                endnodes = [nodes.get() for _ in range(topk)]

            uid = 0
            for score, n in sorted(endnodes, key=operator.itemgetter(0)):
                utterance, length = [], n.len
                utterance.append(n.wordid)
                # back trace
                while n.prevNode != None:
                    n = n.prevNode
                    utterance.append(n.wordid)
                utterance = utterance[::-1]  #reverse
                utterance, length = utterance[1:], length - 1  # remove <sos>
                decoded_words[idx,
                              uid, :min(length, max_unroll)] = utterance[:min(
                                  length, max_unroll)]
                sample_lens[idx, uid] = min(length, max_unroll)
                scores[idx, uid] = score
                uid = uid + 1

        return decoded_words, sample_lens, scores