def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH):
    with torch.no_grad():
        input_tensor = sentence2sequence(input_lang, sentence)
        input_length = len(input_tensor)

        encoder_input = LongTensor(input_tensor).view(1, -1)
        encoder_outputs, encoder_hidden = encoder(encoder_input, LongTensor([input_length]))

        decoder_hidden = encoder_hidden
        decoder_input = LongTensor([SOS_token]).repeat(encoder_input.shape[0], 1)

        decoded_words = []

        for di in range(max_length):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs,
                                                     LongTensor([input_length]))
            # decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
            topv, topi = decoder_output.data.topk(1)
            if topi.item() == EOS_token:
                decoded_words.append('<EOS>')
                break
            else:
                decoded_words.append(output_lang.index2word[topi.item()])

            decoder_input = topi.squeeze().detach().view(1, 1)

        return decoded_words
Exemplo n.º 2
0
    def decode_result(self,
                      decoder_inputs,
                      init_states,
                      memories,
                      target2index,
                      index2target,
                      max_length=50):
        start_decode = Variable(LongTensor([[target2index['<s>']] * 1
                                            ])).transpose(0, 1)

        decodes = []
        embedded = start_decode
        embedd_list = []
        embedd_list.append(target2index['<s>'])

        # while decoded.data.tolist()[0] != target2index['</s>'] and max_length > len(decodes):
        for t in range(max_length):
            _, hidden = self.decode(embedded, init_states, memories)
            softmaxed = F.log_softmax(hidden)
            decodes.append(softmaxed)
            decoded = softmaxed.max(1)[1]
            embedd_list.append(decoded.data.tolist()[0])
            embedded = Variable(LongTensor([embedd_list * 1]))
            if index2target[decoded.data.tolist()[0]] == '</s>' or (
                    t != 0
                    and index2target[decoded.data.tolist()[0]] == '<s>'):
                break
            # context, alpha = self.Attention(hidden, decoder_inputs)
            # attentions.append(alpha.squeeze(1))
        print(embedded.size())

        return torch.cat(decodes).max(1)[1]
Exemplo n.º 3
0
    def decode(self, h, mask):  # Viterbi decoding
        # initialize backpointers and viterbi variables in log space
        bptr = LongTensor()
        score = Tensor(BATCH_SIZE, self.num_tags).fill_(-10000.)
        score[:, SOS_IDX] = 0.

        for t in range(h.size(1)):  # recursion through the sequence
            mask_t = mask[:, t].unsqueeze(1)
            score_t = score.unsqueeze(1) + self.trans  # [B, 1, C] -> [B, C, C]
            score_t, bptr_t = score_t.max(2)  # best previous scores and tags
            score_t += h[:, t]  # plus emission scores
            bptr = torch.cat((bptr, bptr_t.unsqueeze(1)), 1)
            score = score_t * mask_t + score * (1 - mask_t)
        score += self.trans[EOS_IDX]
        best_score, best_tag = torch.max(score, 1)

        # back-tracking
        bptr = bptr.tolist()
        best_path = [[i] for i in best_tag.tolist()]
        for b in range(BATCH_SIZE):
            x = best_tag[b]  # best tag
            y = int(mask[b].sum().item())
            for bptr_t in reversed(bptr[b][:y]):
                x = bptr_t[x]
                best_path[b].append(x)
            best_path[b].pop()
            best_path[b].reverse()

        return best_path
def batch_generator(*arrays, batch_size=32, should_shuffle=False):
    input, target = arrays
    if should_shuffle:
        from sklearn.utils import shuffle
        input, target = shuffle(input, target)

    num_instances = len(input)
    batch_count = int(numpy.ceil(num_instances / batch_size))
    progress = tqdm.tqdm(total=num_instances)
    input_length_in_words = numpy.array([len(seq) for seq in input], dtype=numpy.int32)
    target_length_in_words = numpy.array([len(seq) for seq in target], dtype=numpy.int32)

    for idx in range(batch_count):
        startIdx = idx * batch_size
        endIdx = (idx + 1) * batch_size if (idx + 1) * batch_size < num_instances else num_instances

        batch_input_lengths = input_length_in_words[startIdx:endIdx]
        input_maxlength = batch_input_lengths.max()
        input_lengths_argsort = \
            numpy.argsort(batch_input_lengths)[::-1].copy()  # without the copy torch complains about negative strides

        batch_target_lengths = target_length_in_words[startIdx:endIdx]
        target_maxlength = batch_target_lengths.max()

        batch_input = LongTensor([input_seq + (PAD_IDX,) * (input_maxlength - len(input_seq))
                                  for input_seq in input[startIdx:endIdx]])

        batch_target = LongTensor([target_seq + (PAD_IDX,) * (target_maxlength - len(target_seq))
                                   for target_seq in target[startIdx:endIdx]])

        progress.update(len(batch_input_lengths))
        yield batch_input[input_lengths_argsort], LongTensor(batch_input_lengths)[input_lengths_argsort], \
              batch_target[input_lengths_argsort], LongTensor(batch_target_lengths)[input_lengths_argsort]

    progress.close()
Exemplo n.º 5
0
def batch_generator(*arrays, batch_size=32):
    word_id_lists, tag_id_lists, char_id_lists, seq_length_in_words, word_lengths = arrays
    word_id_lists, tag_id_lists, char_id_lists, seq_length_in_words, word_lengths = \
        shuffle(word_id_lists, tag_id_lists, char_id_lists, seq_length_in_words, word_lengths)

    num_instances = len(word_id_lists)
    batch_count = int(numpy.ceil(num_instances / batch_size))
    from tqdm import tqdm
    prog = tqdm(total=num_instances)
    for idx in range(batch_count):
        startIdx = idx * batch_size
        endIdx = (idx + 1) * batch_size if (
            idx + 1) * batch_size < num_instances else num_instances
        batch_lengths = seq_length_in_words[startIdx:endIdx]
        batch_maxlen = batch_lengths.max()
        argsort = numpy.argsort(batch_lengths)[::-1].copy(
        )  # without the copy torch complains about negative strides
        char_batch = numpy.array(char_id_lists[startIdx:endIdx])[argsort]

        # make each sentence in batch contain same number of words
        char_batch = [
            sentence + ((0, ), ) * (batch_maxlen - len(sentence))
            for sentence in char_batch
        ]

        word_lengths = [
            len(word) for sentence in char_batch for word in sentence
        ]
        max_word_length = max(word_lengths)
        # make each word in batch contain same number of chars
        chars = [
            word + (0, ) * (max_word_length - len(word))
            for sentence in char_batch for word in sentence
        ]
        chars = LongTensor(chars)

        words = LongTensor([
            word_ids + (PAD_IDX, ) * (batch_maxlen - len(word_ids))
            for word_ids in word_id_lists[startIdx:endIdx]
        ])
        tags = LongTensor([
            tag_ids + (PAD_IDX, ) * (batch_maxlen - len(tag_ids))
            for tag_ids in tag_id_lists[startIdx:endIdx]
        ])

        prog.update(len(batch_lengths))
        yield words[argsort], chars, tags[argsort]
    prog.close()
