Ejemplo n.º 1
0
 def __init__(self, base_source_token_embedder, encoder, decoder_cell, copy_lens):
     super(Editor, self).__init__()
     self.encoder = encoder
     context_combiner = AttentionContextCombiner()
     self.train_decoder = TrainDecoder(decoder_cell, context_combiner)
     self.test_decoder_beam = BeamDecoder(decoder_cell, context_combiner)
     self.base_vocab = base_source_token_embedder.vocab
     self.base_source_token_embedder = base_source_token_embedder
     self.copy_lens = copy_lens
Ejemplo n.º 2
0
    def __init__(self, token_embedder, hidden_dim, agenda_dim, edit_dim, lamb_reg, norm_eps, norm_max, kill_edit, decoder_cell, encoder_layers):
        """Construct Editor.

        Args:
            token_embedder (TokenEmbedder)
            hidden_dim (int)
            agenda_dim (int)
            edit_dim (int)
            decoder_cell (DecoderCell)
            encoder_layers (int)
        """
        super(Editor, self).__init__()
        self.encoder = Encoder(token_embedder, agenda_dim, edit_dim, hidden_dim, lamb_reg, norm_eps, norm_max, kill_edit, encoder_layers,
                                   rnn_cell_factory=LSTMCell)
        context_combiner = AttentionContextCombiner()
        self.train_decoder = TrainDecoder(decoder_cell, token_embedder, context_combiner)
        self.test_decoder_beam = BeamDecoder(decoder_cell, token_embedder, context_combiner)
Ejemplo n.º 3
0
 def __init__(self, token_embedder, hidden_dim, agenda_dim, num_layers,
              logger):
     super(LanguageModel, self).__init__()
     input_dim = token_embedder.embed_dim
     decoder_cell = MultilayeredDecoderCell(token_embedder, hidden_dim,
                                            input_dim, agenda_dim,
                                            num_layers)
     context_combiner = MultilayeredLMContextCombiner()
     self.train_decoder = TrainDecoder(decoder_cell, token_embedder,
                                       context_combiner)
     self.sample_decoder = SampleDecoder(decoder_cell, token_embedder,
                                         context_combiner)
     self.beam_decoder = BeamDecoder(decoder_cell, token_embedder,
                                     context_combiner)
     self.agenda_dim = agenda_dim
     self.agenda = Parameter(torch.zeros(agenda_dim))
     self.vocab = token_embedder.vocab
     self.logger = logger
Ejemplo n.º 4
0
    def __init__(self, token_embedder, hidden_dim, agenda_dim, edit_dim, lamb_reg, norm_eps, norm_max, kill_edit, decoder_cell, encoder_layers, num_iter=None, eps=None, momentum=None):
        """Construct Editor.

        Args:
            token_embedder (TokenEmbedder)
            hidden_dim (int)
            agenda_dim (int)
            edit_dim (int)
            decoder_cell (DecoderCell)
            encoder_layers (int)
        """
        super(Editor, self).__init__()
        self.encoder = Encoder(token_embedder, agenda_dim, edit_dim, hidden_dim, lamb_reg, norm_eps, norm_max, kill_edit, encoder_layers,
                                   rnn_cell_factory=LSTMCell)
        context_combiner = AttentionContextCombiner()
        self.train_decoder = TrainDecoder(decoder_cell, token_embedder, context_combiner)
        self.test_decoder_beam = BeamDecoder(decoder_cell, token_embedder, context_combiner)
        update_params = list(self.train_decoder.parameters())
        self.meta_optimizer = OptimN2N(self.encoder, 
                            self.train_decoder, 
                            update_params, 
			    iters=num_iter,
                            eps=eps,
                            momentum=momentum)
