Exemple #1
0
class NMT(object):
    def __init__(self,
                 embed_size,
                 hidden_size,
                 vocab,
                 dropout_rate=0.2,
                 keep_train=False):
        super(NMT, self).__init__()

        self.nvocab_src = len(vocab.src)
        self.nvocab_tgt = len(vocab.tgt)
        self.vocab = vocab
        self.encoder = Encoder(self.nvocab_src,
                               hidden_size,
                               embed_size,
                               input_dropout=dropout_rate,
                               n_layers=2)
        self.decoder = Decoder(self.nvocab_tgt,
                               2 * hidden_size,
                               embed_size,
                               output_dropout=dropout_rate,
                               n_layers=2,
                               tf_rate=1.0)
        if keep_train:
            self.load('model')
        LAS_params = list(self.encoder.parameters()) + list(
            self.decoder.parameters())
        self.optimizer = optim.Adam(LAS_params, lr=0.001)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer,
                                                         step_size=1,
                                                         gamma=0.5)
        weight = torch.ones(self.nvocab_tgt)
        self.loss = NLLLoss(weight=weight, mask=0, size_average=False)
        # TODO: Perplexity or NLLLoss
        # TODO: pass in mask to loss funciton
        #self.loss = Perplexity(weight, 0)

        if torch.cuda.is_available():
            # Move the network and the optimizer to the GPU
            self.encoder = self.encoder.cuda()
            self.decoder = self.decoder.cuda()
            self.loss.cuda()

    def __call__(self, src_sents, tgt_sents):
        """
        take a mini-batch of source and target sentences, compute the log-likelihood of 
        target sentences.

        Args:
            src_sents: list of source sentence tokens
            tgt_sents: list of target sentence tokens, wrapped by `<s>` and `</s>`

        Returns:
            scores: a variable/tensor of shape (batch_size, ) representing the 
                log-likelihood of generating the gold-standard target sentence for 
                each example in the input batch
        """
        src_sents = self.vocab.src.words2indices(src_sents)
        tgt_sents = self.vocab.tgt.words2indices(tgt_sents)
        src_sents, src_len, y_input, y_tgt, tgt_len = sent_padding(
            src_sents, tgt_sents)
        src_encodings, decoder_init_state = self.encode(src_sents, src_len)
        scores, symbols = self.decode(src_encodings,
                                      decoder_init_state, [y_input, y_tgt],
                                      stage="train")

        return scores

    def encode(self, src_sents, input_lengths):
        """
        Use a GRU/LSTM to encode source sentences into hidden states

        Args:
            src_sents: list of source sentence tokens

        Returns:
            src_encodings: hidden states of tokens in source sentences, this could be a variable 
                with shape (batch_size, source_sentence_length, encoding_dim), or in orther formats
            decoder_init_state: decoder GRU/LSTM's initial state, computed from source encodings
        """
        encoder_outputs, encoder_hidden = self.encoder(src_sents,
                                                       input_lengths)

        return encoder_outputs, encoder_hidden

    def decode(self,
               src_encodings,
               decoder_init_state,
               tgt_sents,
               stage="train"):
        """
        Given source encodings, compute the log-likelihood of predicting the gold-standard target
        sentence tokens

        Args:
            src_encodings: hidden states of tokens in source sentences
            decoder_init_state: decoder GRU/LSTM's initial state
            tgt_sents: list of gold-standard target sentences, wrapped by `<s>` and `</s>`

        Returns:
            scores: could be a variable of shape (batch_size, ) representing the 
                log-likelihood of generating the gold-standard target sentence for 
                each example in the input batch
        """
        tgt_input, tgt_target = tgt_sents
        loss = self.loss
        decoder_outputs, decoder_hidden, symbols = self.decoder(
            tgt_input, decoder_init_state, src_encodings)
        loss.reset()
        for step, step_output in enumerate(decoder_outputs):
            batch_size = tgt_input.size(0)
            loss.eval_batch(step_output.contiguous().view(batch_size, -1),
                            tgt_target[:, step])
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.encoder.parameters(), 5.0)
        torch.nn.utils.clip_grad_norm_(self.decoder.parameters(), 5.0)
        self.optimizer.step()
        scores = loss.get_loss()

        return scores, symbols

    def decode_without_bp(self, src_encodings, decoder_init_state, tgt_sents):
        """
        Given source encodings, compute the log-likelihood of predicting the gold-standard target
        sentence tokens

        Args:
            src_encodings: hidden states of tokens in source sentences
            decoder_init_state: decoder GRU/LSTM's initial state
            tgt_sents: list of gold-standard target sentences, wrapped by `<s>` and `</s>`

        Returns:
            scores: could be a variable of shape (batch_size, ) representing the
                log-likelihood of generating the gold-standard target sentence for
                each example in the input batch
        """
        tgt_input, tgt_target = tgt_sents
        loss = self.loss
        decoder_outputs, decoder_hidden, symbols = self.decoder(
            tgt_input, decoder_init_state, src_encodings, stage="valid")
        loss.reset()
        for step, step_output in enumerate(decoder_outputs):
            batch_size = tgt_input.size(0)
            loss.eval_batch(step_output.contiguous().view(batch_size, -1),
                            tgt_target[:, step])

        scores = loss.get_loss()

        return scores, symbols

    # TODO: sent_padding for only src
    # def beam_search(self, src_sent: List[str], beam_size: int=5, max_decoding_time_step: int=70) -> List[Hypothesis]:
    def beam_search(self, src_sent, beam_size, max_decoding_time_step):
        """
        Given a single source sentence, perform beam search

        Args:
            src_sent: a single tokenized source sentence
            beam_size: beam size
            max_decoding_time_step: maximum number of time steps to unroll the decoding RNN

        Returns:
            hypotheses: a list of hypothesis, each hypothesis has two fields:
                value: List[str]: the decoded target sentence, represented as a list of words
                score: float: the log-likelihood of the target sentence
        """

        hypotheses = 0
        return hypotheses

    # def evaluate_ppl(self, dev_data: List[Any], batch_size: int=32):
    def evaluate_ppl(self, dev_data, batch_size):
        """
        Evaluate perplexity on dev sentences

        Args:
            dev_data: a list of dev sentences
            batch_size: batch size
        
        Returns:
            ppl: the perplexity on dev sentences
        """

        ref_corpus = []
        hyp_corpus = []
        cum_loss = 0
        count = 0
        hyp_corpus_ordered = []
        with torch.no_grad():
            for src_sents, tgt_sents, orig_indices in batch_iter(
                    dev_data, batch_size):
                ref_corpus.extend(tgt_sents)
                actual_size = len(src_sents)
                src_sents = self.vocab.src.words2indices(src_sents)
                tgt_sents = self.vocab.tgt.words2indices(tgt_sents)
                src_sents, src_len, y_input, y_tgt, tgt_len = sent_padding(
                    src_sents, tgt_sents)
                src_encodings, decoder_init_state = self.encode(
                    src_sents, src_len)
                scores, symbols = self.decode_without_bp(
                    src_encodings, decoder_init_state, [y_input, y_tgt])
                #sents = np.zeros((len(symbols),actual_size))
                #for i,symbol in enumerate(symbols):
                #    sents[i,:] = symbol.data.cpu().numpy()
                # print(sents.T)

                index = 0
                batch_hyp_orderd = [None] * symbols.size(0)
                for sent in symbols:

                    word_seq = []
                    for idx in sent:
                        if idx == 2:
                            break
                        word_seq.append(
                            self.vocab.tgt.id2word[np.asscalar(idx)])
                    hyp_corpus.append(word_seq)
                    batch_hyp_orderd[orig_indices[index]] = word_seq
                    index += 1
                hyp_corpus_ordered.extend(batch_hyp_orderd)
                cum_loss += scores
                count += 1
        with open('decode.txt', 'a') as f:
            for r, h in zip(ref_corpus, hyp_corpus_ordered):
                f.write(" ".join(h) + '\n')
        bleu = compute_corpus_level_bleu_score(ref_corpus, hyp_corpus)
        print('bleu score: ', bleu)

        return cum_loss / count

    # @staticmethod
    def load(self, model_path):

        self.encoder.load_state_dict(torch.load(model_path + '-encoder'))
        self.decoder.load_state_dict(torch.load(model_path + '-decoder'))
        # self.encoder.eval()
        # self.decoder.eval()

    def save(self, model_save_path):
        """
        Save current model to file
        """
        torch.save(self.encoder.state_dict(), model_save_path + '-encoder')
        torch.save(self.decoder.state_dict(), model_save_path + '-decoder')