Exemplo n.º 6
0
 def detail_forward(self, incoming):
     i = incoming.state.num
     incoming.post = Storage()
     incoming.post.embedding = self.embLayer(LongTensor(incoming.data.post))
     incoming.resp = Storage()
     incoming.wiki = Storage()
     incoming.wiki.embedding = self.embLayer(incoming.data.wiki[:, i])
     incoming.resp.embLayer = self.embLayer
Exemplo n.º 7
0
    def decode(self, h, mask):  # Viterbi decoding
        # initialize backpointers and viterbi variables in log space
        backpointers = LongTensor()
        batch_size = h.shape[0]
        delta = Tensor(batch_size, self.num_tags).fill_(-10000.)
        delta[:, START_TAG_IDX] = 0.

        # TODO: is adding stop tag within loop needed at all???
        # pro argument: yes, backpointers needed at every step - to be checked
        for t in range(h.size(1)):  # iterate through the sequence
            # backpointers and viterbi variables at this timestep
            mask_t = mask[:, t].unsqueeze(1)
            # TODO: maybe unsqueeze transition explicitly for 0 dim for clarity
            next_tag_var = delta.unsqueeze(1) + self.transition  # B x 1 x S + S x S
            delta_t, backpointers_t = next_tag_var.max(2)
            backpointers = torch.cat((backpointers, backpointers_t.unsqueeze(1)), 1)
            delta_next = delta_t + h[:, t]  # plus emission scores
            delta = mask_t * delta_next + (1 - mask_t) * delta  # TODO: check correctness
            # for those that end here add score for transitioning to stop tag
            if t + 1 < h.size(1):
                # mask_next = mask[:, t + 1].unsqueeze(1)
                # ending = mask_next.eq(0.).float().expand(batch_size, self.num_tags)
                # delta += ending * self.transition[STOP_TAG_IDX].unsqueeze(0)
                # or
                ending_here = (mask[:, t].eq(1.) * mask[:, t + 1].eq(0.)).view(1, -1).float()
                delta += ending_here.transpose(0, 1).mul(self.transition[STOP_TAG_IDX])  # add outer product of two vecs
                # TODO: check equality of these two again

        # TODO: should we add transition values for getting in stop state only for those that end here?
        # TODO: or to all?
        delta += mask[:, -1].view(1, -1).float().transpose(0, 1).mul(self.transition[STOP_TAG_IDX])
        best_score, best_tag = torch.max(delta, 1)

        # back-tracking
        backpointers = backpointers.tolist()
        best_path = [[i] for i in best_tag.tolist()]
        for idx in range(batch_size):
            prev_best_tag = best_tag[idx]  # best tag id for single instance
            length = int(scalar(mask[idx].sum()))  # length of instance
            for backpointers_t in reversed(backpointers[idx][:length]):
                prev_best_tag = backpointers_t[prev_best_tag]
                best_path[idx].append(prev_best_tag)
            best_path[idx].pop()  # remove start tag
            best_path[idx].reverse()

        return best_path
        def nextStep(x, flag=None, regroup=None):
            nonlocal step, batch_size, top_k
            # regroup: batch * top_k
            regroup = regroup + LongTensor(list(
                range(batch_size))).unsqueeze(1) * top_k
            regroup = regroup.reshape(-1)
            x = x.reshape(batch_size * top_k, -1)
            x = step(x, regroup=regroup)
            x = x.reshape(batch_size, top_k, -1)

            return x
def positional_embeddings(seqlen, first_emb, reverse_emb=None, length=None):
    # first_emb:   max_length * embedding
    encodings = first_emb.unsqueeze(1).expand(-1, seqlen, -1).\
     gather(0, cuda(torch.arange(seqlen)).unsqueeze(-1).expand(-1, first_emb.shape[1]).unsqueeze(0))[0]

    if length is None:
        assert reverse_emb is None
        return encodings.unsqueeze(0)
    else:
        batch_size = len(length)
        reversed_id = np.zeros((batch_size, seqlen))
        for i, l in enumerate(length):
            reversed_id[i, :l] = np.arange(l - 1, -1, -1)
        reversed_id = LongTensor(reversed_id)
        encodings_reversed = reverse_emb.unsqueeze(0).unsqueeze(2).expand(batch_size, -1, seqlen, -1).\
              gather(1, reversed_id.unsqueeze(1).unsqueeze(-1).expand(-1, -1, -1, reverse_emb.shape[-1]))[:, 0]
        return torch.cat([
            encodings.unsqueeze(0).expand(batch_size, -1, -1),
            encodings_reversed
        ],
                         dim=2)
Exemplo n.º 10
0
    def test(self, key):
        args = self.param.args
        dm = self.param.volatile.dm

        metric1 = dm.get_teacher_forcing_metric()
        batch_num, batches = self.get_batches(dm, key)
        logging.info("eval teacher-forcing")
        for incoming in tqdm.tqdm(batches, total=batch_num):
            incoming.args = Storage()
            incoming.args.sampling_proba = 1.
            with torch.no_grad():
                self.net.forward(incoming)
                gen_log_prob = nn.functional.log_softmax(
                    incoming.gen.w_pro, -1)
            data = incoming.data
            data.resp_allvocabs = LongTensor(incoming.data.resp_allvocabs)
            data.resp_length = incoming.data.resp_length
            data.gen_log_prob = gen_log_prob.transpose(1, 0)
            metric1.forward(data)
        res = metric1.close()

        metric2 = dm.get_inference_metric()
        batch_num, batches = self.get_batches(dm, key)
        logging.info("eval free-run")
        for incoming in tqdm.tqdm(batches, total=batch_num):
            incoming.args = Storage()
            with torch.no_grad():
                self.net.detail_forward(incoming)
            data = incoming.data
            data.gen = incoming.gen.w_o.detach().cpu().numpy().transpose(1, 0)
            metric2.forward(data)
        res.update(metric2.close())

        if not os.path.exists(args.out_dir):
            os.makedirs(args.out_dir)
        filename = args.out_dir + "/%s_%s.txt" % (args.name, key)

        with codecs.open(filename, 'w', encoding='utf8') as f:
            logging.info("%s Test Result:", key)
            for key, value in res.items():
                if isinstance(value, float) or isinstance(value, str):
                    logging.info("\t{}:\t{}".format(key, value))
                    f.write("{}:\t{}\n".format(key, value))
            for i in range(len(res['post'])):
                f.write("post:\t%s\n" % " ".join(res['post'][i]))
                f.write("resp:\t%s\n" % " ".join(res['resp'][i]))
                f.write("gen:\t%s\n" % " ".join(res['gen'][i]))
            f.flush()
        logging.info("result output to %s.", filename)
        return {
            key: val
            for key, val in res.items() if isinstance(val, (str, int, float))
        }
Exemplo n.º 11
0
def sample_image(args, generator, n_row, batches_done):
    """Saves a grid of generated digits ranging from 0 to n_classes"""
    # Sample noise
    z = Variable(
        FloatTensor(np.random.normal(0, 1, (n_row**2, args.latent_dim))))
    # Get labels for the n rows
    labels = np.array([num for _ in range(n_row) for num in range(n_row)])
    labels = Variable(LongTensor(labels))
    gen_imgs = generator(z, labels)
    save_image(gen_imgs.data,
               "images/%d.png" % batches_done,
               nrow=n_row,
               normalize=True)