Ejemplo n.º 5
0
class VAERetriever(Module):
    def __init__(self, base_source_token_embedder, encoder, decoder_cell,
                 copy_lens):
        super(VAERetriever, self).__init__()
        self.encoder = encoder
        context_combiner = AttentionContextCombiner()
        self.train_decoder = TrainDecoder(decoder_cell, context_combiner)
        self.test_decoder_beam = BeamDecoder(decoder_cell, context_combiner)
        self.base_vocab = base_source_token_embedder.vocab
        self.base_source_token_embedder = base_source_token_embedder
        self.copy_lens = copy_lens

    @classmethod
    def _batch_editor_examples(cls, examples):
        batch = lambda attr: [getattr(ex, attr) for ex in examples]
        input_words = batch('input_words')
        target_words = batch('target_words')
        return input_words, target_words

    ####
    # Retriever code
    def batch_embed(self, exes, train_mode=True):
        ret_list = []
        for batch in chunks(exes, 128):
            encin = self.encode(batch, train_mode).data.cpu().numpy()
            for vec in encin:
                ret_list.append(vec)
        return ret_list

    def make_lsh(self, veclist):
        """
        :param veclist: list of vectors to be indexed
        :return: an annoy LSH index structure.
        """
        lshind = AnnoyIndex(len(veclist[0]), metric='angular')
        for num, vec in enumerate(veclist):
            lshind.add_item(num, vec)
        lshind.build(10)
        return lshind

    def encode(self, in_data, train_mode=True):
        """
        :param in_data: sequence of edit examples - only inputs are encoded!
        :return: batch of agenda vectors of size (batch_size, agenda_dim)
        """
        encoder_input = self.preprocess(in_data).encoder_input
        encoder_output, enc_loss = self.encoder(encoder_input,
                                                train_mode=train_mode)
        return encoder_output.agenda

    def ret_lsh(self, veclist, lsh, topk=100, startat=0):
        return [
            lsh.get_nns_by_vector(vec, topk + startat)[startat:]
            for vec in veclist
        ]

    def ret_idx(self, in_data, lsh, train_mode=True):
        """
        :param in_data:
        :param lsh:
        :return:
        """
        encoded = self.encode(in_data,
                              train_mode=train_mode)  # batch x agenda_size
        encoded_np = encoded.data.cpu().numpy()
        topk_ret = self.ret_lsh(encoded_np, lsh)
        return topk_ret

    def ret_and_make_ex(self, input, lsh, ex_list, startat, train_mode=True):
        ret_list = []
        for batch in chunks(input, 128):
            idxlist = self.ret_idx(batch, lsh, train_mode=train_mode)
            ret_tmp = [ex_list[idx[startat]] for idx in idxlist]
            ret_list.extend(ret_tmp)
        return self.make_editexamples(ret_list, input)

    def make_editexamples(self, proto_list, edit_list):
        example_list = []
        for i in range(len(proto_list)):
            el = EditExample(
                edit_list[0].input_words + proto_list[i].input_words +
                [proto_list[i].target_words], edit_list[0].target_words)
            example_list.append(el)
        return example_list

    def ret_and_make_ex_one(self,
                            input,
                            lsh,
                            ex_list,
                            startat,
                            endat,
                            train_mode=True):
        # test
        # ret_list = []
        # for batch in chunks(input, 128):
        idxlist = self.ret_idx([input], lsh, train_mode=train_mode)
        ret_list = []
        for idx in idxlist:
            for i in range(startat, endat):
                ret_list.append(ex_list[idx[i]])
        # ret_list = [ex_list[idx[i]] for idx in idxlist for i in range(startat, endat)]
        return self.make_editexamples(ret_list, [input])

    ####
    # Editor code (identical to editor)

    def preprocess(self, examples, volatile=False):
        """Preprocess a batch of EditExamples, converting them into arrays.

        Args:
            examples (list[EditExample])

        Returns:
            EditorInput
        """
        input_words, target_words = self._batch_editor_examples(examples)
        dynamic_vocabs = self._compute_dynamic_vocabs(input_words,
                                                      self.base_vocab)

        dynamic_token_embedder = DynamicMultiVocabTokenEmbedder(
            self.base_source_token_embedder, dynamic_vocabs, self.base_vocab)

        # WARNING:
        # Note that we currently use the same token embedder for both inputs to the encoder and
        # inputs to the decoder. In the future, we may use a different embedder for the decoder.

        encoder_input = self.encoder.preprocess(input_words,
                                                target_words,
                                                dynamic_token_embedder,
                                                volatile=volatile)
        train_decoder_input = HardCopyTrainDecoderInput(
            target_words, dynamic_vocabs, self.base_vocab)
        return EditorInput(encoder_input, train_decoder_input)

    def _compute_dynamic_vocabs(self, input_batches, vocab):
        """Compute dynamic vocabs for each example.

        Args:
            input_batches (list[list[list[unicode]]]): a batch of input lists,
                where each input list is a list of sentences
            vocab (HardCopyVocab)

        Returns:
            list[HardCopyDynamicVocab]: a batch of dynamic vocabs, one for each example
        """
        dynamic_vocabs = []
        for input_words in input_batches:
            # compute dynamic vocab from concatenation of input sequences
            #concat = flatten(input_words)
            #dynamic_vocabs.append(HardCopyDynamicVocab(vocab, concat))
            dynamic_vocabs.append(
                HardCopyDynamicVocab(vocab, input_words, self.copy_lens))
        return dynamic_vocabs

    def forward(self, editor_input):
        """Return the training loss.

        Args:
            editor_input (EditorInput)

        Returns:
            loss (Variable): scalar
            losses (SequenceBatch): of shape (batch_size, seq_length)
        """
        encoder_output, enc_loss = self.encoder(editor_input.encoder_input)
        total_loss, losses = self.train_decoder.loss(
            encoder_output, editor_input.train_decoder_input)
        return total_loss + enc_loss, losses, enc_loss

    def vocab_probs(self, editor_input):
        """Return raw vocab_probs

        Args:
            editor_input (EditorInput)

        Returns:
            vocab_probs (list[Variable]) contains softmax variables
        """
        encoder_output, enc_loss = self.encoder(editor_input.encoder_input)
        vocab_probs = self.train_decoder.vocab_probs(
            encoder_output, editor_input.train_decoder_input)
        return vocab_probs

    def loss(self, examples, assert_train=True):
        """Compute loss Variable.

        Args:
            examples (list[EditExample])

        Returns:
            loss (Variable): of shape 1
            loss_traces (list[LossTrace])
        """

        editor_input = self.preprocess(examples)
        total_loss, losses, enc_loss = self(editor_input)
        # list of length batch_size. Each element is a 1D numpy array of per-token losses
        per_ex_losses = list(losses.values.data.cpu().numpy())
        return total_loss, [
            LossTrace(ex, l, self.base_vocab)
            for ex, l in zip(examples, per_ex_losses)
        ], enc_loss.data.cpu().numpy()

    def test_batch(self, examples):
        """simple batching test"""
        return
        if len(examples) > 1:
            ex1, ex2 = examples[0:2]

            loss = lambda batch: self.loss(batch, assert_train=False)[
                0].data.cpu().numpy()

            self.eval()  # test mode, to disable randomness of dropout
            np.random.seed(0)
            torch.manual_seed(0)
            lindivid = loss([ex1]) + loss([ex2])
            np.random.seed(0)
            torch.manual_seed(0)
            ltogether = loss([ex1, ex2]) * 2.0
            self.train()  # set back to train mode

            if abs(lindivid - ltogether) > 1e-3:
                print examples[0:2]
                print 'individually:'
                print lindivid
                print 'batched:'
                print ltogether
                raise Exception(
                    'batching error - examples do not produce identical results under batching'
                )
        else:
            raise Exception(
                'test_batch called with example list of length < 2')
        print 'Passed batching test'

    def edit(self,
             examples,
             max_seq_length=150,
             beam_size=5,
             batch_size=64,
             constrain_vocab=False,
             verbose=False):
        """Performs edits on a batch of source sentences.

        Args:
            examples (list[EditExample])
            max_seq_length (int): max # timesteps to generate for
            beam_size (int): for beam decoding
            batch_size (int): max number of examples to pass into the RNN decoder at a time.
                The total # examples decoded in parallel = batch_size / beam_size.
            constrain_vocab (bool):
default is False

        Returns:
            beam_list (list[list[list[unicode]]]): a batch of beams.
            edit_traces (list[EditTrace])
        """
        self.eval()  # set to evaluation mode, for dropout to work correctly
        beam_list = []
        edit_traces = []

        batches = chunks(examples, batch_size / beam_size)
        batches = verboserate(batches,
                              desc='Decoding examples') if verbose else batches
        for batch in batches:
            beams, traces = self._edit_batch(batch, max_seq_length, beam_size,
                                             constrain_vocab)
            beam_list.extend(beams)
            edit_traces.extend(traces)
        self.train()  # set back to train mode
        return beam_list, edit_traces

    def _edit_batch(self, examples, max_seq_length, beam_size,
                    constrain_vocab):
        # should only run in evaluation mode
        assert not self.training

        input_words, output_words = self._batch_editor_examples(examples)
        base_vocab = self.base_vocab
        dynamic_vocabs = self._compute_dynamic_vocabs(input_words, base_vocab)
        dynamic_token_embedder = DynamicMultiVocabTokenEmbedder(
            self.base_source_token_embedder, dynamic_vocabs, base_vocab)

        encoder_input = self.encoder.preprocess(input_words,
                                                output_words,
                                                dynamic_token_embedder,
                                                volatile=True)
        encoder_output, _ = self.encoder(encoder_input)

        extension_probs_modifiers = []

        if constrain_vocab:
            whitelists = [flatten(ex.input_words) for ex in examples
                          ]  # will contain duplicates, that's ok
            vocab_constrainer = LexicalWhitelister(whitelists, self.base_vocab,
                                                   word_to_forms)
            extension_probs_modifiers.append(vocab_constrainer)

        beams, decoder_traces = self.test_decoder_beam.decode(
            examples,
            encoder_output,
            beam_size=beam_size,
            max_seq_length=max_seq_length,
            extension_probs_modifiers=extension_probs_modifiers)

        # replace copy tokens in predictions with actual words, modifying beams in-place
        for beam, dyna_vocab in izip(beams, dynamic_vocabs):
            copy_to_word = dyna_vocab.copy_token_to_word
            for i, seq in enumerate(beam):
                beam[i] = [copy_to_word.get(w, w) for w in seq]

        return beams, [
            EditTrace(ex, d_trace.beam_traces[-1],
                      dyna_vocab) for ex, d_trace, dyna_vocab in izip(
                          examples, decoder_traces, dynamic_vocabs)
        ]

    def interact(self, beam_size=8, constrain_vocab=False, verbose=True):
        ex = EditExample.from_prompt()
        beam_list, edit_traces = self.edit([ex],
                                           beam_size=beam_size,
                                           constrain_vocab=constrain_vocab)
        beam = beam_list[0]
        output_words = beam[0]
        edit_trace = edit_traces[0]

        # nll = lambda example: self.loss([example]).data[0]

        # TODO: make this fully generative in the right way.. current NLL is wrong, disabled for now.
        # compare NLL of correct output and predicted output
        # output_ex = EditExample(ex.source_words, ex.insert_words, ex.delete_words, output_words)
        # gold_nll = nll(ex)
        # output_nll = nll(output_ex)

        print 'output:'
        print ' '.join(output_words)

        if verbose:
            # print
            # print 'output NLL: {}, gold NLL: {}'.format(output_nll, gold_nll)
            print edit_trace

    def get_vectors(self, tset):
        """
        :param tset: list of training examples
        :return: vec_list (joint encoding) and vec_list_in (context encoding)
        """
        vec_list = []
        vec_list_in = []
        for titem in chunks(tset, 128):
            edit_proc = self.preprocess(titem, volatile=True)
            agenda_out = self.encoder.target_out(edit_proc.encoder_input)
            agenda_in, _ = self.encoder.ctx_code_out(edit_proc.encoder_input)
            amat = agenda_out.data.cpu().numpy()
            amat_in = agenda_in.data.cpu().numpy()
            for i in range(amat.shape[0]):
                avec = amat[i] + amat_in[i]
                anorm = np.linalg.norm(avec)
                vec_list.append(avec / anorm)
                vec_list_in.append(amat_in[i] / np.linalg.norm(amat_in[i]))
        return vec_list, vec_list_in
