class NMT(object): def __init__(self, embed_size, hidden_size, vocab, dropout_rate=0.2, keep_train=False): super(NMT, self).__init__() self.nvocab_src = len(vocab.src) self.nvocab_tgt = len(vocab.tgt) self.vocab = vocab self.encoder = Encoder(self.nvocab_src, hidden_size, embed_size, input_dropout=dropout_rate, n_layers=2) self.decoder = Decoder(self.nvocab_tgt, 2 * hidden_size, embed_size, output_dropout=dropout_rate, n_layers=2, tf_rate=1.0) if keep_train: self.load('model') LAS_params = list(self.encoder.parameters()) + list( self.decoder.parameters()) self.optimizer = optim.Adam(LAS_params, lr=0.001) self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=1, gamma=0.5) weight = torch.ones(self.nvocab_tgt) self.loss = NLLLoss(weight=weight, mask=0, size_average=False) # TODO: Perplexity or NLLLoss # TODO: pass in mask to loss funciton #self.loss = Perplexity(weight, 0) if torch.cuda.is_available(): # Move the network and the optimizer to the GPU self.encoder = self.encoder.cuda() self.decoder = self.decoder.cuda() self.loss.cuda() def __call__(self, src_sents, tgt_sents): """ take a mini-batch of source and target sentences, compute the log-likelihood of target sentences. Args: src_sents: list of source sentence tokens tgt_sents: list of target sentence tokens, wrapped by `<s>` and `</s>` Returns: scores: a variable/tensor of shape (batch_size, ) representing the log-likelihood of generating the gold-standard target sentence for each example in the input batch """ src_sents = self.vocab.src.words2indices(src_sents) tgt_sents = self.vocab.tgt.words2indices(tgt_sents) src_sents, src_len, y_input, y_tgt, tgt_len = sent_padding( src_sents, tgt_sents) src_encodings, decoder_init_state = self.encode(src_sents, src_len) scores, symbols = self.decode(src_encodings, decoder_init_state, [y_input, y_tgt], stage="train") return scores def encode(self, src_sents, input_lengths): """ Use a GRU/LSTM to encode source sentences into hidden states Args: src_sents: list of source sentence tokens Returns: src_encodings: hidden states of tokens in source sentences, this could be a variable with shape (batch_size, source_sentence_length, encoding_dim), or in orther formats decoder_init_state: decoder GRU/LSTM's initial state, computed from source encodings """ encoder_outputs, encoder_hidden = self.encoder(src_sents, input_lengths) return encoder_outputs, encoder_hidden def decode(self, src_encodings, decoder_init_state, tgt_sents, stage="train"): """ Given source encodings, compute the log-likelihood of predicting the gold-standard target sentence tokens Args: src_encodings: hidden states of tokens in source sentences decoder_init_state: decoder GRU/LSTM's initial state tgt_sents: list of gold-standard target sentences, wrapped by `<s>` and `</s>` Returns: scores: could be a variable of shape (batch_size, ) representing the log-likelihood of generating the gold-standard target sentence for each example in the input batch """ tgt_input, tgt_target = tgt_sents loss = self.loss decoder_outputs, decoder_hidden, symbols = self.decoder( tgt_input, decoder_init_state, src_encodings) loss.reset() for step, step_output in enumerate(decoder_outputs): batch_size = tgt_input.size(0) loss.eval_batch(step_output.contiguous().view(batch_size, -1), tgt_target[:, step]) self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.encoder.parameters(), 5.0) torch.nn.utils.clip_grad_norm_(self.decoder.parameters(), 5.0) self.optimizer.step() scores = loss.get_loss() return scores, symbols def decode_without_bp(self, src_encodings, decoder_init_state, tgt_sents): """ Given source encodings, compute the log-likelihood of predicting the gold-standard target sentence tokens Args: src_encodings: hidden states of tokens in source sentences decoder_init_state: decoder GRU/LSTM's initial state tgt_sents: list of gold-standard target sentences, wrapped by `<s>` and `</s>` Returns: scores: could be a variable of shape (batch_size, ) representing the log-likelihood of generating the gold-standard target sentence for each example in the input batch """ tgt_input, tgt_target = tgt_sents loss = self.loss decoder_outputs, decoder_hidden, symbols = self.decoder( tgt_input, decoder_init_state, src_encodings, stage="valid") loss.reset() for step, step_output in enumerate(decoder_outputs): batch_size = tgt_input.size(0) loss.eval_batch(step_output.contiguous().view(batch_size, -1), tgt_target[:, step]) scores = loss.get_loss() return scores, symbols # TODO: sent_padding for only src # def beam_search(self, src_sent: List[str], beam_size: int=5, max_decoding_time_step: int=70) -> List[Hypothesis]: def beam_search(self, src_sent, beam_size, max_decoding_time_step): """ Given a single source sentence, perform beam search Args: src_sent: a single tokenized source sentence beam_size: beam size max_decoding_time_step: maximum number of time steps to unroll the decoding RNN Returns: hypotheses: a list of hypothesis, each hypothesis has two fields: value: List[str]: the decoded target sentence, represented as a list of words score: float: the log-likelihood of the target sentence """ hypotheses = 0 return hypotheses # def evaluate_ppl(self, dev_data: List[Any], batch_size: int=32): def evaluate_ppl(self, dev_data, batch_size): """ Evaluate perplexity on dev sentences Args: dev_data: a list of dev sentences batch_size: batch size Returns: ppl: the perplexity on dev sentences """ ref_corpus = [] hyp_corpus = [] cum_loss = 0 count = 0 hyp_corpus_ordered = [] with torch.no_grad(): for src_sents, tgt_sents, orig_indices in batch_iter( dev_data, batch_size): ref_corpus.extend(tgt_sents) actual_size = len(src_sents) src_sents = self.vocab.src.words2indices(src_sents) tgt_sents = self.vocab.tgt.words2indices(tgt_sents) src_sents, src_len, y_input, y_tgt, tgt_len = sent_padding( src_sents, tgt_sents) src_encodings, decoder_init_state = self.encode( src_sents, src_len) scores, symbols = self.decode_without_bp( src_encodings, decoder_init_state, [y_input, y_tgt]) #sents = np.zeros((len(symbols),actual_size)) #for i,symbol in enumerate(symbols): # sents[i,:] = symbol.data.cpu().numpy() # print(sents.T) index = 0 batch_hyp_orderd = [None] * symbols.size(0) for sent in symbols: word_seq = [] for idx in sent: if idx == 2: break word_seq.append( self.vocab.tgt.id2word[np.asscalar(idx)]) hyp_corpus.append(word_seq) batch_hyp_orderd[orig_indices[index]] = word_seq index += 1 hyp_corpus_ordered.extend(batch_hyp_orderd) cum_loss += scores count += 1 with open('decode.txt', 'a') as f: for r, h in zip(ref_corpus, hyp_corpus_ordered): f.write(" ".join(h) + '\n') bleu = compute_corpus_level_bleu_score(ref_corpus, hyp_corpus) print('bleu score: ', bleu) return cum_loss / count # @staticmethod def load(self, model_path): self.encoder.load_state_dict(torch.load(model_path + '-encoder')) self.decoder.load_state_dict(torch.load(model_path + '-decoder')) # self.encoder.eval() # self.decoder.eval() def save(self, model_save_path): """ Save current model to file """ torch.save(self.encoder.state_dict(), model_save_path + '-encoder') torch.save(self.decoder.state_dict(), model_save_path + '-decoder')
def main(args): # Config logging logging.basicConfig(level=logging.INFO, format=LOG_FORMAT) logger = logging.getLogger() # Create model directory if not os.path.exists(args.model_path): os.makedirs(args.model_path) # Image preprocessing transform = transforms.Compose([ transforms.RandomCrop(args.crop_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # Load vocabulary wrapper. vocab = load_vocab(args.vocab_path) # Build data loader logger.info("Building data loader...") data_loader = get_loader(args.image_dir, args.caption_path, vocab, transform, args.batch_size, shuffle=True, num_workers=args.num_workers) logger.info("Done") # Build the models logger.info("Building image captioning models...") vqg = VQGModel(len(vocab), args.max_length, args.hidden_size, vocab(vocab.sos), vocab(vocab.eos), rnn_cell=args.rnn_cell) logger.info("Done") if torch.cuda.is_available(): vqg.cuda() # Loss and Optimizer weight = torch.ones(len(vocab)) pad = vocab(vocab.pad) # Set loss weight for 'pad' symbol to 0 loss = NLLLoss(weight, pad) if torch.cuda.is_available(): loss.cuda() # Parameters to train params = vqg.params_to_train() optimizer = torch.optim.Adam(params, lr=args.learning_rate) # Train the Models total_step = len(data_loader) for epoch in range(args.num_epochs): for i, (images, questions, answers) in enumerate(data_loader): # Set mini-batch dataset images = Variable(images) questions = Variable(questions) answers = Variable(answers) if torch.cuda.is_available(): images = images.cuda() questions = questions.cuda() answers = answers.cuda() # Forward, Backward and Optimize vqg.zero_grad() outputs, hiddens, other = vqg(images, questions, teacher_forcing_ratio=1.0) # Get loss loss.reset() for step, step_output in enumerate(outputs): batch_size = questions.size(0) loss.eval_batch(step_output.contiguous().view(batch_size, -1), questions[:, step + 1]) loss.backward() optimizer.step() # Print log info if i % args.log_step == 0: logger.info( 'Epoch [%d/%d], Step [%d/%d], Loss: %.4f' % (epoch, args.num_epochs, i, total_step, loss.get_loss())) # Save the models if (i + 1) % args.save_step == 0: torch.save( vqg.state_dict(), os.path.join(args.model_path, 'vqg-%d-%d.pkl' % (epoch + 1, i + 1)))
fields=[('src', src), ('tgt', tgt)], filter_pred=len_filter) src.build_vocab(train, max_size=50000) tgt.build_vocab(train, max_size=50000) input_vocab = src.vocab output_vocab = tgt.vocab # Prepare loss weight = torch.ones(len(tgt.vocab)) pad = tgt.vocab.stoi[tgt.pad_token] loss = NLLLoss(weight, pad) if torch.cuda.is_available(): loss.cuda() else: print("*********** no cuda **************") seq2seq_m = None # Initialize model hidden_size = 512 bidirectional = True num_epochs = 500 # Initialize models encoder = EncoderRNN(len(input_vocab), max_len, hidden_size, bidirectional=True,