Exemplo n.º 12
0
def train(
        encoder_input,
        input_lengths,
        target_tensor,
        target_lengths,
        encoder,
        decoder,
        encoder_optimizer,
        decoder_optimizer,
        criterion):
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    loss = 0

    encoder_outputs, encoder_hidden = encoder(encoder_input, input_lengths)
    decoder_hidden = encoder_hidden
    decoder_input = LongTensor([SOS_token]).repeat(encoder_input.shape[0], 1)  # one for each instance in batch

    # use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
    use_teacher_forcing = False

    if use_teacher_forcing:  # TODO: adapt teacher forcing
        # Teacher forcing: Feed the target as the next input
        for idx in range(target_lengths.shape[1]):
            decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_outputs)
            loss += criterion(decoder_output, target_tensor[idx])
            decoder_input = target_tensor[idx]  # Teacher forcing
    else:
        # Without teacher forcing: use its own predictions as the next input
        max_target_length = target_lengths.max().item()
        target_lengths_copy = target_lengths.clone()
        for idx in range(max_target_length):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs, input_lengths)
            mask = target_lengths_copy > PAD_IDX
            target_lengths_copy -= 1
            masked_output = decoder_output * mask.unsqueeze(1).float()
            topv, topi = masked_output.topk(1)
            decoder_input = topi.squeeze().detach()  # detach from history as input
            loss += criterion(masked_output[mask], target_tensor[:, idx][mask])
            # or alternative below
            # for instance_idx, target_word in enumerate(target_tensor[:, idx]):
            #     if idx < target_lengths[instance_idx]:
            #         loss += criterion(masked_output[instance_idx].view(1, -1),
            #                           target_word.view(1))

    loss.backward()

    encoder_optimizer.step()
    decoder_optimizer.step()

    return loss.item() / target_lengths.sum().item()
Exemplo n.º 13
0
 def forward(self, incoming):
     '''
     inp: data
     output: post
     '''
     i = incoming.state.num
     incoming.post = Storage()
     incoming.post.embedding = self.drop(
         self.embLayer(LongTensor(incoming.data.post[:, i])))
     incoming.resp = Storage()
     incoming.resp.embedding = self.drop(
         self.embLayer(incoming.data.resp[:, i]))
     incoming.wiki = Storage()
     incoming.wiki.embedding = self.drop(
         self.embLayer(incoming.data.wiki[:, i]))
     incoming.resp.embLayer = self.embLayer
Exemplo n.º 14
0
    def __predict_sentence(self, src_batch):
        """
        predict sentence
        :param src_batch: get the source sentence
        :return:
        """
        hyp_batch = ''

        inputs = prepare_sequence(['<s>'] + src_batch + ['</s>'], self.data_model.source2index).view(1, -1)
        start_decode = Variable(LongTensor([[self.data_model.target2index['<s>']] * inputs.size(1)]))
        show_preds = self.qrnn(inputs, [inputs.size(1)], start_decode)
        outputs = torch.max(show_preds, dim=1)[1].view(len(inputs), -1)
        for pred in outputs.data.tolist():
            for each_pred in pred:
                hyp_batch += self.data_model.index2target[each_pred]
        hyp_batch = hyp_batch.replace('<s>', '')
        hyp_batch = hyp_batch.replace('</s>', '')
        return hyp_batch
Exemplo n.º 15
0
 def score(self, h, y, mask):  # calculate the score of a given sequence
     batch_size = h.shape[0]
     score = Tensor(batch_size).fill_(0.)
     # TODO: maybe instead of unsqueezing following two separately do it after sum in line for score calculation
     # TODO: check if unsqueezing needed at all
     h = h.unsqueeze(3)
     transition = self.transition.unsqueeze(2)
     y = torch.cat([LongTensor([START_TAG_IDX]).view(1, -1).expand(batch_size, 1), y], 1)  # add start tag to begin
     # TODO: the loop can be vectorized, probably
     for t in range(h.size(1)):  # iterate through the sequence
         mask_t = mask[:, t]
         emission = torch.cat([h[i, t, y[i, t + 1]] for i in range(batch_size)])
         transition_t = torch.cat([transition[seq[t + 1], seq[t]] for seq in y])
         score += (emission + transition_t) * mask_t
     # get transitions from last tags to stop tag: use gather to get last time step
     lengths = mask.sum(1).long()
     indices = lengths.unsqueeze(1)  # we can safely use lengths as indices, because we prepended start tag to y
     last_tags = y.gather(1, indices).squeeze()
     score += self.transition[STOP_TAG_IDX, last_tags]
     return score
Exemplo n.º 16
0
    def detail_forward(self, incoming):
        incoming.hidden = hidden = Storage()
        # incoming.post.embedding : batch * sen_num * length * vec_dim
        # post_length : batch * sen_num
        raw_post = incoming.post.embedding
        raw_post_length = LongTensor(incoming.data.post_length)
        incoming.state.valid_sen = torch.sum(torch.nonzero(raw_post_length), 1)
        raw_reverse = torch.cumsum(torch.gt(raw_post_length, 0), 0) - 1
        incoming.state.reverse_valid_sen = raw_reverse * torch.ge(
            raw_reverse, 0).to(torch.long)
        valid_sen = incoming.state.valid_sen
        incoming.state.valid_num = valid_sen.shape[0]

        post = torch.index_select(raw_post, 0, valid_sen).transpose(
            0, 1)  # [length, valid_num, vec_dim]
        post_length = torch.index_select(
            raw_post_length, 0, valid_sen).cpu().numpy()  # [valid_num]

        hidden.h, hidden.h_n = self.postGRU.forward(post,
                                                    post_length,
                                                    need_h=True)
        hidden.length = post_length
Exemplo n.º 17
0
    def freerun(self, inp, gen, mode='max'):
        batch_size = inp.batch_size
        dm = self.param.volatile.dm

        first_emb = inp.embLayer(LongTensor([dm.go_id])).repeat(batch_size, 1)
        gen.w_pro = []
        gen.w_o = []
        gen.emb = []
        flag = zeros(batch_size).byte()
        EOSmet = []

        next_emb = first_emb
        gru_h = self.GRULayer.getInitialParameter(batch_size)[0]
        for _ in range(self.args.max_sen_length):
            now = next_emb
            gru_h = self.GRULayer.cell_forward(now, gru_h)
            w = self.wLinearLayer(gru_h)
            gen.w_pro.append(w)
            if mode == "max":
                w_o = torch.argmax(w[:, self.start_generate_id:],
                                   dim=1) + self.start_generate_id
                next_emb = inp.embLayer(w_o)
            elif mode == "gumbel":
                w_onehot, w_o = gumbel_max(w[:, self.start_generate_id:], 1, 1)
                w_o = w_o + self.start_generate_id
                next_emb = torch.sum(
                    torch.unsqueeze(w_onehot, -1) * inp.embLayer.weight[2:], 1)
            gen.w_o.append(w_o)
            gen.emb.append(next_emb)

            EOSmet.append(flag)
            flag = flag | (w_o == dm.eos_id)
            if torch.sum(flag).detach().cpu().numpy() == batch_size:
                break

        EOSmet = 1 - torch.stack(EOSmet)
        gen.w_o = torch.stack(gen.w_o) * EOSmet.long()
        gen.emb = torch.stack(gen.emb) * EOSmet.float().unsqueeze(-1)
        gen.length = torch.sum(EOSmet, 0).detach().cpu().numpy()
