def run_training(opt, default_data_dir, num_epochs=100): if opt.load_checkpoint is not None: logging.info("loading checkpoint from {}".format( os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, opt.load_checkpoint))) checkpoint_path = os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, opt.load_checkpoint) checkpoint = Checkpoint.load(checkpoint_path) seq2seq = checkpoint.model input_vocab = checkpoint.input_vocab output_vocab = checkpoint.output_vocab else: # Prepare dataset src = SourceField() tgt = TargetField() max_len = 50 data_file = os.path.join(default_data_dir, opt.train_path, 'data.txt') logging.info("Starting new Training session on %s", data_file) def len_filter(example): return (len(example.src) <= max_len) and (len(example.tgt) <= max_len) \ and (len(example.src) > 0) and (len(example.tgt) > 0) train = torchtext.data.TabularDataset( path=data_file, format='json', fields={'src': ('src', src), 'tgt': ('tgt', tgt)}, filter_pred=len_filter ) dev = None if opt.no_dev is False: dev_data_file = os.path.join(default_data_dir, opt.train_path, 'dev-data.txt') dev = torchtext.data.TabularDataset( path=dev_data_file, format='json', fields={'src': ('src', src), 'tgt': ('tgt', tgt)}, filter_pred=len_filter ) src.build_vocab(train, max_size=50000) tgt.build_vocab(train, max_size=50000) input_vocab = src.vocab output_vocab = tgt.vocab # NOTE: If the source field name and the target field name # are different from 'src' and 'tgt' respectively, they have # to be set explicitly before any training or inference # seq2seq.src_field_name = 'src' # seq2seq.tgt_field_name = 'tgt' # Prepare loss weight = torch.ones(len(tgt.vocab)) pad = tgt.vocab.stoi[tgt.pad_token] loss = Perplexity(weight, pad) if torch.cuda.is_available(): logging.info("Yayyy We got CUDA!!!") loss.cuda() else: logging.info("No cuda available device found running on cpu") seq2seq = None optimizer = None if not opt.resume: hidden_size = 128 decoder_hidden_size = hidden_size * 2 logging.info("EncoderRNN Hidden Size: %s", hidden_size) logging.info("DecoderRNN Hidden Size: %s", decoder_hidden_size) bidirectional = True encoder = EncoderRNN(len(src.vocab), max_len, hidden_size, bidirectional=bidirectional, rnn_cell='lstm', variable_lengths=True) decoder = DecoderRNN(len(tgt.vocab), max_len, decoder_hidden_size, dropout_p=0, use_attention=True, bidirectional=bidirectional, rnn_cell='lstm', eos_id=tgt.eos_id, sos_id=tgt.sos_id) seq2seq = Seq2seq(encoder, decoder) if torch.cuda.is_available(): seq2seq.cuda() for param in seq2seq.parameters(): param.data.uniform_(-0.08, 0.08) # Optimizer and learning rate scheduler can be customized by # explicitly constructing the objects and pass to the trainer. optimizer = Optimizer(torch.optim.Adam(seq2seq.parameters()), max_grad_norm=5) scheduler = StepLR(optimizer.optimizer, 1) optimizer.set_scheduler(scheduler) # train num_epochs = num_epochs batch_size = 32 checkpoint_every = num_epochs / 10 print_every = num_epochs / 100 properties = dict(batch_size=batch_size, checkpoint_every=checkpoint_every, print_every=print_every, expt_dir=opt.expt_dir, num_epochs=num_epochs, teacher_forcing_ratio=0.5, resume=opt.resume) logging.info("Starting training with the following Properties %s", json.dumps(properties, indent=2)) t = SupervisedTrainer(loss=loss, batch_size=num_epochs, checkpoint_every=checkpoint_every, print_every=print_every, expt_dir=opt.expt_dir) seq2seq = t.train(seq2seq, train, num_epochs=num_epochs, dev_data=dev, optimizer=optimizer, teacher_forcing_ratio=0.5, resume=opt.resume) evaluator = Evaluator(loss=loss, batch_size=batch_size) if opt.no_dev is False: dev_loss, accuracy = evaluator.evaluate(seq2seq, dev) logging.info("Dev Loss: %s", dev_loss) logging.info("Accuracy: %s", dev_loss) beam_search = Seq2seq(seq2seq.encoder, TopKDecoder(seq2seq.decoder, 4)) predictor = Predictor(beam_search, input_vocab, output_vocab) while True: try: seq_str = raw_input("Type in a source sequence:") seq = seq_str.strip().split() results = predictor.predict_n(seq, n=3) for i, res in enumerate(results): print('option %s: %s\n', i + 1, res) except KeyboardInterrupt: logging.info("Bye Bye") exit(0)
opt.load_checkpoint = os.path.join(opt.model_dir, last_checkpoint) opt.skip_steps = int(last_checkpoint.strip('.pt').split('/')[-1]) if opt.load_checkpoint: seq2seq.load_state_dict(torch.load(opt.load_checkpoint)) opt.skip_steps = int(opt.load_checkpoint.strip('.pt').split('/')[-1]) if not multi_gpu or hvd.rank() == 0: logger.info(f"\nLoad from {opt.load_checkpoint}\n") else: for param in seq2seq.parameters(): param.data.uniform_(-opt.init_weight, opt.init_weight) if opt.beam_width > 1 and opt.phase == "infer": if not multi_gpu or hvd.rank() == 0: logger.info(f"Beam Width {opt.beam_width}") seq2seq.decoder = TopKDecoder(seq2seq.decoder, opt.beam_width) if opt.phase == "train": # Prepare Train Data trans_data = TranslateData(pad_id) train_set = DialogDataset(opt.train_path, trans_data.translate_data, src_vocab, tgt_vocab, max_src_length=opt.max_src_length, max_tgt_length=opt.max_tgt_length) train_sampler = dist.DistributedSampler(train_set, num_replicas=hvd.size(), rank=hvd.rank()) \ if multi_gpu else None train = DataLoader(train_set, batch_size=opt.batch_size, shuffle=False if multi_gpu else True,
# train t = SupervisedTrainer( loss=loss, batch_size=32, checkpoint_every=50, print_every=10, expt_dir=opt.expt_dir, ) seq2seq = t.train( seq2seq, train, num_epochs=6, dev_data=dev, optimizer=optimizer, teacher_forcing_ratio=0.5, resume=opt.resume, ) evaluator = Evaluator(loss=loss, batch_size=32) dev_loss, accuracy = evaluator.evaluate(seq2seq, dev) assert dev_loss < 1.5 beam_search = Seq2seq(seq2seq.encoder, TopKDecoder(seq2seq.decoder, 3)) predictor = Predictor(beam_search, input_vocab, output_vocab) inp_seq = "1 3 5 7 9" seq = predictor.predict(inp_seq.split()) assert " ".join(seq[:-1]) == inp_seq[::-1]
def test_k_greater_than_1(self): """ Implement beam search manually and compare results from topk decoder. """ max_len = 50 beam_size = 3 batch_size = 1 hidden_size = 8 sos = 0 eos = 1 for _ in range(10): decoder = DecoderRNN(self.vocab_size, max_len, hidden_size, sos, eos) for param in decoder.parameters(): param.data.uniform_(-1, 1) topk_decoder = TopKDecoder(decoder, beam_size) encoder_hidden = torch.autograd.Variable( torch.randn(1, batch_size, hidden_size)) _, hidden_topk, other_topk = topk_decoder( None, encoder_hidden=encoder_hidden) # Queue state: # 1. time step # 2. symbol # 3. hidden state # 4. accumulated log likelihood # 5. beam number batch_queue = [[(-1, sos, encoder_hidden[:, b, :].unsqueeze(1), 0, None)] for b in range(batch_size)] time_batch_queue = [batch_queue] batch_finished_seqs = [list() for _ in range(batch_size)] for t in range(max_len): new_batch_queue = [] for b in range(batch_size): new_queue = [] for k in range(min(len(time_batch_queue[t][b]), beam_size)): _, inputs, hidden, seq_score, _ = time_batch_queue[t][ b][k] if inputs == eos: batch_finished_seqs[b].append( time_batch_queue[t][b][k]) continue inputs = torch.autograd.Variable( torch.LongTensor([[inputs]])) context, hidden, attn = decoder.forward_step( inputs, hidden, None) decoder_outputs, symbols = decoder.decoder( context, attn, None, None) decoder_outputs = decoder_outputs.log() topk_score, topk = decoder_outputs[0].data.topk( beam_size) for score, sym in zip(topk_score.tolist()[0], topk.tolist()[0]): new_queue.append( (t, sym, hidden, score + seq_score, k)) new_queue = sorted(new_queue, key=lambda x: x[3], reverse=True)[:beam_size] new_batch_queue.append(new_queue) time_batch_queue.append(new_batch_queue) # finished beams finalist = [l[:beam_size] for l in batch_finished_seqs] # unfinished beams for b in range(batch_size): if len(finalist[b]) < beam_size: last_step = sorted(time_batch_queue[-1][b], key=lambda x: x[3], reverse=True) finalist[b] += last_step[:beam_size - len(finalist[b])] # back track topk = [] for b in range(batch_size): batch_topk = [] for k in range(beam_size): seq = [finalist[b][k]] prev_k = seq[-1][4] prev_t = seq[-1][0] while prev_k is not None: seq.append(time_batch_queue[prev_t][b][prev_k]) prev_k = seq[-1][4] prev_t = seq[-1][0] batch_topk.append([s for s in reversed(seq)]) topk.append(batch_topk) for b in range(batch_size): topk[b] = sorted(topk[b], key=lambda s: s[-1][3], reverse=True) topk_scores = other_topk['score'] topk_lengths = other_topk['topk_length'] topk_pred_symbols = other_topk['topk_sequence'] for b in range(batch_size): precision_error = False for k in range(beam_size - 1): if np.isclose(topk_scores[b][k], topk_scores[b][k + 1]): precision_error = True break if precision_error: break for k in range(beam_size): self.assertEqual(topk_lengths[b][k], len(topk[b][k]) - 1) self.assertTrue( np.isclose(topk_scores[b][k], topk[b][k][-1][3])) total_steps = topk_lengths[b][k] for t in range(total_steps): self.assertEqual(topk_pred_symbols[t][b, k].data[0], topk[b][k][t + 1][1]) # topk includes SOS
def test_init(self): decoder = DecoderRNN(self.vocab_size, 50, 16, 0, 1, input_dropout_p=0) TopKDecoder(decoder, 3)
parser.add_argument('--num_layer', type=int, default=1) parser.add_argument('--num_class', type=int, default=3) parser.add_argument('--use_cuda', type=bool, default=True) parser.add_argument('--use_type', type=str, default='elmo') parser.add_argument('--class_batch_size', type=int, default=1) parser.add_argument('--seed', type=int, default=42) opt = parser.parse_args() random.seed(opt.seed) np.random.seed(opt.seed) torch.manual_seed(opt.seed) checkpoint = Checkpoint().load(opt.expt_dir) model = checkpoint.model beam_search = Multi_Task(model.embedding_layer, model.encoder, TopKDecoder(model.decoder, 3), model.classification, model.class_encoder, model.norm_encoder, opt=opt) # Multi_Task(multi_task.encoder, TopKDecoder(multi_task.decoder, 3), multi_task.classification) if torch.cuda.is_available(): beam_search = beam_search.cuda() input_vocab = load_from_pickle(opt.src_vocab_path) output_vocab = load_from_pickle(opt.tgt_vocab_path) predictor = Predictor(beam_search, input_vocab, output_vocab) # inp_seq = ["This was largely accounted for by seed under 9 years old , about 90% of which is viable .", # "MENTION MENTION weddings in the summer in Aruba ofc u guys r my bridesmaids"] # inp_seq = "MENTION MENTION weddings in the summer in Aruba ofc u guys r my bridesmaids"