示例#1
0
    def forward(self, input,
                enc_init_h=None, enc_init_c=None,
                dec_init_input=None, dec_kwargs={}):
        """
        Args:
            input: [batch_size, seq_len]
            enc_init_h: [batch_size, n_layers, hidden_size]
            enc_init_c: [batch_size, n_layers, hidden_size]
            dec_init_input: [batch_size]
            dec_kwargs: dict

        Returns:
            Output of StackedLSTMDecoder's forward()
        """
        if (enc_init_h is None) and (enc_init_c is None):
            batch_size = input.size(0)
            enc_init_h, enc_init_c = self.rnn.state0(batch_size)
            enc_init_h, enc_init_c = move_to_cuda(enc_init_h), move_to_cuda(enc_init_c)

        hiddens, cells, outputs = self.encoder(input, enc_init_h, enc_init_c)

        # Get states and input for decoder
        last_hidden, last_cell, last_logits = hiddens[-1], cells[-1], outputs[-1]
        if dec_init_input is None:
            last_probs = logits_to_prob(last_logits, method='softmax')  # [batch, vocab]
            _, dec_init_input = prob_to_vocab_id(last_probs, 'greedy')  # [batch]

        probs, ids, texts, extra = self.decoder(last_hidden, last_cell, dec_init_input, **dec_kwargs)
        return probs, ids, texts, extra
