Exemplo n.º 1
0
    def train_batch(self, samples, use_cuda, tf_ratio=0.5, backprop=True, coverage_lambda=-1):
        start = time.time()
        if len(samples) == 0: return 0, 0

        target_length = min(self.config['target_length'], max([len(pair.masked_target_tokens) for pair in samples]))
        nb_unks = max([len(s.unknown_tokens) for s in samples])
        input_variable, full_input_variable, target_variable, full_target_variable, decoder_input = \
            utils.get_batch_variables(samples, self.config['input_length'], target_length, use_cuda,
                                      self.vocab.word2index['SOS'])

        encoder_hidden = self.encoder.init_hidden(len(samples), use_cuda)
        self.encoder_optimizer.zero_grad()
        self.decoder_optimizer.zero_grad()
        loss = 0

        encoder_outputs, encoder_hidden = self.encoder(input_variable, encoder_hidden)
        decoder_hidden = torch.cat((encoder_hidden[0], encoder_hidden[1]), -1)
        decoder_h_states = torch.cat((encoder_hidden[0], encoder_hidden[1]), -1).unsqueeze(1)
        previous_att = None

        for token_i in range(target_length):
            p_final, p_gen, p_vocab, att_dist, decoder_h_states, decoder_hidden, previous_att = \
                self.decoder(decoder_input, decoder_h_states, decoder_hidden, encoder_outputs,
                             full_input_variable, previous_att, nb_unks, use_cuda)

            if coverage_lambda < 0 or token_i == 0:
                loss += self.criterion(torch.log(p_final.clamp(min=1e-8)), full_target_variable.narrow(1, token_i, 1)
                                       .squeeze(-1))
            else:
                coverage = previous_att.narrow(1, 0, previous_att.size()[1]-1).sum(dim=1)
                coverage_min, _ = torch.cat((att_dist.unsqueeze(1), coverage.unsqueeze(1)), dim=1).min(dim=1)
                coverage_loss = coverage_min.sum(-1)
                loss += self.criterion(torch.log(p_final.clamp(min=1e-8)), full_target_variable.narrow(1, token_i, 1).squeeze(-1))\
                        + (coverage_lambda * coverage_loss) # this needs to be fixed

            if random.uniform(0, 1) < tf_ratio: decoder_input = target_variable.narrow(1, token_i, 1)
            else:
                _, max_tokens = p_final.max(1)
                for i in range(max_tokens.size()[0]):
                    if max_tokens.data[i] >= self.vocab.vocab_size: max_tokens.data[i] = self.vocab.word2index['UNK']
                decoder_input = max_tokens.unsqueeze(1)
        if backprop:
            loss.backward()
            torch.nn.utils.clip_grad_norm(self.encoder.parameters(), 2)
            torch.nn.utils.clip_grad_norm(self.decoder.parameters(), 2)
            self.encoder_optimizer.step()
            self.decoder_optimizer.step()

        '''
        print(" ", [[t for t in pair.full_target_tokens if t not in pair.full_source_tokens and t >= self.vocab.vocab_size]
                    for pair in samples])
        '''
        return loss.data[0] / target_length, time.time() - start