Exemplo n.º 18
0
def train(generator, discriminator, dataloader, args, cuda, adversarial_loss,
          auxiliary_loss):

    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=args.lr,
                                   betas=(args.b1, args.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=args.lr,
                                   betas=(args.b1, args.b2))

    for epoch in range(args.n_epochs):
        for i, (imgs, labels) in enumerate(dataloader):

            batch_size = imgs.shape[0]

            # Adversarial ground truths
            valid = Variable(FloatTensor(batch_size, 1).fill_(1.0),
                             requires_grad=False)
            fake = Variable(FloatTensor(batch_size, 1).fill_(0.0),
                            requires_grad=False)

            # Configure input
            real_imgs = Variable(imgs.type(FloatTensor))
            labels = Variable(labels.type(LongTensor))

            # -----------------
            #  Train Generator
            # -----------------

            optimizer_G.zero_grad()

            # Sample noise as generator input
            z = Variable(
                FloatTensor(
                    np.random.normal(0, 1, (batch_size, args.latent_dim))))
            gen_labels = Variable(
                LongTensor(np.random.randint(0, args.n_classes, batch_size)))

            # Generate a batch of images
            gen_imgs = generator(z, gen_labels)

            # Loss measures generator's ability to fool the discriminator
            validity, pred_label = discriminator(gen_imgs)
            g_loss = 0.5 * adversarial_loss(validity, valid) + auxiliary_loss(
                pred_label, gen_labels)

            g_loss.backward()
            optimizer_G.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------

            optimizer_D.zero_grad()

            # Loss for real images
            real_pred, real_aux = discriminator(real_imgs)
            d_real_loss = (adversarial_loss(real_pred, valid) +
                           auxiliary_loss(real_aux, labels)) / 2

            # Loss for fake images
            fake_pred, fake_aux = discriminator(gen_imgs.detach())
            d_fake_loss = (adversarial_loss(fake_pred, fake) +
                           auxiliary_loss(fake_aux, gen_labels)) / 2

            # Measure discriminator's ability to classify real from generated samples
            d_loss = (d_real_loss + d_fake_loss) / 2

            # Calculate discriminator accuracy
            pred = np.concatenate(
                [real_aux.data.cpu().numpy(),
                 fake_aux.data.cpu().numpy()],
                axis=0)
            gt = np.concatenate(
                [labels.data.cpu().numpy(),
                 gen_labels.data.cpu().numpy()],
                axis=0)
            d_acc = np.mean(np.argmax(pred, axis=1) == gt)

            d_loss.backward()
            optimizer_D.step()

            print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" %
                  (epoch, args.n_epochs, i, len(dataloader), d_loss.item(),
                   g_loss.item()))

            batches_done = epoch * len(dataloader) + i
            if batches_done % args.sample_interval == 0:
                sample_image(args,
                             generator,
                             n_row=10,
                             batches_done=batches_done)
Exemplo n.º 19
0
    def freerun(self, inp, gen, mode='max'):
        batch_size = inp.batch_size
        dm = self.param.volatile.dm

        first_emb = inp.embLayer(LongTensor([dm.go_id])).repeat(batch_size, 1)
        gen.w_pro = []
        gen.w_o = []
        gen.emb = []
        flag = zeros(batch_size).byte()
        EOSmet = []

        inp.wiki_sen = inp.wiki_sen[:, :inp.wiki_hidden.shape[1]]
        copyHead = zeros(1, inp.wiki_sen.shape[0], inp.wiki_hidden.shape[1],
                         self.param.volatile.dm.vocab_size).scatter_(
                             3,
                             torch.unsqueeze(torch.unsqueeze(inp.wiki_sen, 0),
                                             3), 1)
        wikiState = torch.transpose(
            torch.tanh(self.wCopyLinear(inp.wiki_hidden)), 0, 1)

        next_emb = first_emb
        gru_h = inp.init_h
        gen.p = []

        wiki_cv = inp.wiki_cv  # valid_num * (2 * eh_size)

        for _ in range(self.args.max_sent_length):
            now = torch.cat([next_emb, wiki_cv], dim=-1)

            gru_h = self.GRULayer.cell_forward(now, gru_h)
            w = self.wLinearLayer(gru_h)
            w = torch.clamp(w, max=5.0)
            vocab_p = torch.exp(w)
            copyW = torch.exp(
                torch.clamp(torch.unsqueeze(
                    (torch.sum(torch.unsqueeze(gru_h, 0) * wikiState,
                               -1).transpose_(0, 1)), 1),
                            max=5.0))  # batch * 1 * wiki_len
            copy_p = torch.matmul(copyW, copyHead).squeeze()

            p = vocab_p + copy_p + 1e-10
            p = p / torch.unsqueeze(torch.sum(p, 1), 1)
            p = torch.clamp(p, 1e-10, 1.0)
            gen.p.append(p)

            if mode == "max":
                w_o = torch.argmax(p[:, self.start_generate_id:],
                                   dim=1) + self.start_generate_id
                next_emb = inp.embLayer(w_o)
            elif mode == "gumbel":
                w_onehot, w_o = gumbel_max(p[:, self.start_generate_id:], 1, 1)
                w_o = w_o + self.start_generate_id
                next_emb = torch.sum(
                    torch.unsqueeze(w_onehot, -1) * inp.embLayer.weight[2:], 1)
            gen.w_o.append(w_o)
            gen.emb.append(next_emb)

            EOSmet.append(flag)
            flag = flag | (w_o == dm.eos_id).byte()
            if torch.sum(flag).detach().cpu().numpy() == batch_size:
                break

        EOSmet = 1 - torch.stack(EOSmet)
        gen.w_o = torch.stack(gen.w_o) * EOSmet.long()
        gen.emb = torch.stack(gen.emb) * EOSmet.float().unsqueeze(-1)
        gen.length = torch.sum(EOSmet, 0).detach().cpu().numpy()
        gen.h_n = gru_h