Exemple #2
0
def main(args):
    # Config logging
    logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
    logger = logging.getLogger()

    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Image preprocessing
    transform = transforms.Compose([
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # Load vocabulary wrapper.
    vocab = load_vocab(args.vocab_path)

    # Build data loader
    logger.info("Building data loader...")
    data_loader = get_loader(args.image_dir,
                             args.caption_path,
                             vocab,
                             transform,
                             args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers)
    logger.info("Done")

    # Build the models
    logger.info("Building image captioning models...")
    vqg = VQGModel(len(vocab),
                   args.max_length,
                   args.hidden_size,
                   vocab(vocab.sos),
                   vocab(vocab.eos),
                   rnn_cell=args.rnn_cell)

    logger.info("Done")

    if torch.cuda.is_available():
        vqg.cuda()

    # Loss and Optimizer
    weight = torch.ones(len(vocab))
    pad = vocab(vocab.pad)  # Set loss weight for 'pad' symbol to 0
    loss = NLLLoss(weight, pad)
    if torch.cuda.is_available():
        loss.cuda()

    # Parameters to train
    params = vqg.params_to_train()
    optimizer = torch.optim.Adam(params, lr=args.learning_rate)

    # Train the Models
    total_step = len(data_loader)
    for epoch in range(args.num_epochs):
        for i, (images, questions, answers) in enumerate(data_loader):

            # Set mini-batch dataset
            images = Variable(images)
            questions = Variable(questions)
            answers = Variable(answers)
            if torch.cuda.is_available():
                images = images.cuda()
                questions = questions.cuda()
                answers = answers.cuda()

            # Forward, Backward and Optimize
            vqg.zero_grad()
            outputs, hiddens, other = vqg(images,
                                          questions,
                                          teacher_forcing_ratio=1.0)

            # Get loss
            loss.reset()
            for step, step_output in enumerate(outputs):
                batch_size = questions.size(0)
                loss.eval_batch(step_output.contiguous().view(batch_size, -1),
                                questions[:, step + 1])
            loss.backward()
            optimizer.step()

            # Print log info
            if i % args.log_step == 0:
                logger.info(
                    'Epoch [%d/%d], Step [%d/%d], Loss: %.4f' %
                    (epoch, args.num_epochs, i, total_step, loss.get_loss()))

            # Save the models
            if (i + 1) % args.save_step == 0:
                torch.save(
                    vqg.state_dict(),
                    os.path.join(args.model_path,
                                 'vqg-%d-%d.pkl' % (epoch + 1, i + 1)))
Exemple #3
0
                                                        fields=[('src', src),
                                                                ('tgt', tgt)],
                                                        filter_pred=len_filter)

src.build_vocab(train, max_size=50000)
tgt.build_vocab(train, max_size=50000)
input_vocab = src.vocab
output_vocab = tgt.vocab

# Prepare loss
weight = torch.ones(len(tgt.vocab))
pad = tgt.vocab.stoi[tgt.pad_token]

loss = NLLLoss(weight, pad)
if torch.cuda.is_available():
    loss.cuda()
else:
    print("*********** no cuda **************")

seq2seq_m = None

# Initialize model
hidden_size = 512
bidirectional = True
num_epochs = 500

# Initialize models
encoder = EncoderRNN(len(input_vocab),
                     max_len,
                     hidden_size,
                     bidirectional=True,