Exemplo n.º 2
0
    def train_batch(self, samples, use_cuda, backprop=True):
        start = time.time()
        input_variable, full_input_variable, target_variable, full_target_variable, decoder_input = \
                        utils.get_batch_variables(samples, 400, 50, use_cuda, self.vocab.word2index['SOS'])

        novelty_score = Variable(torch.FloatTensor(self.compute_novelty_n_grams(samples, 3)))
        if use_cuda: novelty_score = novelty_score.cuda()

        loss = self.criterion(self.model(full_input_variable, full_target_variable), novelty_score)
        if backprop:
            loss.backward()
            self.optimizer.step()
        return loss.data[0], time.time() - start
    def predict(self, samples, target_length, beam_size,
                use_cuda):  # this only works with one sample at a time
        nb_unks = max([len(s.unknown_tokens) for s in samples])
        input_variable, full_input_variable, target_variable, full_target_variable, decoder_input = \
            utils.get_batch_variables(samples, self.config['input_length'], target_length, use_cuda,
                                      self.vocab.word2index['SOS'])
        encoder_hidden = self.encoder.init_hidden(len(samples), use_cuda)
        encoder_outputs, encoder_hidden = self.encoder(input_variable,
                                                       encoder_hidden)
        decoder_hidden = torch.cat((encoder_hidden[0], encoder_hidden[1]), -1)
        decoder_h_states = torch.cat((encoder_hidden[0], encoder_hidden[1]),
                                     -1).unsqueeze(1)
        previous_att = None

        if not beam_size:
            result = []
            for token_i in range(target_length):

                p_final, p_gen, p_vocab, att_dist, decoder_h_states, decoder_hidden, previous_att = \
                    self.decoder(decoder_input, decoder_h_states, decoder_hidden,
                                 encoder_outputs, full_input_variable, previous_att, nb_unks, use_cuda)

                p_vocab_word, vocab_word_idx = p_final.max(1)
                result.append([{
                    'token_idx':
                    vocab_word_idx.data[i],
                    'word':
                    utils.translate_word(vocab_word_idx.data[i], samples[i],
                                         self.vocab),
                    'p_gen':
                    round(p_gen.data[i][0], 3)
                } for i in range(len(samples))])
                _, max_tokens = p_final.max(1)
                for i in range(max_tokens.size()[0]):
                    if max_tokens.data[i] >= self.vocab.vocab_size:
                        max_tokens.data[i] = self.vocab.word2index['UNK']
                decoder_input = max_tokens.unsqueeze(1)

            return result
        else:
            search_complete = False
            top_beams = [
                Beam(decoder_input, decoder_h_states, decoder_hidden,
                     previous_att, [], [])
            ]

            while not search_complete:
                new_beams = []
                for beam in top_beams:
                    if beam.complete: new_beams.append(beam)
                    else:
                        p_final, p_gen, p_vocab, att_dist, decoder_h_states, decoder_hidden, previous_att = \
                            self.decoder(beam.decoder_input, beam.decoder_h_states, beam.decoder_hidden,
                                         encoder_outputs, full_input_variable, beam.previous_att, nb_unks, use_cuda)
                        for k in range(beam_size):
                            p_vocab_word, vocab_word_idx = p_final.max(1)
                            _, max_tokens = p_final.max(1)
                            if max_tokens.data[0] >= self.vocab.vocab_size:
                                max_tokens.data[0] = self.vocab.word2index[
                                    'UNK']
                            new_beams.append(
                                Beam(max_tokens.unsqueeze(1), decoder_h_states,
                                     decoder_hidden, previous_att,
                                     beam.log_probs + [p_vocab_word.data[0]],
                                     beam.sequence + [vocab_word_idx.data[0]]))
                            p_final[0, vocab_word_idx.data[0]] = 0

                            if len(new_beams[-1].sequence
                                   ) == target_length or vocab_word_idx.data[
                                       0] == self.vocab.word2index['EOS']:
                                new_beams[-1].complete = True

                all_beams = sorted([(b, b.compute_score()) for b in new_beams],
                                   key=lambda tup: tup[1])
                if len(all_beams) > beam_size:
                    all_beams = all_beams[:beam_size]
                top_beams = [beam[0] for beam in all_beams]

                if len([True for b in top_beams if b.complete]) == beam_size:
                    search_complete = True

            return [[
                " ".join([
                    utils.translate_word(t, samples[0], self.vocab)
                    for t in b.sequence
                ]),
                b.compute_score()
            ] for b in top_beams]