Exemplo n.º 20
0
    def detail_forward_disentangle(self, incoming):
        incoming.conn = conn = Storage()
        index = incoming.state.num
        valid_sen = incoming.state.valid_sen
        valid_wiki_h_n1 = torch.index_select(
            incoming.wiki_hidden.h_n1, 1,
            valid_sen)  # [wiki_sen_num, valid_num, 2 * eh_size]
        valid_wiki_sen = torch.index_select(incoming.wiki_sen, 0, valid_sen)
        valid_wiki_h1 = torch.index_select(incoming.wiki_hidden.h1, 1,
                                           valid_sen)
        atten_label = torch.index_select(incoming.data.atten[:, index], 0,
                                         valid_sen)  # valid_num
        valid_wiki_num = torch.index_select(
            LongTensor(incoming.data.wiki_num[:, index]), 0,
            valid_sen)  # valid_num

        reverse_valid_sen = incoming.state.reverse_valid_sen
        self.beta = torch.sum(valid_wiki_h_n1 * incoming.hidden.h_n, dim=2)
        self.beta = torch.t(self.beta)  # [valid_num, wiki_len]

        mask = torch.arange(
            self.beta.shape[1], device=self.beta.device).long().expand(
                self.beta.shape[0],
                self.beta.shape[1]).transpose(0,
                                              1)  # [wiki_sen_num, valid_num]
        expand_wiki_num = valid_wiki_num.unsqueeze(0).expand_as(
            mask)  # [wiki_sen_num, valid_num]
        reverse_mask = (expand_wiki_num <=
                        mask).float()  # [wiki_sen_num, valid_num]

        if index > 0:
            wiki_hidden = incoming.wiki_hidden
            wiki_num = incoming.data.wiki_num[:, index]  # [batch], numpy array
            wiki_hidden.h2, wiki_hidden.h_n2 = self.compareGRU.forward(
                wiki_hidden.h_n1, wiki_num, need_h=True)
            valid_wiki_h2 = torch.index_select(
                wiki_hidden.h2, 1,
                valid_sen)  # wiki_len * valid_num * (2 * eh_size)

            tilde_wiki_list = []
            for i in range(self.last_wiki.size(-1)):
                last_wiki = torch.index_select(self.last_wiki[:, :, i], 0,
                                               valid_sen).unsqueeze(
                                                   0)  # 1, valid_num, (2 * eh)
                tilde_wiki = torch.tanh(
                    self.tilde_linear(
                        torch.cat([
                            last_wiki - valid_wiki_h2,
                            last_wiki * valid_wiki_h2
                        ],
                                  dim=-1)))
                tilde_wiki_list.append(
                    tilde_wiki.unsqueeze(-1) * self.hist_weights[i])
            tilde_wiki = torch.cat(tilde_wiki_list, dim=-1).sum(
                dim=-1)  # wiki_len * valid_num * (2 * eh_size)

            query = self.attn_query(tilde_wiki)  # [1, valid_num, hidden]
            key = self.attn_key(
                torch.cat([valid_wiki_h2, tilde_wiki],
                          dim=-1))  # [wiki_sen_num, valid_num, hidden]
            atten_sum = self.attn_v(torch.tanh(query + key)).squeeze(
                -1)  # [wiki_sen_num, valid_num]

            self.beta = self.beta[:, :atten_sum.shape[0]] + torch.t(
                atten_sum)  #

        if index == 0:
            incoming.result.atten_loss = self.atten_lossCE(
                self.beta,  #self.alpha.t().log(),
                atten_label)
        else:
            incoming.result.atten_loss += self.atten_lossCE(
                self.beta,  #self.alpha.t().log(),
                atten_label)

        self.beta = torch.t(
            self.beta) - 1e10 * reverse_mask[:self.beta.shape[1]]
        self.alpha = self.wiki_atten(self.beta)  # wiki_len * valid_num
        incoming.acc.prob.append(
            torch.index_select(
                self.alpha.t(), 0,
                incoming.state.reverse_valid_sen).cpu().tolist())
        atten_indices = torch.argmax(self.alpha, 0)  # valid_num
        alpha = zeros(self.beta.t().shape).scatter_(1,
                                                    atten_indices.unsqueeze(1),
                                                    1)
        alpha = torch.t(alpha)
        wiki_cv = torch.sum(valid_wiki_h_n1[:alpha.shape[0]] *
                            alpha.unsqueeze(2),
                            dim=0)  # valid_num * (2 * eh_size)
        conn.wiki_cv = wiki_cv
        conn.init_h = self.initLinearLayer(
            torch.cat([incoming.hidden.h_n, wiki_cv], 1))

        if index == 0:
            self.last_wiki = torch.index_select(wiki_cv, 0,
                                                reverse_valid_sen).unsqueeze(
                                                    -1)  # [batch, 2 * eh_size]
        else:
            self.last_wiki = torch.cat([
                torch.index_select(wiki_cv, 0,
                                   reverse_valid_sen).unsqueeze(-1),
                self.last_wiki[:, :, :self.hist_len - 1]
            ],
                                       dim=-1)

        incoming.acc.label.append(
            torch.index_select(atten_label, 0,
                               reverse_valid_sen).cpu().tolist())
        incoming.acc.pred.append(
            torch.index_select(atten_indices, 0,
                               reverse_valid_sen).cpu().tolist())

        atten_indices = atten_indices.unsqueeze(1)
        atten_indices = torch.cat([
            torch.arange(atten_indices.shape[0]).unsqueeze(1),
            atten_indices.cpu()
        ], 1)  # valid_num * 2
        valid_wiki_h1 = torch.transpose(
            valid_wiki_h1, 0,
            1)  # valid_num * wiki_sen_len * wiki_len * (2 * eh_size)
        valid_wiki_h1 = torch.transpose(
            valid_wiki_h1, 1,
            2)  # valid_num * wiki_len * wiki_sen_len * (2 * eh_size)
        conn.selected_wiki_h = valid_wiki_h1[atten_indices.chunk(
            2, 1)].squeeze(1)  # valid_num * wiki_sen_len * (2 * eh_size)
        conn.selected_wiki_sen = valid_wiki_sen[atten_indices.chunk(
            2, 1)].squeeze(1)  # valid_num * wiki_sen_len