示例#2
0
    def forward(self,
                init_hidden,
                init_cell,
                init_input,
                targets=None,
                seq_len=None,
                eos_id=EOS_ID,
                non_pad_prob_val=0,
                softmax_method='softmax',
                sample_method='sample',
                tau=1.0,
                eps=1e-10,
                gumbel_hard=False,
                encoder_hiddens=None,
                encoder_inputs=None,
                attend_to_embs=None,
                subwordenc=None,
                return_last_state=False,
                k=1):
        """
        Decode. If targets is given, then use teacher forcing.

        Notes:
            This is also used by beam search by setting seq_len=1 and k=beam_size.
            The comments talk about [batch * k^(t+1), but in practice this should only ever
            be called with seq_len=1 (and hence t=0). The results from one beam step are then pruned,
            before beam search repeats the step.

        Args:
            init_hidden: [batch_size, n_layers, hidden_size]
            init_cell: [batch_size, n_layers, hidden_size]
            init_input: [batch_size] (e.g. <EDOC> ids)

            # For teacher forcing
            targets: [batch_size, trg_seq_len]

            # For non-teacher forcing
            seq_len: int (length to generate)
            eos_id: int (generate until every sequence in batch has generated eos, or until seq_len)
            non_pad_prob_val: float
                When replacing tokens after eos_id, set probability of non-pad tokens to this value
                A small epsilon may be used if the log of the probs will be computed for a NLLLoss in order
                to prevent Nans.

            # Sampling, temperature, etc.
            softmax_method: str (which version of softmax to get probabilities; 'gumbel' or 'softmax')
            sample_method: str (how to sample words given probabilities; 'greedy', 'sample')
            tau: float (temperature for softmax)
            eps: float (controls sampling from Gumbel)
            gumbel_hard: boolean (whether to produce one hot encodings for Gumbel Softmax)
            subwordenc: SubwordTokenizer
                (returns text if given)

            # Additional inputs
            encoder_hiddens: [batch_size, seq_len, hidden_size]
                Hiddens at each time step. Would be used for attention
            encoder_inputs: [batch_size, seq_len]
                Would be used for a pointer network
            attend_to_embs: [batch_size, n_docs, n_layers, hidden_size]
                maybe just [batch, *, hidden]?
                Embs to attend to. Could be last hidden states (i.e. document representations)

            # Beam search
            return_last_state: bool
                (states used for beam search)
            k: int (i.e. beam width)

        Returns:
            decoded_probs: [batch * k^(gen_len), gen_len, vocab]
            decoded_ids: [batch * k^(gen_len), gen_len]
            decoded_texts: list of str's if subwordenc is given
            extra: dict of additional outputs
        """
        batch_size = init_input.size(0)
        output_len = seq_len if seq_len is not None else targets.size(1)
        vocab_size = self.rnn.h2o.out_features

        decoded_probs = move_to_cuda(
            torch.zeros(batch_size * k, output_len, vocab_size))
        decoded_ids = move_to_cuda(
            torch.zeros(batch_size * k, output_len).long())
        extra = {}

        rows_with_eos = move_to_cuda(torch.zeros(
            batch_size *
            k).long())  # track which sequences have generated eos_id
        pad_ids = move_to_cuda(
            torch.Tensor(batch_size * k).fill_(PAD_ID).long())
        pad_prob = move_to_cuda(torch.zeros(
            batch_size * k, vocab_size)).fill_(non_pad_prob_val)
        pad_prob[:, PAD_ID] = 1.0

        hidden, cell = init_hidden, init_cell  # [batch, n_layers, hidden]
        input = init_input.long()

        for t in range(output_len):
            if gumbel_hard and t != 0:
                input_emb = torch.matmul(input, self.embed.weight)
            else:
                input_emb = self.embed(input)  # [batch, emb_size]

            if self.use_docs_attn:
                attn_wts = self.attn_lin1(
                    attend_to_embs)  # [batch, n_docs, n_layers, attn_size]
                attn_wts = self.attn_lin2(
                    self.attn_act1(attn_wts))  # [batch, n_docs, n_layers, 1]
                attn_wts = F.softmax(attn_wts,
                                     dim=1)  # [batch, n_docs, n_layers, 1]
                context = attn_wts * attend_to_embs  # [batch, n_docs, n_layers, hidden]
                context = context.sum(dim=1)  # [batch, n_layers, hidden]
                hidden = self.context_alpha * context + (
                    1 - self.context_alpha) * hidden

            hidden, cell, output = self.rnn(input_emb, hidden, cell)
            prob = logits_to_prob(output,
                                  softmax_method,
                                  tau=tau,
                                  eps=eps,
                                  gumbel_hard=gumbel_hard)  # [batch, vocab]
            prob, id = prob_to_vocab_id(prob, sample_method,
                                        k=k)  # [batch * k^(t+1)]

            # If sequence (row) has *previously* produced an EOS,
            # replace prob with one hot (probability one for pad) and id with pad
            prob = torch.where((rows_with_eos == 1).unsqueeze(1), pad_prob,
                               prob)  # unsqueeze to broadcast
            id = torch.where(rows_with_eos == 1, pad_ids, id)
            # Now update rows_with_eos to include this time step
            # This has to go after the above! Otherwise EOS is replaced as well
            rows_with_eos = rows_with_eos | (id == eos_id).long()

            decoded_probs[:, t, :] = prob
            decoded_ids[:, t] = id

            # Get next input
            if targets is not None:  # teacher forcing
                input = targets[:, t]  # [batch]
            else:  # non-teacher forcing
                if gumbel_hard:
                    input = prob
                else:
                    input = id  # [batch * k^(t+1)]

            # Terminate early if not teacher forcing and all sequences have generated an eos
            if targets is None:
                if rows_with_eos.sum().item() == (batch_size * k):
                    break

        # if return_last_state:
        #     extra['last_state'] = states

        decoded_texts = []
        if subwordenc:
            for i in range(batch_size):
                decoded_texts.append(
                    subwordenc.decode(decoded_ids[i].long().tolist()))

        return decoded_probs, decoded_ids, decoded_texts, extra