Exemplo n.º 4
0
    def predict(self, samples, target_length, beam_size,
                use_cuda):  # this only works with one sample at a time
        nb_unks = max([len(s.unknown_tokens) for s in samples])
        input_variable, full_input_variable, target_variable, full_target_variable, decoder_input = \
            utils.get_batch_variables(samples, self.config['input_length'], target_length, use_cuda,
                                      self.vocab.word2index['SOS'])

        encoder_hidden = self.encoder.init_hidden(len(samples), use_cuda)
        encoder_outputs, encoder_hidden = self.encoder(input_variable,
                                                       encoder_hidden)
        decoder_hidden = torch.cat((encoder_hidden[0], encoder_hidden[1]), -1)
        decoder_h_states = torch.cat((encoder_hidden[0], encoder_hidden[1]),
                                     -1).unsqueeze(1)
        previous_att = None

        if not beam_size:
            result = []
            for token_i in range(target_length):

                p_final, p_gen, p_vocab, att_dist, decoder_h_states, decoder_hidden, previous_att = \
                    self.decoder(decoder_input, decoder_h_states, decoder_hidden,
                                 encoder_outputs, full_input_variable, previous_att, nb_unks, use_cuda)

                p_vocab_word, vocab_word_idx = p_final.max(1)
                result.append([{
                    'token_idx':
                    vocab_word_idx.data[i],
                    'word':
                    utils.translate_word(vocab_word_idx.data[i], samples[i],
                                         self.vocab),
                    'p_gen':
                    round(p_gen.data[i][0], 3)
                } for i in range(len(samples))])
                _, max_tokens = p_final.max(1)
                for i in range(max_tokens.size()[0]):
                    if max_tokens.data[i] >= self.vocab.vocab_size:
                        max_tokens.data[i] = self.vocab.word2index['UNK']
                decoder_input = max_tokens.unsqueeze(1)

            return result

        else:
            search_complete = False
            top_beams = [
                Beam(decoder_input, decoder_h_states, decoder_hidden,
                     previous_att, [], [])
            ]

            def predict_for_beams(beams, encoder_outputs, full_input_variable):

                results = []

                encoder_outputs = torch.stack([
                    encoder_outputs[i] for beam in beams
                    for i in range(len(samples))
                ], 0)
                full_input_variable = torch.stack([
                    full_input_variable[i] for beam in beams
                    for i in range(len(samples))
                ], 0)
                decoder_input = torch.stack([
                    beam.decoder_input[i] for beam in beams
                    for i in range(len(samples))
                ], 0)
                decoder_h_states = torch.stack([
                    beam.decoder_h_states[i] for beam in beams
                    for i in range(len(samples))
                ], 0)
                decoder_hidden = torch.stack([
                    beam.decoder_hidden[i] for beam in beams
                    for i in range(len(samples))
                ], 0)

                if beams[0].previous_att is not None:
                    previous_att = torch.stack([
                        beam.previous_att[i] for beam in beams
                        for i in range(len(samples))
                    ], 0)
                else:
                    previous_att = None

                p_final, p_gen, p_vocab, att_dist, decoder_h_states, decoder_hidden, previous_att = \
                    self.decoder(decoder_input, decoder_h_states, decoder_hidden, encoder_outputs, full_input_variable,
                                 previous_att, nb_unks, use_cuda)

                for b in range(len(beams)):
                    results.append([beams[b]] + [
                        tensor.narrow(0, b * len(samples), len(samples))
                        for tensor in [
                            p_final, decoder_h_states, decoder_hidden,
                            previous_att
                        ]
                    ])

                return results

            while not search_complete:
                new_beams = []
                beams_to_predict = []

                for beam in top_beams:
                    if beam.complete:
                        new_beams.append(beam)
                    else:
                        beams_to_predict.append(beam)

                predictions = predict_for_beams(beams_to_predict,
                                                encoder_outputs,
                                                full_input_variable)
                for b in predictions:
                    beam = b[0]
                    p_final, decoder_h_states, decoder_hidden, previous_att = b[
                        1], b[2], b[3], b[4]

                    p_top_words, top_indexes = p_final.topk(beam_size)

                    for k in range(beam_size):
                        non_masked_word = top_indexes.data[0][k]
                        if top_indexes.data[0][k] >= self.vocab.vocab_size:
                            top_indexes.data[0][k] = self.vocab.word2index[
                                'UNK']

                        new_beams.append(
                            Beam(top_indexes.narrow(1, k, 1), decoder_h_states,
                                 decoder_hidden, previous_att,
                                 beam.log_probs + [p_top_words.data[0][k]],
                                 beam.sequence + [non_masked_word]))

                        if len(new_beams[-1].sequence) == target_length or top_indexes.data[0][k] == \
                                self.vocab.word2index['EOS']:
                            new_beams[-1].complete = True

                all_beams = sorted([(b, b.compute_score()) for b in new_beams],
                                   key=lambda tup: tup[1])
                if len(all_beams) > beam_size:
                    all_beams = all_beams[:beam_size]
                top_beams = [beam[0] for beam in all_beams]

                if len([True for b in top_beams if b.complete]) == beam_size:
                    search_complete = True

            return [[
                " ".join([
                    utils.translate_word(t, samples[0], self.vocab)
                    for t in b.sequence
                ]),
                b.compute_score()
            ] for b in top_beams]