Exemplo n.º 21
0
    def forward(self, incoming):
        incoming.conn = conn = Storage()
        index = incoming.state.num
        valid_sen = incoming.state.valid_sen
        valid_wiki_h_n1 = torch.index_select(
            incoming.wiki_hidden.h_n1, 1,
            valid_sen)  # [wiki_sen_num, valid_num, 2 * eh_size]
        valid_wiki_sen = torch.index_select(
            incoming.wiki_sen, 0,
            valid_sen)  # [valid_num, wiki_sen_num, wiki_sen_len]
        valid_wiki_h1 = torch.index_select(
            incoming.wiki_hidden.h1, 1,
            valid_sen)  # [wiki_sen_len, valid_num, wiki_sen_num, 2 * eh_size]
        atten_label = torch.index_select(incoming.data.atten[:, index], 0,
                                         valid_sen)  # valid_num
        valid_wiki_num = torch.index_select(
            LongTensor(incoming.data.wiki_num[:, index]), 0,
            valid_sen)  # valid_num

        if index == 0:
            tilde_wiki = zeros(1, 1, 2 * self.args.eh_size) * ones(
                valid_wiki_h_n1.shape[0], valid_wiki_h_n1.shape[1], 1)
        else:
            wiki_hidden = incoming.wiki_hidden
            wiki_num = incoming.data.wiki_num[:, index]  # [batch], numpy array
            wiki_hidden.h2, wiki_hidden.h_n2 = self.compareGRU.forward(
                wiki_hidden.h_n1, wiki_num, need_h=True)
            valid_wiki_h2 = torch.index_select(
                wiki_hidden.h2, 1,
                valid_sen)  # wiki_len * valid_num * (2 * eh_size)

            tilde_wiki_list = []
            for i in range(self.last_wiki.size(-1)):
                last_wiki = torch.index_select(self.last_wiki[:, :, i], 0,
                                               valid_sen).unsqueeze(
                                                   0)  # 1, valid_num, (2 * eh)
                tilde_wiki = torch.tanh(
                    self.tilde_linear(
                        torch.cat([
                            last_wiki - valid_wiki_h2,
                            last_wiki * valid_wiki_h2
                        ],
                                  dim=-1)))
                tilde_wiki_list.append(
                    tilde_wiki.unsqueeze(-1) * self.hist_weights[i])
            tilde_wiki = torch.cat(tilde_wiki_list, dim=-1).sum(dim=-1)

        query = self.attn_query(incoming.hidden.h_n)  # [valid_num, hidden]
        key = self.attn_key(
            torch.cat([valid_wiki_h_n1[:tilde_wiki.shape[0]], tilde_wiki],
                      dim=-1))  # [wiki_sen_num, valid_num, hidden]
        atten_sum = self.attn_v(torch.tanh(query + key)).squeeze(
            -1)  # [wiki_sen_num, valid_num]
        beta = atten_sum.t()  # [valid_num, wiki_len]

        mask = torch.arange(beta.shape[1], device=beta.device).long().expand(
            beta.shape[0],
            beta.shape[1]).transpose(0, 1)  # [wiki_sen_num, valid_num]
        expand_wiki_num = valid_wiki_num.unsqueeze(0).expand_as(
            mask)  # [wiki_sen_num, valid_num]
        reverse_mask = (expand_wiki_num <=
                        mask).float()  # [wiki_sen_num, valid_num]

        if index == 0:
            incoming.result.atten_loss = self.atten_lossCE(beta, atten_label)
        else:
            incoming.result.atten_loss += self.atten_lossCE(beta, atten_label)

        golden_alpha = zeros(beta.shape).scatter_(1, atten_label.unsqueeze(1),
                                                  1)
        golden_alpha = torch.t(golden_alpha).unsqueeze(2)
        wiki_cv = torch.sum(valid_wiki_h_n1[:golden_alpha.shape[0]] *
                            golden_alpha,
                            dim=0)  # valid_num * (2 * eh_size)
        conn.wiki_cv = wiki_cv
        conn.init_h = self.initLinearLayer(
            torch.cat([incoming.hidden.h_n, wiki_cv], 1))

        reverse_valid_sen = incoming.state.reverse_valid_sen
        if index == 0:
            self.last_wiki = torch.index_select(wiki_cv, 0,
                                                reverse_valid_sen).unsqueeze(
                                                    -1)  # [batch, 2 * eh_size]
        else:
            self.last_wiki = torch.cat([
                torch.index_select(wiki_cv, 0,
                                   reverse_valid_sen).unsqueeze(-1),
                self.last_wiki[:, :, :self.hist_len - 1]
            ],
                                       dim=-1)

        atten_indices = atten_label.unsqueeze(1)  # valid_num * 1
        atten_indices = torch.cat([
            torch.arange(atten_indices.shape[0]).unsqueeze(1),
            atten_indices.cpu()
        ], 1)  # valid_num * 2
        valid_wiki_h1 = torch.transpose(
            valid_wiki_h1, 0,
            1)  # valid_num * wiki_sen_len * wiki_len * (2 * eh_size)
        valid_wiki_h1 = torch.transpose(
            valid_wiki_h1, 1,
            2)  # valid_num * wiki_len * wiki_sen_len * (2 * eh_size)
        conn.selected_wiki_h = valid_wiki_h1[atten_indices.chunk(2,
                                                                 1)].squeeze(1)
        conn.selected_wiki_sen = valid_wiki_sen[atten_indices.chunk(
            2, 1)].squeeze(1)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-t",
        "--train_data",
        metavar="train_data",
        type=str,
        default='../data/processed/source_replay_twitter_data.txt',
        dest="train_data",
        help="set the training data ")
    parser.add_argument("-e",
                        "--embedding_size",
                        metavar="embedding_size",
                        type=int,
                        default=50,
                        dest="embedding_size",
                        help="set the embedding size ")
    parser.add_argument("-H",
                        "--hidden_size",
                        metavar="hidden_size",
                        type=int,
                        default=512,
                        dest="hidden_size",
                        help="set the hidden size ")
    parser.add_argument("-f",
                        "--fine_tune_model_name",
                        metavar="fine_tune_model_name",
                        type=str,
                        default='../models/glove_model_40.pth',
                        dest="fine_tune_model_name",
                        help="set the fine tune model name ")
    parser.add_argument("-n",
                        "--num_layers",
                        metavar="num_layers",
                        type=int,
                        default=2,
                        dest="num_layers",
                        help="set the layer number")
    parser.add_argument("-k",
                        "--kernel_size",
                        metavar="kernel_size",
                        type=int,
                        default=2,
                        dest="kernel_size",
                        help="set the kernel_size")
    batch_size = 64
    args = parser.parse_args()
    test_data_loader_attention = DataLoaderAttention(file_name=args.train_data)
    source2index, index2source, target2index, index2target, train_data = \
        test_data_loader_attention.load_data()

    encoder_model_name = '../models/qrnn_encoder_model_285.pth'
    decoder_model_name = '../models/qrnn_decoder_model_285.pth'
    proj_linear_model_name = '../models/qrnn_proj_linear_model_285.pth'

    HIDDEN_SIZE = args.hidden_size
    NUM_LAYERS = args.num_layers
    KERNEL_SIZE = args.kernel_size
    EMBEDDING_SIZE = args.embedding_size
    SOURCE_VOCAB_SIZE = len(source2index)
    TARGET_VOCAB_SIZE = len(target2index)
    ZONE_OUT = 0.0
    TRAINING = False
    DROPOUT = 0.0

    qrnn = QRNNModel(QRNNLayer, NUM_LAYERS, KERNEL_SIZE, HIDDEN_SIZE,
                     EMBEDDING_SIZE, SOURCE_VOCAB_SIZE, TARGET_VOCAB_SIZE,
                     ZONE_OUT, TRAINING, DROPOUT)

    qrnn.encoder = torch.load(encoder_model_name)
    qrnn.decoder = torch.load(decoder_model_name)
    qrnn.proj_linear = torch.load(proj_linear_model_name)

    test = random.choice(train_data)
    inputs = test[0]
    truth = test[1]
    print(inputs)
    print(truth)

    start_decode = Variable(LongTensor([[target2index['<s>']] * truth.size(1)
                                        ]))
    show_preds = qrnn(inputs, [inputs.size(1)], start_decode)
    outputs = torch.max(show_preds, dim=1)[1].view(len(inputs), -1)
    show_sentence(truth, inputs, outputs.data.tolist(), index2source,
                  index2target)