示例#3
0
    def prepare_batch(self,
                      texts_batch,
                      ratings_batch,
                      global_prepend_id=None,
                      global_append_id=None,
                      doc_prepend_id=None,
                      doc_append_id=None):
        """
        Prepare batch of texts and labels from DataLoader as input into nn.

        Args:
            texts_batch: list of str's
                - length batch_size
                - each str is a concatenated group of document
            ratings_batch: list of size-1 LongTensor's
                - length_batch_size

            global_prepend_id: int (prepend GO)
            global_append_id: int (append EOS)
            doc_prepend_id: int (prepend DOC before start of each review)
            doc_append_id: int (append /DOC after end of each review)

        Returns: (cuda)
            x: LongTensor of size [batch_size, max_seq_len]
            lengths: LongTensor (length of each text in subtokens)
            labels: LongTensor of size [batch_size]
        """
        # Original ratings go from 1-5

        labels_batch = [rating - 1 for rating in ratings_batch]

        batch = []
        for i, text in enumerate(texts_batch):

            # Split apart by docs and potentially add delimiters
            docs = SummDataset.split_docs(text)  # list of strs

            if doc_prepend_id or doc_append_id:
                docs_ids = [self.subwordenc.encode(doc) for doc in docs]
                if doc_prepend_id:
                    for doc_ids in docs_ids:
                        doc_ids.insert(0, doc_prepend_id)
                if doc_append_id:
                    for doc_ids in docs_ids:
                        doc_ids.append(doc_append_id)

                docs_ids = [id for doc_ids in docs_ids
                            for id in doc_ids]  # flatten
                subtoken_ids = docs_ids
            else:
                subtoken_ids = self.subwordenc.encode(' '.join(docs))

            # Add start and end token for concatenated set of documents
            if global_prepend_id:
                subtoken_ids.insert(0, global_prepend_id)
            if global_append_id:
                subtoken_ids.append(global_append_id)
            seq_len = len(subtoken_ids)
            batch.append((subtoken_ids, seq_len, labels_batch[i]))

        texts_ids, lengths, labels = zip(*batch)

        lengths = torch.LongTensor(lengths)
        labels = torch.stack(labels)

        # Pad each text
        max_seq_len = max(lengths)
        batch_size = len(batch)
        x = np.zeros((batch_size, max_seq_len))
        for i, text_ids in enumerate(texts_ids):
            padded = np.zeros(max_seq_len)
            padded[:len(text_ids)] = text_ids
            x[i, :] = padded
        x = torch.from_numpy(x.astype(int))

        x = move_to_cuda(x)
        lengths = move_to_cuda(lengths)
        labels = move_to_cuda(labels)

        return x, lengths, labels