Ejemplo n.º 6
0
class Editor(Module):
    """Editor.

    Attributes:
        encoder (Encoder)
        train_decoder (TrainDecoder)
    """

    def __init__(self, token_embedder, hidden_dim, agenda_dim, edit_dim, lamb_reg, norm_eps, norm_max, kill_edit, decoder_cell, encoder_layers):
        """Construct Editor.

        Args:
            token_embedder (TokenEmbedder)
            hidden_dim (int)
            agenda_dim (int)
            edit_dim (int)
            decoder_cell (DecoderCell)
            encoder_layers (int)
        """
        super(Editor, self).__init__()
        self.encoder = Encoder(token_embedder, agenda_dim, edit_dim, hidden_dim, lamb_reg, norm_eps, norm_max, kill_edit, encoder_layers,
                                   rnn_cell_factory=LSTMCell)
        context_combiner = AttentionContextCombiner()
        self.train_decoder = TrainDecoder(decoder_cell, token_embedder, context_combiner)
        self.test_decoder_beam = BeamDecoder(decoder_cell, token_embedder, context_combiner)

    @classmethod
    def _batch_editor_examples(cls, examples):
        batch = lambda attr: [getattr(ex, attr) for ex in examples]
        source_words = batch('source_words')
        insert_words = batch('insert_words')
        insert_exact_words = batch('insert_exact_words')
        delete_words = batch('delete_words')
        delete_exact_words = batch('delete_exact_words')
        target_words = batch('target_words')

        edit_embed_list = [ex.edit_embed for ex in examples]

        # either they all have edit embeds, or they all don't.
        if edit_embed_list[0] is None:
            assert all(e is None for e in edit_embed_list)
            edit_embed = None
        else:
            assert all(e is not None for e in edit_embed_list)
            edit_embed = np.stack(edit_embed_list, axis=0)

        return source_words, insert_words, insert_exact_words, delete_words, delete_exact_words, target_words, edit_embed

    def preprocess(self, examples):
        """Preprocess a batch of EditExamples, converting them into arrays.

        Args:
            examples (list[EditExample])

        Returns:
            EditorInput
        """
        source_words, insert_words, insert_exact_words, delete_words, delete_exact_words, target_words, edit_embed = self._batch_editor_examples(
            examples)
        encoder_input = self.encoder.preprocess(source_words, insert_words, insert_exact_words, delete_words,
                                                delete_exact_words, edit_embed)
        train_decoder_input = TrainDecoderInput(target_words, self.train_decoder.word_vocab)
        return EditorInput(encoder_input, train_decoder_input)

    def forward(self, editor_input, draw_samples, draw_p=False):
        """Return the training loss.

        Args:
            editor_input (EditorInput)
            draw_samples (bool) : flag for whether to add noise for variational approx. 

        Returns:
            loss (Variable): scalar
        """
        encoder_output = self.encoder(editor_input.encoder_input, draw_samples, draw_p)
        total_loss = self.train_decoder.loss(encoder_output, editor_input.train_decoder_input)
        return total_loss

    def loss(self, examples, draw_samples=False, draw_p=False):
        """Compute loss Variable.

        Args:
            examples (list[EditExample])
            draw_samples (bool) : flag for whether to add noise for variational approx. disable at test time.

        Returns:
            loss (Variable): of shape 1
        """
        editor_input = self.preprocess(examples)
        total_loss = self(editor_input, draw_samples, draw_p)
        if draw_samples:
            total_loss += self.encoder.regularizer(editor_input.encoder_input)
        return total_loss

    def per_instance_losses(self, examples, draw_samples=False, batch_size=128):
        """Compute per-instance losses."""
        per_instance_loss_list = []
        for batch in chunks(examples, batch_size):
            editor_input = self.preprocess(batch)
            encoder_output = self.encoder(editor_input.encoder_input, draw_samples)
            ilosses = self.train_decoder.per_instance_losses(encoder_output, editor_input.train_decoder_input)
            per_instance_loss_list.extend([loss.data.cpu().numpy()[0] for loss in ilosses])
        return per_instance_loss_list

    def test_batch(self, examples):
        """simple batching test"""
        if len(examples) > 1:
            lindivid = self.loss([examples[0]]) + self.loss([examples[1]])
            ltogether = self.loss(examples[0:2])*2.0
            if abs(lindivid.data.cpu().numpy() - ltogether.data.cpu().numpy()) > 1e-5:
                print examples[0:2]
                print 'individually:'
                print lindivid
                print 'batched:'
                print ltogether
                print 'possible batching issue'
                #raise Exception('batching error - examples do not produce identical results under batching')
        else:
            raise Exception('test_batch called with example list of length < 2')
        print 'Passed batching test'

    def edit(self, examples, max_seq_length=35, beam_size=10, batch_size=256):
        """Performs edits on a batch of source sentences.

        Args:
            examples (list[EditExample])
            max_seq_length (int): max # timesteps to generate for
            beam_size (int): for beam decoding
            batch_size (int): max number of examples to pass into the RNN decoder at a time.
                The total # examples decoded in parallel = batch_size / beam_size.

        Returns:
            beam_list (list[list[list[unicode]]]): a batch of beams.
            edit_traces (list[EditTrace])
        """
        beam_list = []
        edit_traces = []
        for batch in chunks(examples, batch_size / beam_size):
            beams, traces = self._edit_batch(batch, max_seq_length, beam_size)
            beam_list.extend(beams)
            edit_traces.extend(traces)
        return beam_list, edit_traces

    def _edit_batch(self, examples, max_seq_length, beam_size):
        source_words, insert_words, insert_exact_words, delete_words, delete_exact_words, _, edit_embed = self._batch_editor_examples(
            examples)
        encoder_input = self.encoder.preprocess(source_words, insert_words, insert_exact_words, delete_words,
                                                delete_exact_words, edit_embed)
        encoder_output = self.encoder(encoder_input)

        beams, decoder_traces = self.test_decoder_beam.decode(examples, encoder_output, weighted_value_estimators=[]
                                                                     , beam_size=beam_size, prefix_hints = [[]]
                                                                     , sibling_penalty=0, max_seq_length=max_seq_length)

        return beams, [EditTrace(ex, d_trace.beam_traces[-1]) for ex, d_trace in izip(examples, decoder_traces)]

    def interact(self, beam_size=8, verbose=True):
        ex = EditExample.from_prompt()
        output_words_batch, edit_traces = self.edit([ex], beam_size=beam_size)
        output_words = output_words_batch[0]
        edit_trace = edit_traces[0]

        # nll = lambda example: self.loss([example]).data[0]

        # TODO: make this fully generative in the right way.. current NLL is wrong, disabled for now.
        # compare NLL of correct output and predicted output
        # output_ex = EditExample(ex.source_words, ex.insert_words, ex.delete_words, output_words)
        # gold_nll = nll(ex)
        # output_nll = nll(output_ex)

        print 'output:'
        print ' '.join(output_words)

        if verbose:
            # print
            # print 'output NLL: {}, gold NLL: {}'.format(output_nll, gold_nll)
            print edit_trace