Exemplo n.º 23
0
    def train_attention(self,
                        train_data: list=[],
                        source2index: list=[],
                        target2index: list=[],
                        index2source: list=[],
                        index2target: list=[],
                        encoder_model: object=None,
                        decoder_model: object=None):

        encoder_model.init_weight()
        decoder_model.init_weight()

        encoder_model, decoder_model = self.__fine_tune_weight(encoder_model=encoder_model,
                                                               decoder_model=decoder_model)

        if USE_CUDA:
            encoder_model = encoder_model.cuda()
            decoder_model = decoder_model.cuda()

        loss_function = nn.CrossEntropyLoss(ignore_index=0)
        encoder_optimizer = optim.Adam(encoder_model.parameters(), lr=self.lr)
        decoder_optimizer = optim.Adam(decoder_model.parameters(),
                                       lr=self.lr * self.decoder_learning_rate)
        for epoch in range(self.epoch):
            losses = []
            for i, batch in enumerate(get_batch(self.batch_size, train_data)):
                inputs, targets, input_lengths, target_lengths = \
                    pad_to_batch(batch, source2index, target2index)
                input_mask = torch.cat([Variable(ByteTensor(
                    tuple(map(lambda s: s == 0, t.data))))
                    for t in inputs]).view(inputs.size(0), -1)
                start_decode = Variable(LongTensor([[target2index['<s>']] * targets.size(0)])).transpose(0, 1)
                encoder_model.zero_grad()
                decoder_model.zero_grad()
                output, hidden_c = encoder_model(inputs, input_lengths)
                preds = decoder_model(start_decode, hidden_c, targets.size(1),
                                      output, input_mask, True)
                loss = loss_function(preds, targets.view(-1))
                losses.append(loss.data.tolist()[0])
                loss.backward()
                torch.nn.utils.clip_grad_norm(encoder_model.parameters(), 50.0)
                torch.nn.utils.clip_grad_norm(decoder_model.parameters(), 50.0)
                encoder_optimizer.step()
                decoder_optimizer.step()

                if i % 200 == 0:
                    test = random.choice(train_data)
                    inputs = test[0]

                    output_c, hidden = encoder_model(inputs, [inputs.size(1)])
                    show_preds, _ = decoder_model.decode(hidden, output_c, target2index, index2target)
                    show_preds = decoder_model(start_decode, hidden_c, targets.size(1),
                                          output, input_mask, True)
                    outputs = torch.max(show_preds, dim=1)[1].view(len(inputs), -1)
                    show_sentence(targets, inputs, outputs.data.tolist(), index2source, index2target)
                    print("[%02d/%d] [%03d/%d] mean_loss : %0.2f" %(epoch,
                                                                    self.epoch,
                                                                    i,
                                                                    len(train_data) // self.batch_size,
                                                                    np.mean(losses)))
                    self.__save_model_info(inputs, epoch, losses)
                    torch.save(encoder_model, './../models/encoder_model_{0}.pth'.format(epoch))
                    torch.save(decoder_model, './../models/decoder_model_{0}.pth'.format(epoch))
                    losses=[]
                if self.rescheduled is False and epoch == self.epoch // 2:
                    self.lr = self.lr * 0.01
                    encoder_optimizer = optim.Adam(encoder_model.parameters(), lr=self.lr)
                    decoder_optimizer = optim.Adam(decoder_model.parameters(), lr=self.lr * self.decoder_learning_rate)
                    self.rescheduled = True
        self.writer.export_scalars_to_json("./all_scalars.json")
        self.writer.close()