示例#4
0
    def run_epoch(self,
                  data_iter,
                  nbatches,
                  epoch,
                  split,
                  optimizer=None,
                  tb_writer=None):
        """

        Args:
            data_iter: Pytorch DataLoader
            nbatches: int (number of batches in data_iter)
            epoch: int
            split: str ('train', 'val')
            optimizer: Wrapped optim (e.g. OptWrapper, NoamOpt)
            tb_writer: Tensorboard SummaryWriter

        Returns:
            1D tensor containing average loss across all items in data_iter
        """

        loss_avg = 0
        n_fwds = 0
        for s_idx, (texts, ratings, metadata) in enumerate(data_iter):

            start = time.time()

            # Add special tokens to texts
            x, lengths, labels = self.dataset.prepare_batch(
                texts, ratings, doc_append_id=EDOC_ID)
            iter = create_lm_data_iter(x, self.hp.lm_seq_len)
            for b_idx, batch_obj in enumerate(iter):
                if optimizer:
                    optimizer.optimizer.zero_grad()

                #
                # Forward pass
                #
                if self.hp.model_type == 'mlstm':
                    # Note: iter creates a sequence of length hp.lm_seq_len + 1, and batch_obj.trg is all about the
                    # last token, while batch_obj.trg_y is all but the first token. They're named as such because
                    # the Batch class was originally designed for the Encoder-Decoder version of the Transformer, and
                    # the trg variables correspond to inputs to the Decoder.
                    batch = move_to_cuda(
                        batch_obj.trg
                    )  # it's trg because doesn't include last token
                    batch_trg = move_to_cuda(batch_obj.trg_y)
                    batch_size, seq_len = batch.size()

                    if b_idx == 0:
                        h_init, c_init = self.model.module.rnn.state0(batch_size) if self.ngpus > 1 \
                            else self.model.rnn.state0(batch_size)
                        h_init = move_to_cuda(h_init)
                        c_init = move_to_cuda(c_init)

                    # Forward steps for lstm
                    result = self.model(batch, h_init, c_init)
                    hiddens, cells, outputs = zip(
                        *result) if self.ngpus > 1 else result

                    # Calculate loss
                    loss = 0
                    batch_trg = batch_trg.transpose(
                        0, 1).contiguous()  # [seq_len, batch]
                    if self.ngpus > 1:
                        for t in range(len(outputs[0])):
                            # length ngpus list of outputs at that time step
                            loss += self.loss_fn(
                                [outputs[i][t] for i in range(len(outputs))],
                                batch_trg[t])
                    else:
                        for t in range(len(outputs)):
                            loss += self.loss_fn(outputs[t], batch_trg[t])
                    loss_value = loss.item() / self.hp.lm_seq_len

                    # We only do bptt until lm_seq_len. Copy the hidden states so that we can continue the sequence
                    if self.ngpus > 1:
                        h_init = torch.cat([
                            copy_state(hiddens[i][-1])
                            for i in range(self.ngpus)
                        ],
                                           dim=0)
                        c_init = torch.cat([
                            copy_state(cells[i][-1]) for i in range(self.ngpus)
                        ],
                                           dim=0)
                    else:
                        h_init = copy_state(hiddens[-1])
                        c_init = copy_state(cells[-1])

                elif self.hp.model_type == 'transformer':
                    # This is the decoder only version now
                    logits = self.model(move_to_cuda(batch_obj.trg),
                                        move_to_cuda(batch_obj.trg_mask))
                    # logits: [batch, seq_len, vocab]
                    loss = self.loss_fn(logits, move_to_cuda(batch_obj.trg_y))
                    loss /= move_to_cuda(batch_obj.ntokens.float(
                    ))  # normalize by number of non-pad tokens
                    loss_value = loss.item()
                    if self.ngpus > 1:
                        # With the custom DataParallel, there is no gather() and the loss is calculated per
                        # minibatch split on each GPU (see DataParallelCriterion's forward() -- the return
                        # value is divided by the number of GPUs). We simply undo that operation here.
                        # Also, note that the KLDivLoss in LabelSmoothing is already normalized by both
                        # batch and seq_len, as we use size_average=False to prevent any normalization followed
                        # by a manual normalization using the batch.ntokens. This oddity is because
                        # KLDivLoss does not support ignore_index=PAD_ID as CrossEntropyLoss does.
                        loss_value *= len(self.opt.gpus.split(','))

                #
                # Backward pass
                #
                gn = -1.0  # dummy for val (norm can't be < 0 anyway)
                if optimizer:
                    loss.backward()
                    gn = calc_grad_norm(
                        self.model
                    )  # not actually using this, just for printing
                    optimizer.step()
                loss_avg = update_moving_avg(loss_avg, loss_value, n_fwds + 1)
                n_fwds += 1

            # Print
            print_str = 'Epoch={}, batch={}/{}, split={}, time={:.4f} --- ' \
                        'loss={:.4f}, loss_avg_so_far={:.4f}, grad_norm={:.4f}'
            if s_idx % self.opt.print_every_nbatches == 0:
                print(
                    print_str.format(epoch, s_idx, nbatches, split,
                                     time.time() - start, loss_value, loss_avg,
                                     gn))
                if tb_writer:
                    # Step for tensorboard: global steps in terms of number of reviews
                    # This accounts for runs with different batch sizes
                    step = (epoch * nbatches *
                            self.hp.batch_size) + (s_idx * self.hp.batch_size)
                    tb_writer.add_scalar('stats/loss', loss_value, step)

            # Save periodically so we don't have to wait for epoch to finish
            save_every = nbatches // 10
            if save_every != 0 and s_idx % save_every == 0:
                save_model(self.save_dir, self.model, self.optimizer, epoch,
                           self.opt, 'intermediate')

        print('Epoch={}, split={}, --- '
              'loss_avg={:.4f}'.format(epoch, split, loss_avg))

        return loss_avg