Ejemplo n.º 7
0
class Editor(Module):
    """Editor.

    Attributes:
        encoder (Encoder)
        train_decoder (TrainDecoder)
    """

    def __init__(self, token_embedder, hidden_dim, agenda_dim, edit_dim, lamb_reg, norm_eps, norm_max, kill_edit, decoder_cell, encoder_layers, num_iter=None, eps=None, momentum=None):
        """Construct Editor.

        Args:
            token_embedder (TokenEmbedder)
            hidden_dim (int)
            agenda_dim (int)
            edit_dim (int)
            decoder_cell (DecoderCell)
            encoder_layers (int)
        """
        super(Editor, self).__init__()
        self.encoder = Encoder(token_embedder, agenda_dim, edit_dim, hidden_dim, lamb_reg, norm_eps, norm_max, kill_edit, encoder_layers,
                                   rnn_cell_factory=LSTMCell)
        context_combiner = AttentionContextCombiner()
        self.train_decoder = TrainDecoder(decoder_cell, token_embedder, context_combiner)
        self.test_decoder_beam = BeamDecoder(decoder_cell, token_embedder, context_combiner)
        update_params = list(self.train_decoder.parameters())
        self.meta_optimizer = OptimN2N(self.encoder, 
                            self.train_decoder, 
                            update_params, 
			    iters=num_iter,
                            eps=eps,
                            momentum=momentum)
        # meta_optimizer has default settings
        # Todo: add hypereparamters to tune!!!

    @classmethod
    def _batch_editor_examples(cls, examples):
        batch = lambda attr: [getattr(ex, attr) for ex in examples]
        source_words = batch('source_words')
        insert_words = batch('insert_words')
        insert_exact_words = batch('insert_exact_words')
        delete_words = batch('delete_words')
        delete_exact_words = batch('delete_exact_words')
        target_words = batch('target_words')

        edit_embed_list = [ex.edit_embed for ex in examples]

        # either they all have edit embeds, or they all don't.
        if edit_embed_list[0] is None:
            assert all(e is None for e in edit_embed_list)
            edit_embed = None
        else:
            assert all(e is not None for e in edit_embed_list)
            edit_embed = np.stack(edit_embed_list, axis=0)

        return source_words, insert_words, insert_exact_words, delete_words, delete_exact_words, target_words, edit_embed

    def preprocess(self, examples):
        """Preprocess a batch of EditExamples, converting them into arrays.

        Args:
            examples (list[EditExample])

        Returns:
            EditorInput
        """
        source_words, insert_words, insert_exact_words, delete_words, delete_exact_words, target_words, edit_embed = self._batch_editor_examples(
            examples)
        encoder_input = self.encoder.preprocess(source_words, insert_words, insert_exact_words, delete_words,
                                                delete_exact_words, edit_embed)
        train_decoder_input = TrainDecoderInput(target_words, self.train_decoder.word_vocab)
        return EditorInput(encoder_input, train_decoder_input)

    def forward(self, editor_input, draw_samples, draw_p=False):
        """Return the training loss.

        Args:
            editor_input (EditorInput)
            draw_samples (bool) : flag for whether to add noise for variational approx. 

        Returns:
            loss (Variable): scalar
        """
        self.encoder_output = self.encoder(editor_input.encoder_input, draw_samples, draw_p)
        #total_loss = self.train_decoder.loss(encoder_output, editor_input.train_decoder_input) <-- original.

        mean, logvar = self.encoder_output.agenda
        var_params = torch.cat([mean, logvar], 1) 
        mean_svi = Variable(self.encoder_output.agenda[0].data, requires_grad=True)
        logvar_svi = Variable(self.encoder_output.agenda[1].data, requires_grad=True)

        var_params_svi = self.meta_optimizer.forward([mean_svi, logvar_svi], self.encoder_output, editor_input.train_decoder_input)
        # encoder_output.source_embeds or  editor_input.train_decoder_input?
        # verbose False above

        self.mean_svi_final, self.logvar_svi_final = var_params_svi
        self.z_samples = self.encoder._reparameterize(self.mean_svi_final, self.logvar_svi_final)
        #self.encoder_output.agenda = self.z_samples
        self.encoder_output = EncoderOutput(self.encoder_output.source_embeds, self.encoder_output.insert_embeds, self.encoder_output.delete_embeds, self.z_samples)
        var_loss = self.train_decoder.loss(self.encoder_output, editor_input.train_decoder_input)
        var_loss.backward(retain_variables=True)
        # the above is nll_svi.
        
        var_param_grads = self.meta_optimizer.backward([self.mean_svi_final.grad, self.logvar_svi_final.grad])
        # verbose set to False above
        var_param_grads = torch.cat(var_param_grads, 1)
        var_params.backward(var_param_grads, retain_variables=True)

        #return total_loss
        return var_loss, var_params, var_param_grads

    def loss(self, examples, draw_samples=False, draw_p=False):
        """Compute loss Variable.

        Args:
            examples (list[EditExample])
            draw_samples (bool) : flag for whether to add noise for variational approx. disable at test time.

        Returns:
            loss (Variable): of shape 1
        """
        editor_input = self.preprocess(examples)
        #total_loss = self(editor_input, draw_samples, draw_p)
        var_loss, var_params, var_param_grads = self(editor_input, draw_samples, draw_p)
        #reg_loss = 0
        #if draw_samples:
        #    reg_loss += self.encoder.regularizer(editor_input.encoder_input)
        return var_loss, var_params, var_param_grads #, reg_loss

    def per_instance_losses(self, examples, draw_samples=False, batch_size=128):
        """Compute per-instance losses."""
        per_instance_loss_list = []
        for batch in chunks(examples, batch_size):
            editor_input = self.preprocess(batch)
            encoder_output = self.encoder(editor_input.encoder_input, draw_samples)
            ilosses = self.train_decoder.per_instance_losses(encoder_output, editor_input.train_decoder_input)
            per_instance_loss_list.extend([loss.data.cpu().numpy()[0] for loss in ilosses])
        return per_instance_loss_list

    def test_batch(self, examples):
        """simple batching test"""
        if len(examples) > 1:
            var_loss0, var_params0, var_param_grads0 = self.loss([examples[0]])
            lindivid = torch.abs(var_loss0) + torch.abs(torch.mean(var_params0))
            var_loss1, var_params1, var_param_grads1 = self.loss([examples[1]])
            lindivid += torch.abs(var_loss1) + torch.abs(torch.mean(var_params1))

            #lindivid = self.loss([examples[0]]) + self.loss([examples[1]])
            var_loss2, var_params2, var_param_grads2 = self.loss(examples[0:2])
            ltogether = (torch.abs(var_loss2) + torch.abs(torch.mean(var_params2))) * 2.0
            #ltogether = self.loss(examples[0:2]) * 2.0
            if abs(lindivid.data.cpu().numpy() - ltogether.data.cpu().numpy()) > 1e-5:
                print examples[0:2]
                print 'individually:'
                print lindivid
                print 'batched:'
                print ltogether
                raise Exception('batching error - examples do not produce identical results under batching')
        else:
            raise Exception('test_batch called with example list of length < 2')
        print 'Passed batching test'

    def edit(self, examples, max_seq_length=35, beam_size=5, batch_size=256):
        """Performs edits on a batch of source sentences.

        Args:
            examples (list[EditExample])
            max_seq_length (int): max # timesteps to generate for
            beam_size (int): for beam decoding
            batch_size (int): max number of examples to pass into the RNN decoder at a time.
                The total # examples decoded in parallel = batch_size / beam_size.

        Returns:
            beam_list (list[list[list[unicode]]]): a batch of beams.
            edit_traces (list[EditTrace])
        """
        beam_list = []
        edit_traces = []
        for batch in chunks(examples, batch_size / beam_size):
            beams, traces = self._edit_batch(batch, max_seq_length, beam_size)
            beam_list.extend(beams)
            edit_traces.extend(traces)
        return beam_list, edit_traces

    def _edit_batch(self, examples, max_seq_length, beam_size):
        #source_words, insert_words, insert_exact_words, delete_words, delete_exact_words, target_words, edit_embed = self._batch_editor_examples(examples)
        #encoder_input = self.encoder.preprocess(source_words, insert_words, insert_exact_words, delete_words, delete_exact_words, edit_embed)

        #############
        # New inference during eval,
        editor_input = self.preprocess(examples)
        self.encoder_output = self.encoder(editor_input.encoder_input)
        mean, logvar = self.encoder_output.agenda
        var_params = torch.cat([mean, logvar], 1) 
        mean_svi = Variable(self.encoder_output.agenda[0].data, requires_grad=True)
        logvar_svi = Variable(self.encoder_output.agenda[1].data, requires_grad=True)
        var_params_svi = self.meta_optimizer.forward([mean_svi, logvar_svi], self.encoder_output, editor_input.train_decoder_input)

        self.mean_svi_final, self.logvar_svi_final = var_params_svi
        self.z_samples = self.encoder._reparameterize(self.mean_svi_final, self.logvar_svi_final)
        #self.encoder_output.agenda = self.z_samples
        encoder_output = EncoderOutput(self.encoder_output.source_embeds, self.encoder_output.insert_embeds, self.encoder_output.delete_embeds, self.z_samples)
        #############
        #encoder_output = self.encoder(encoder_input)
        beams, decoder_traces = self.test_decoder_beam.decode(examples, 
            encoder_output, 
            weighted_value_estimators=[], 
            beam_size=beam_size, 
            prefix_hints = [[]], 
            sibling_penalty=0, 
            max_seq_length=max_seq_length)

        return beams, [EditTrace(ex, d_trace.beam_traces[-1]) for ex, d_trace in izip(examples, decoder_traces)]

    def interact(self, beam_size=8, verbose=True):
        ex = EditExample.from_prompt()
        output_words_batch, edit_traces = self.edit([ex], beam_size=beam_size)
        output_words = output_words_batch[0]
        edit_trace = edit_traces[0]

        # nll = lambda example: self.loss([example]).data[0]

        # TODO: make this fully generative in the right way.. current NLL is wrong, disabled for now.
        # compare NLL of correct output and predicted output
        # output_ex = EditExample(ex.source_words, ex.insert_words, ex.delete_words, output_words)
        # gold_nll = nll(ex)
        # output_nll = nll(output_ex)

        print 'output:'
        print ' '.join(output_words)

        if verbose:
            # print
            # print 'output NLL: {}, gold NLL: {}'.format(output_nll, gold_nll)
            print edit_trace