Exemplo n.º 24
0
    def train_qrnn(self,
                   train_data: list=[],
                   source2index: list=[],
                   target2index: list=[],
                   index2source: list=[],
                   index2target: list=[],
                   qrnn_model: object=None):

        # qrnn_model.encoder, qrnn_model.decoder = self.__fine_tune_weight(
        #     encoder_model=qrnn_model.encoder,
        #     decoder_model=qrnn_model.decoder)
        if USE_CUDA:
            qrnn_model = qrnn_model.cuda()
            encoder_model = qrnn_model.encoder.cuda()
            decoder_model = qrnn_model.decoder.cuda()
            # proj_linear_model = qrnn_model.proj_linear.cuda()

        loss_function = nn.CrossEntropyLoss(ignore_index=0)
        # qrnn_optimizer = optim.Adam(qrnn_model.parameters(), lr=self.lr)
        encoder_optimizer = optim.Adam(encoder_model.parameters(), lr=self.lr)
        decoder_optimizer = optim.Adam(decoder_model.parameters(), lr=self.lr)
        # proj_linear_optimizer = optim.Adam(proj_linear_model.parameters(), lr=self.lr)
        for epoch in range(self.epoch):
            losses = []
            for i, batch in enumerate(get_batch(self.batch_size, train_data)):
                inputs, targets, input_lengths, target_lengths = \
                    pad_to_batch(batch, source2index, target2index)
                qrnn_model.zero_grad()
                start_decode = Variable(LongTensor([[target2index['<s>']] * targets.size(1)]))
                preds = qrnn_model(inputs, input_lengths, start_decode)
                loss = loss_function(preds, targets.view(-1))
                losses.append(loss.data.tolist()[0])
                loss.backward()
                torch.nn.utils.clip_grad_norm(qrnn_model.parameters(), 50.0)
                # qrnn_optimizer.step()
                encoder_optimizer.step()
                decoder_optimizer.step()
                # proj_linear_optimizer.step()

                if i % 200 == 0:
                    test = random.choice(train_data)
                    show_inputs = test[0]
                    show_targets = test[1]
                    show_preds = qrnn_model(inputs, [inputs.size(1)], start_decode)
                    outputs = torch.max(show_preds, dim=1)[1].view(len(inputs), -1)
                    show_sentence(show_targets, show_inputs, outputs.data.tolist(), index2source, index2target)
                    print("[%02d/%d] [%03d/%d] mean_loss : %0.2f" %(epoch,
                                                                    self.epoch,
                                                                    i,
                                                                    len(train_data) // self.batch_size,
                                                                    np.mean(losses)))
                    self.__save_model_info(inputs, epoch, losses)
                    torch.save(qrnn_model.encoder, './../models/test_qrnn_encoder_model_{0}.pth'.format(epoch))
                    torch.save(qrnn_model.decoder, './../models/test_qrnn_decoder_model_{0}.pth'.format(epoch))
                    torch.save(qrnn_model.proj_linear, './../models/test_qrnn_proj_linear_model_{0}.pth'.format(epoch))
                    losses=[]
                if self.rescheduled is False and epoch == self.epoch // 2:
                    self.lr = self.lr * 0.01
                    # qrnn_optimizer = optim.Adam(qrnn_model.parameters(), lr=self.lr)
                    encoder_optimizer = optim.Adam(encoder_model.parameters(), lr=self.lr)
                    decoder_optimizer = optim.Adam(decoder_model.parameters(), lr=self.lr * self.decoder_learning_rate)
                    # proj_linear_optimizer = optim.Adam(proj_linear_model.parameters(), lr=self.lr)
                    self.rescheduled = True
        self.writer.export_scalars_to_json("./all_scalars.json")
        self.writer.close()
    def forward(self, query, key, value, mask=None, tau=1):
        dot_products = (query.unsqueeze(2) * key.unsqueeze(1)).sum(
            -1)  # batch x query_len x key_len

        if self.relative_clip:
            dot_relative = torch.einsum(
                "ijk,tk->ijt", query,
                self.key_relative.weight)  # batch * query_len * relative_size

            batch_size, query_len, key_len = dot_products.shape

            diag_dim = max(query_len, key_len)
            if self.diag_id.shape[0] < diag_dim:
                self.diag_id = np.zeros((diag_dim, diag_dim))
                for i in range(diag_dim):
                    for j in range(diag_dim):
                        if i <= j - self.relative_clip:
                            self.diag_id[i, j] = 0
                        elif i >= j + self.relative_clip:
                            self.diag_id[i, j] = self.relative_clip * 2
                        else:
                            self.diag_id[i, j] = i - j + self.relative_clip
            diag_id = LongTensor(self.diag_id[:query_len, :key_len])

            dot_relative = reshape(
                dot_relative, "bld", "bl_d",
                key_len).gather(-1,
                                reshape(diag_id, "lm", "_lm_", batch_size,
                                        -1))[:, :, :,
                                             0]  # batch * query_len * key_len
            dot_products = dot_products + dot_relative

        if self.attend_mode == "only_attend_front":
            assert query.shape[1] == key.shape[1]
            tri = cuda(torch.ones(key.shape[1], key.shape[1]).triu(1),
                       device=query) * 1e9
            dot_products = dot_products - tri.unsqueeze(0)
        elif self.attend_mode == "only_attend_back":
            assert query.shape[1] == key.shape[1]
            tri = cuda(torch.ones(key.shape[1], key.shape[1]).tril(1),
                       device=query) * 1e9
            dot_products = dot_products - tri.unsqueeze(0)
        elif self.attend_mode == "not_attend_self":
            assert query.shape[1] == key.shape[1]
            eye = cuda(torch.eye(key.shape[1]), device=query) * 1e9
            dot_products = dot_products - eye.unsqueeze(0)

        if self.window > 0:
            assert query.shape[1] == key.shape[1]
            window_mask = cuda(torch.ones(key.shape[1], key.shape[1]),
                               device=query)
            window_mask = (window_mask.triu(self.window + 1) +
                           window_mask.tril(self.window + 1)) * 1e9
            dot_products = dot_products - window_mask.unsqueeze(0)

        if mask is not None:
            dot_products -= (1 - mask) * 1e9

        logits = dot_products / self.scale
        if self.gumbel_attend and self.training:
            probs = gumbel_softmax(logits, tau, dim=-1)
        else:
            probs = torch.softmax(logits, dim=-1)

        probs = probs * (
            (dot_products <= -5e8).sum(-1, keepdim=True) <
            dot_products.shape[-1]).float()  # batch_size * query_len * key_len
        probs = self.dropout(probs)

        res = torch.matmul(probs, value)  # batch_size * query_len * d_value

        if self.relative_clip:
            if self.recover_id.shape[0] < query_len:
                self.recover_id = np.zeros((query_len, self.relative_size))
                for i in range(query_len):
                    for j in range(self.relative_size):
                        self.recover_id[i, j] = i + j - self.relative_clip
            recover_id = LongTensor(self.recover_id[:key_len])
            recover_id[recover_id < 0] = key_len
            recover_id[recover_id >= key_len] = key_len

            probs = torch.cat([probs, zeros(batch_size, query_len, 1)], -1)
            relative_probs = probs.gather(
                -1,
                reshape(recover_id, "qr", "_qr",
                        batch_size))  # batch_size * query_len * relative_size
            res = res + torch.einsum(
                "bqr,rd->bqd", relative_probs,
                self.value_relative.weight)  # batch_size * query_len * d_value

        return res
    def forward(self,
                inp,
                wLinearLayerCallback,
                h_init=None,
                mode='max',
                input_callback=None,
                no_unk=True,
                top_k=10):
        """
        inp contains: batch_size, dm, embLayer, embedding, sampling_proba, max_sent_length, post, post_length, resp_length [init_h]
        input_callback(i, embedding):   if you want to change word embedding at pos i, override this function
        nextStep(embedding, flag):  pass embedding to RNN and get gru_h, flag indicates i th sentence is end when flag[i]==1
        wLinearLayerCallback(gru_h): input gru_h and give a probability distribution on vocablist

        output: w_o emb length"""
        nextStep, h_now, context = self.init_forward_all(inp.batch_size,
                                                         inp.post,
                                                         inp.post_length,
                                                         h_init=inp.get(
                                                             "init_h", None))

        gen = Storage()
        gen.w_pro = []
        batch_size = inp.embedding.shape[1]
        seqlen = inp.embedding.shape[0]
        length = inp.resp_length - 1
        start_id = inp.dm.go_id if no_unk else 0

        attn_weights = []
        first_emb = inp.embLayer(LongTensor([inp.dm.go_id
                                             ])).repeat(inp.batch_size, 1)
        next_emb = first_emb

        if input_callback:
            inp.embedding = input_callback(inp.embedding)

        for i in range(seqlen):
            proba = random()

            # Sampling
            if proba < inp.sampling_proba:
                now = next_emb
                if input_callback:
                    now = input_callback(now)
            # Teacher Forcing
            else:
                now = inp.embedding[i]

            if self.gru_input_attn:
                h_now = self.cell_forward(torch.cat([now, context], last_dim=-1), h_now) \
                    * Tensor((length > np.ones(batch_size) * i).astype(float)).unsqueeze(-1)
            else:
                h_now = self.cell_forward(now, h_now) \
                    * Tensor((length > np.ones(batch_size) * i).astype(float)).unsqueeze(-1)

            query = self.attn_query(h_now)
            attn_weight = maskedSoftmax(
                (query.unsqueeze(0) * inp.post).sum(-1), inp.post_length)
            context = (attn_weight.unsqueeze(-1) * inp.post).sum(0)

            gru_h = torch.cat([h_now, context], dim=-1)
            attn_weights.append(attn_weight)

            w = wLinearLayerCallback(gru_h)
            gen.w_pro.append(w)

            # Decoding
            if mode == "max":
                w = torch.argmax(w[:, start_id:], dim=1) + start_id
                next_emb = inp.embLayer(w)
            elif mode == "gumbel" or mode == "sample":
                w_onehot = gumbel_max(w[:, start_id:])
                w = torch.argmax(w_onehot, dim=1) + start_id
                next_emb = torch.sum(
                    torch.unsqueeze(w_onehot, -1) *
                    inp.embLayer.weight[start_id:], 1)
            elif mode == "samplek":
                _, index = w[:,
                             start_id:].topk(top_k,
                                             dim=-1,
                                             largest=True,
                                             sorted=True)  # batch_size, top_k
                mask = torch.zeros_like(w[:,
                                          start_id:]).scatter_(-1, index, 1.0)
                w_onehot = gumbel_max_with_mask(w[:, start_id:], mask)
                w = torch.argmax(w_onehot, dim=1) + start_id
                next_emb = torch.sum(
                    torch.unsqueeze(w_onehot, -1) *
                    inp.embLayer.weight[start_id:], 1)
            else:
                raise AttributeError(
                    "The given mode {} is not recognized.".format(mode))

        gen.w_pro = torch.stack(gen.w_pro, dim=0)

        return gen