示例#5
0
    def forward(self,
                docs_ids,
                labels,
                cycle_tgt_ids=None,
                extract_summ_ids=None,
                tau=None,
                adv_step=None,
                real_ids=None,
                minibatch_idx=None,
                print_every_nbatches=None,
                tb_writer=None,
                tb_step=None,
                wass_loss=None,
                grad_pen_loss=None,
                adv_gen_loss=None,
                clf_loss=None,
                clf_acc=None,
                clf_avg_diff=None):
        """
        Args:
            docs_ids: [batch, max_len (concatenated reviews)] when concat_docs=True
                      [batch, n_docs, max_len] when concat_docs=False
            labels: [batch]
                - ratings for classification
            cycle_tgt_ids: [batch, n_docs, seq_len]
            extract_summ_ids: [batch, max_sum_len]
                - summaries from extractive model
            tau: float
                Passed in instead of using self.hp.tau because tau may be
                obtained from a StepAnnealer if there is a scheduled decay.
            adv_step: str ('discrim' or 'gen')
                - whether to compute discriminator step (with detach to only train Discriminator), or
                just the generator step (pass in generated summaries and return .mean())
            real_ids: [batch, max_rev_len]
                - reviews used for Discriminator

            minibatch_idx: int (how many minibatches in current epoch)
            print_every_nbatches: int
            tb_writer: Tensorboard SummaryWriter
            tb_step: int (used for writer)

            The remaining are 0-D float Tensors to handle an edge case where the summary is too short for the
            TextCNN. The current average is passed in.

        Returns:
            stats: dict (str to 0-D tensors)
                - contains losses
            summ_texts: list of strs
        """
        batch_size = docs_ids.size(0)

        ##########################################################
        # ENCODE DOCUMENTS
        ##########################################################
        # Print a review if we're autoencoding or using cycle reconstruction loss so that we can
        # check how well the reconstruction is
        if self.hp.autoenc_docs or (self.hp.cycle_loss == 'rec'):
            if minibatch_idx % print_every_nbatches == 0:
                if docs_ids.get_device() == 0:
                    print('\n', '-' * 100)
                    orig_rev_text = self.dataset.subwordenc.decode(
                        docs_ids[0][0])
                    print('ORIGINAL REVIEW: ', orig_rev_text.encode('utf8'))
                    print('-' * 100)
                    if tb_writer:
                        tb_writer.add_text('auto_or_rec/orig_review',
                                           orig_rev_text, tb_step)

        if not self.hp.concat_docs:
            n_docs = docs_ids.size(
                1
            )  # TODO: need to get data loader to choose items with same n_docs
            docs_ids = docs_ids.view(
                -1, docs_ids.size(-1))  # [batch * n_docs, len]

        h_init, c_init = self.docs_enc.rnn.state0(docs_ids.size(0))
        h_init, c_init = move_to_cuda(h_init), move_to_cuda(c_init)
        hiddens, cells, outputs = self.docs_enc(docs_ids, h_init, c_init)
        docs_enc_h, docs_enc_c = hiddens[-1], cells[
            -1]  # [_, n_layers, hidden]

        ##########################################################
        # DECODE INTO SUMMARIES AND / OR ORIGINAL REVIEWS
        ##########################################################

        # Autoencoder - decode into original reviews
        if self.hp.autoenc_docs:
            assert (self.hp.concat_docs == False), \
                'Docs must be encoded individually for autoencoder. Set concat_docs=False'
            init_input = torch.LongTensor(
                [EDOC_ID for _ in range(docs_enc_h.size(0))])  # batch * n_docs
            init_input = move_to_cuda(init_input)
            docs_autodec_probs, _, docs_autodec_texts, _ = self.docs_autodec(
                docs_enc_h,
                docs_enc_c,
                init_input,
                targets=docs_ids,
                eos_id=EDOC_ID,
                non_pad_prob_val=1e-14,
                softmax_method='softmax',
                sample_method='greedy',
                tau=tau,
                subwordenc=self.dataset.subwordenc)

            docs_autodec_logprobs = torch.log(docs_autodec_probs)
            autoenc_loss = self.rec_crit(
                docs_autodec_logprobs.view(-1, docs_autodec_logprobs.size(-1)),
                docs_ids.view(-1))
            if self.hp.sum_label_smooth:
                autoenc_loss /= (docs_ids != move_to_cuda(
                    torch.tensor(PAD_ID))).sum().float()
            self.stats['autoenc_loss'] = autoenc_loss

            if minibatch_idx % print_every_nbatches == 0:
                if docs_ids.get_device() == 0:
                    dec_text = docs_autodec_texts[0]
                    print('DECODED REVIEW: ', dec_text.encode('utf8'))
                    print('-' * 100, '\n')
                    if tb_writer:
                        tb_writer.add_text('auto_or_rec/auto_dec_review',
                                           dec_text, tb_step)

            # Early return if we're only computing auto-encoder (don't have to decode into summaries)
            if self.hp.autoenc_only:
                dummy_summ_texts = [
                    'a dummy review' for _ in range(batch_size)
                ]
                return self.stats, dummy_summ_texts

        # Decode into summary
        if not self.hp.concat_docs:
            _, n_layers, hidden_size = docs_enc_h.size()
            docs_enc_h = docs_enc_h.view(batch_size, n_docs, n_layers,
                                         hidden_size)
            docs_enc_c = docs_enc_c.view(batch_size, n_docs, n_layers,
                                         hidden_size)
            if self.hp.combine_encs == 'mean':
                docs_enc_h_comb = docs_enc_h.mean(dim=1)
                docs_enc_c_comb = docs_enc_c.mean(dim=1)
            elif self.hp.combine_encs == 'ff':
                docs_enc_h_comb = docs_enc_h.transpose(1, 2).view(
                    batch_size, n_layers, -1)
                # [batch, n_layers, n_docs * hidden]
                docs_enc_c_comb = docs_enc_c.transpose(1, 2).view(
                    batch_size, n_layers, -1)
                docs_enc_h_comb = self.combine_encs_h_net(
                    docs_enc_h_comb)  # [batch, n_layers, hidden]
                docs_enc_c_comb = self.combine_encs_c_net(docs_enc_c_comb)
            elif self.hp.combine_encs == 'gru':
                n_directions = 2 if self.hp.combine_encs_gru_bi else 1
                init_h = torch.zeros(
                    self.hp.combine_encs_gru_nlayers * n_directions,
                    batch_size, hidden_size)
                init_c = torch.zeros(
                    self.hp.combine_encs_gru_nlayers * n_directions,
                    batch_size, hidden_size)
                init_h = move_to_cuda(init_h)
                init_c = move_to_cuda(init_c)
                docs_enc_h_comb = docs_enc_h.view(batch_size,
                                                  n_docs * n_layers,
                                                  hidden_size)
                docs_enc_c_comb = docs_enc_h.view(batch_size,
                                                  n_docs * n_layers,
                                                  hidden_size)
                # [batch, n_directions * gru_nlayers, hidden]
                # self.combine_encs_h_net.flatten_parameters()
                # self.combine_encs_c_net.flatten_parameters()
                _, docs_enc_h_comb = self.combine_encs_h_net(
                    docs_enc_h_comb, init_h)
                _, docs_enc_c_comb = self.combine_encs_c_net(
                    docs_enc_c_comb, init_c)
                # [n_directions * gru_nlayers, batch, hidden]
                docs_enc_h_comb = docs_enc_h_comb[-1, :, :].unsqueeze(
                    0).transpose(0, 1)  # last layer TODO: last or combine?
                docs_enc_c_comb = docs_enc_c_comb[-1, :, :].unsqueeze(
                    0).transpose(0, 1)  # last layer

        softmax_method = 'gumbel'
        sample_method = 'greedy'
        if self.hp.early_cycle:
            softmax_method = 'softmax'
            # sample_method = 'sample'
        init_input = torch.LongTensor([EDOC_ID for _ in range(batch_size)])
        init_input = move_to_cuda(init_input)
        # Backwards compatibility with models trained before dataset refactoring
        # We could use the code_snapshot saved at the time the model was trained, but I've added some
        # useful things (e.g. tracking NLL of summaries)
        tgt_summ_seq_len = self.dataset.conf.review_max_len if hasattr(self.dataset, 'conf') else \
            self.hp.yelp_review_max_len
        summ_probs, _, summ_texts, _ = self.summ_dec(
            docs_enc_h_comb,
            docs_enc_c_comb,
            init_input,
            seq_len=tgt_summ_seq_len,
            eos_id=EDOC_ID,
            # seq_len=self.dataset.conf.review_max_len, eos_id=EDOC_ID,
            softmax_method=softmax_method,
            sample_method=sample_method,
            tau=tau,
            eps=self.hp.g_eps,
            gumbel_hard=True,
            attend_to_embs=docs_enc_h,
            subwordenc=self.dataset.subwordenc)

        # [batch, max_summ_len, vocab];  [batch] of str's

        # Compute a cosine similarity loss between the (mean) summary representation that's fed to the
        # summary decoder and each of the original encoded reviews.
        # With this setup, there's no need for the summary encoder or back propagating through the summary.
        if self.hp.early_cycle:
            # Repeat each summary representation n_docs times to match shape of tensor with individual reviews
            docs_enc_h_comb_rep = docs_enc_h_comb.repeat(1, n_docs, 1) \
                .view(batch_size * n_docs, docs_enc_h_comb.size(1), docs_enc_h_comb.size(2))
            docs_enc_c_comb_rep = docs_enc_c_comb.repeat(1, n_docs, 1) \
                .view(batch_size * n_docs, docs_enc_c_comb.size(1), docs_enc_c_comb.size(2))

            loss = -self.cos_crit(docs_enc_h_comb_rep.view(batch_size, -1),
                                  docs_enc_h.view(batch_size,
                                                  -1).detach()).mean()
            if not self.hp.cos_honly:
                loss -= self.cos_crit(docs_enc_c_comb_rep.view(batch_size, -1),
                                      docs_enc_c.view(batch_size,
                                                      -1).detach()).mean()
            self.stats['early_cycle_loss'] = loss * self.hp.cos_wgt

        ##########################################################
        # CYCLE LOSS and / or  EXTRACTIVE SUMMARY LOSS
        ##########################################################

        # Encode summaries
        if self.hp.sum_cycle or self.hp.extract_loss:
            init_h, init_c = self.summ_enc.rnn.state0(batch_size)
            init_h, init_c = move_to_cuda(init_h), move_to_cuda(init_c)
            hiddens, cells, outputs = self.summ_enc(summ_probs, init_h, init_c)
            summ_enc_h, summ_enc_c = hiddens[-1], cells[
                -1]  # [batch, n_layers, hidden], ''

        # Extractive vs. abstractive summary loss
        if self.hp.extract_loss:
            # Encode extractive summary
            init_h, init_c = self.summ_enc.rnn.state0(batch_size)
            init_h, init_c = move_to_cuda(init_h), move_to_cuda(init_c)
            ext_hiddens, ext_cells, ext_outputs = self.summ_enc(
                extract_summ_ids, init_h, init_c)
            ext_enc_h, ext_enc_c = ext_hiddens[-1], ext_cells[
                -1]  # [batch, n_layers, hidden], ''
            loss = -self.cos_crit(summ_enc_h.view(batch_size, -1),
                                  ext_enc_h.view(batch_size,
                                                 -1).detach()).mean()
            if not self.hp.cos_honly:
                loss -= self.cos_crit(summ_enc_c.view(batch_size, -1),
                                      ext_enc_c.view(batch_size,
                                                     -1).detach()).mean()
            self.stats['extract_loss'] = loss

        # Reconstruction or encoder cycle loss
        if self.hp.sum_cycle:
            # Repeat each summary representation n_docs times to match shape of tensor with individual reviews
            summ_enc_h_rep = summ_enc_h.repeat(1, n_docs, 1) \
                .view(batch_size * n_docs, summ_enc_h.size(1), summ_enc_h.size(2))
            summ_enc_c_rep = summ_enc_c.repeat(1, n_docs, 1) \
                .view(batch_size * n_docs, summ_enc_c.size(1), summ_enc_c.size(2))

            if self.hp.cycle_loss == 'enc':
                assert (self.hp.concat_docs == False), \
                    'Docs must have been encoded individually for autoencoder. Set concat_docs=False'
                # (It's possible to have cycle_loss=enc and concat_docs=False, you just have to als encode them
                # separately. Didn't add that b/c I think I'll always have concat_docs=False from now on)
                # docs_enc_h, docs_enc_c: [batch, n_docs, n_layers, hidden]
                loss = -self.cos_crit(summ_enc_h_rep.view(batch_size, -1),
                                      docs_enc_h.view(batch_size,
                                                      -1).detach()).mean()
                if not self.hp.cos_honly:
                    loss -= self.cos_crit(
                        summ_enc_c_rep.view(batch_size, -1),
                        docs_enc_c.view(batch_size, -1).detach()).mean()
                self.stats['cycle_loss'] = loss * self.hp.cos_wgt
            elif self.hp.cycle_loss == 'rec':
                init_input = move_to_cuda(
                    torch.LongTensor(
                        [EDOC_ID for _ in range(batch_size * n_docs)]))
                probs, ids, texts, extra = self.docs_dec(
                    summ_enc_h_rep,
                    summ_enc_c_rep,
                    init_input,
                    targets=cycle_tgt_ids.view(-1, cycle_tgt_ids.size(-1)),
                    eos_id=EDOC_ID,
                    non_pad_prob_val=1e-14,
                    softmax_method='softmax',
                    sample_method='sample',
                    tau=tau,
                    subwordenc=self.dataset.subwordenc)
                vocab_size = probs.size(-1)
                logprobs = torch.log(probs).view(-1, vocab_size)
                loss = self.rec_crit(logprobs, cycle_tgt_ids.view(-1))
                if self.hp.sum_label_smooth:
                    loss /= (cycle_tgt_ids != move_to_cuda(
                        torch.tensor(PAD_ID))).sum().float()
                self.stats['cycle_loss'] = loss * self.hp.cos_wgt

                if minibatch_idx % print_every_nbatches == 0:
                    if docs_ids.get_device() == 0:
                        print('DECODED REVIEW: ', texts[0].encode('utf8'))
                        print('-' * 100, '\n')
                        if tb_writer:
                            tb_writer.add_text('auto_or_rec/rec_review',
                                               texts[0], tb_step)

        ##########################################################
        # DISCRIMINATOR
        ##########################################################
        # TODO: Remove this -- discriminator is not used

        # Goal: self.discrim_model should be good at distinguishing between generated canonical review and
        # original reviews. The adv_loss returns difference between gen and real (plus the gradient penalty)
        # To train the discriminator, we want to minimize this value (gen is small, real is large)
        # To train the rest, want to maximize gen (or minimize -gen)
        if adv_step == 'discrim':
            if (self.hp.discrim_model == 'cnn') and (
                    summ_probs.size(1) < 5):  # conv filters are 3,4,5
                print(
                    'Summary length is less than 5... skipping Discriminator model because it uses a CNN '
                    'with a convolution kernel of size 5')
            else:
                real_ids_onehot = convert_to_onehot(
                    real_ids, self.dataset.subwordenc.vocab_size)
                result = self.discrim_model(
                    real_ids_onehot.float().detach().requires_grad_(),
                    summ_probs.detach().requires_grad_())
                gen_mean, real_mean, grad_pen = result[0], result[1], result[2]
                wass_loss = gen_mean - real_mean
                grad_pen_loss = self.hp.wgan_lam * grad_pen

            adv_loss = wass_loss + grad_pen_loss
            self.stats.update({
                'wass_loss': wass_loss,
                'grad_pen_loss': grad_pen_loss,
                'adv_loss': adv_loss
            })
        elif adv_step == 'gen':
            dummy_real = torch.zeros_like(summ_probs).long()
            result = self.discrim_model(dummy_real, summ_probs)
            gen_mean = result[0]
            self.stats['adv_gen_loss'] = -1 * gen_mean

        ##########################################################
        # CLASSIFIER
        ##########################################################

        if self.hp.sum_clf:
            if summ_probs.size(1) < 5:  # conv filters are 3,4,5
                print(
                    'Summary length is less than 5... skipping classification model because it uses a CNN '
                    'with a convolution kernel of size 5')
            else:
                logits = self.clf_model(summ_probs.long())
                clf_loss = self.clf_crit(logits, labels)

                _, indices = torch.max(logits, dim=1)
                clf_avg_diff = (labels - indices).float().mean()
                clf_acc = torch.eq(indices, labels).sum().float() / batch_size

            self.stats.update({
                'clf_loss': clf_loss,
                'clf_acc': clf_acc,
                'clf_avg_diff': clf_avg_diff
            })

        return self.stats, summ_texts
示例#6
0
# Prepare initial input text
#
subwordenc = load_file(ds_conf.subwordenc_path)
init_texts = [init for init in opt.init.split('|')]
init_tokens = [subwordenc.encode(init) for init in init_texts]
init_lens = [len(init) for init in init_tokens]
max_len = max(init_lens)
init_tokens_padded = [
    tokens + [PAD_ID for _ in range(max_len - len(tokens))]
    for tokens in init_tokens
]
init_tensor = [
    batchify(torch.LongTensor(init), 1) for init in init_tokens_padded
]
init_tensor = torch.cat(init_tensor, dim=0)  # [batch, lens
init_tensor = move_to_cuda(init_tensor)
batch_size = init_tensor.size(0)

#
# Load and set up model
#
checkpoint = torch.load(opt.load_model)
model = checkpoint['model']
if isinstance(model, nn.DataParallel):
    model = model.module

ngpus = 1 if len(opt.gpus) == 1 else len(opt.gpus.split(','))

#
# Generate
# #