Ejemplo n.º 8
0
class LanguageModel(Module):
    def __init__(self, token_embedder, hidden_dim, agenda_dim, num_layers,
                 logger):
        super(LanguageModel, self).__init__()
        input_dim = token_embedder.embed_dim
        decoder_cell = MultilayeredDecoderCell(token_embedder, hidden_dim,
                                               input_dim, agenda_dim,
                                               num_layers)
        context_combiner = MultilayeredLMContextCombiner()
        self.train_decoder = TrainDecoder(decoder_cell, token_embedder,
                                          context_combiner)
        self.sample_decoder = SampleDecoder(decoder_cell, token_embedder,
                                            context_combiner)
        self.beam_decoder = BeamDecoder(decoder_cell, token_embedder,
                                        context_combiner)
        self.agenda_dim = agenda_dim
        self.agenda = Parameter(torch.zeros(agenda_dim))
        self.vocab = token_embedder.vocab
        self.logger = logger

    def _encoder_output(self, batch_size):
        return tile_state(self.agenda, batch_size)

    def per_instance_losses(self, examples):
        batch_size = len(examples)
        decoder_input = TrainDecoderInput(examples, self.vocab)
        encoder_output = self._encoder_output(batch_size)
        return self.train_decoder.per_instance_losses(encoder_output,
                                                      decoder_input)

    def loss(self, examples, train_step):
        """Compute training loss.

        Args:
            examples (list[list[unicode]])

        Returns:
            Variable: a scalar
        """
        batch_size = len(examples)
        decoder_input = TrainDecoderInput(examples, self.vocab)
        encoder_output = self._encoder_output(batch_size)
        return self.train_decoder.loss(encoder_output, decoder_input)

    def generate(self, num_samples, decode_method='argmax'):
        examples = range(num_samples)
        prefix_hints = [[]] * num_samples  # none
        encoder_output = self._encoder_output(num_samples)
        if decode_method == 'sample':
            output_beams, decoder_traces = self.sample_decoder.decode(
                examples,
                encoder_output,
                beam_size=1,
                prefix_hints=prefix_hints)
        elif decode_method == 'argmax':
            value_estimators = []
            beam_size = 1
            sibling_penalty = 0.
            output_beams, decoder_traces = self.beam_decoder.decode(
                examples,
                encoder_output,
                weighted_value_estimators=value_estimators,
                beam_size=beam_size,
                prefix_hints=prefix_hints,
                sibling_penalty=sibling_penalty)
        else:
            raise ValueError(decode_method)

        return [beam[0] for beam in output_beams]