Ejemplo n.º 1
0
def run_training(opt, default_data_dir, num_epochs=100):
    if opt.load_checkpoint is not None:
        logging.info("loading checkpoint from {}".format(
            os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, opt.load_checkpoint)))
        checkpoint_path = os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, opt.load_checkpoint)
        checkpoint = Checkpoint.load(checkpoint_path)
        seq2seq = checkpoint.model
        input_vocab = checkpoint.input_vocab
        output_vocab = checkpoint.output_vocab
    else:

        # Prepare dataset
        src = SourceField()
        tgt = TargetField()
        max_len = 50

        data_file = os.path.join(default_data_dir, opt.train_path, 'data.txt')

        logging.info("Starting new Training session on %s", data_file)

        def len_filter(example):
            return (len(example.src) <= max_len) and (len(example.tgt) <= max_len) \
                   and (len(example.src) > 0) and (len(example.tgt) > 0)

        train = torchtext.data.TabularDataset(
            path=data_file, format='json',
            fields={'src': ('src', src), 'tgt': ('tgt', tgt)},
            filter_pred=len_filter
        )

        dev = None
        if opt.no_dev is False:
            dev_data_file = os.path.join(default_data_dir, opt.train_path, 'dev-data.txt')
            dev = torchtext.data.TabularDataset(
                path=dev_data_file, format='json',
                fields={'src': ('src', src), 'tgt': ('tgt', tgt)},
                filter_pred=len_filter
            )

        src.build_vocab(train, max_size=50000)
        tgt.build_vocab(train, max_size=50000)
        input_vocab = src.vocab
        output_vocab = tgt.vocab

        # NOTE: If the source field name and the target field name
        # are different from 'src' and 'tgt' respectively, they have
        # to be set explicitly before any training or inference
        # seq2seq.src_field_name = 'src'
        # seq2seq.tgt_field_name = 'tgt'

        # Prepare loss
        weight = torch.ones(len(tgt.vocab))
        pad = tgt.vocab.stoi[tgt.pad_token]
        loss = Perplexity(weight, pad)
        if torch.cuda.is_available():
            logging.info("Yayyy We got CUDA!!!")
            loss.cuda()
        else:
            logging.info("No cuda available device found running on cpu")

        seq2seq = None
        optimizer = None
        if not opt.resume:
            hidden_size = 128
            decoder_hidden_size = hidden_size * 2
            logging.info("EncoderRNN Hidden Size: %s", hidden_size)
            logging.info("DecoderRNN Hidden Size: %s", decoder_hidden_size)
            bidirectional = True
            encoder = EncoderRNN(len(src.vocab), max_len, hidden_size,
                                 bidirectional=bidirectional,
                                 rnn_cell='lstm',
                                 variable_lengths=True)
            decoder = DecoderRNN(len(tgt.vocab), max_len, decoder_hidden_size,
                                 dropout_p=0, use_attention=True,
                                 bidirectional=bidirectional,
                                 rnn_cell='lstm',
                                 eos_id=tgt.eos_id, sos_id=tgt.sos_id)

            seq2seq = Seq2seq(encoder, decoder)
            if torch.cuda.is_available():
                seq2seq.cuda()

            for param in seq2seq.parameters():
                param.data.uniform_(-0.08, 0.08)

        # Optimizer and learning rate scheduler can be customized by
        # explicitly constructing the objects and pass to the trainer.

        optimizer = Optimizer(torch.optim.Adam(seq2seq.parameters()), max_grad_norm=5)
        scheduler = StepLR(optimizer.optimizer, 1)
        optimizer.set_scheduler(scheduler)

        # train

        num_epochs = num_epochs
        batch_size = 32
        checkpoint_every = num_epochs / 10
        print_every = num_epochs / 100

        properties = dict(batch_size=batch_size,
                          checkpoint_every=checkpoint_every,
                          print_every=print_every, expt_dir=opt.expt_dir,
                          num_epochs=num_epochs,
                          teacher_forcing_ratio=0.5,
                          resume=opt.resume)

        logging.info("Starting training with the following Properties %s", json.dumps(properties, indent=2))
        t = SupervisedTrainer(loss=loss, batch_size=num_epochs,
                              checkpoint_every=checkpoint_every,
                              print_every=print_every, expt_dir=opt.expt_dir)

        seq2seq = t.train(seq2seq, train,
                          num_epochs=num_epochs, dev_data=dev,
                          optimizer=optimizer,
                          teacher_forcing_ratio=0.5,
                          resume=opt.resume)

        evaluator = Evaluator(loss=loss, batch_size=batch_size)

        if opt.no_dev is False:
            dev_loss, accuracy = evaluator.evaluate(seq2seq, dev)
            logging.info("Dev Loss: %s", dev_loss)
            logging.info("Accuracy: %s", dev_loss)

    beam_search = Seq2seq(seq2seq.encoder, TopKDecoder(seq2seq.decoder, 4))

    predictor = Predictor(beam_search, input_vocab, output_vocab)
    while True:
        try:
            seq_str = raw_input("Type in a source sequence:")
            seq = seq_str.strip().split()
            results = predictor.predict_n(seq, n=3)
            for i, res in enumerate(results):
                print('option %s: %s\n', i + 1, res)
        except KeyboardInterrupt:
            logging.info("Bye Bye")
            exit(0)
Ejemplo n.º 2
0
            opt.load_checkpoint = os.path.join(opt.model_dir, last_checkpoint)
            opt.skip_steps = int(last_checkpoint.strip('.pt').split('/')[-1])

    if opt.load_checkpoint:
        seq2seq.load_state_dict(torch.load(opt.load_checkpoint))
        opt.skip_steps = int(opt.load_checkpoint.strip('.pt').split('/')[-1])
        if not multi_gpu or hvd.rank() == 0:
            logger.info(f"\nLoad from {opt.load_checkpoint}\n")
    else:
        for param in seq2seq.parameters():
            param.data.uniform_(-opt.init_weight, opt.init_weight)

    if opt.beam_width > 1 and opt.phase == "infer":
        if not multi_gpu or hvd.rank() == 0:
            logger.info(f"Beam Width {opt.beam_width}")
        seq2seq.decoder = TopKDecoder(seq2seq.decoder, opt.beam_width)

    if opt.phase == "train":
        # Prepare Train Data
        trans_data = TranslateData(pad_id)
        train_set = DialogDataset(opt.train_path,
                                  trans_data.translate_data,
                                  src_vocab,
                                  tgt_vocab,
                                  max_src_length=opt.max_src_length,
                                  max_tgt_length=opt.max_tgt_length)
        train_sampler = dist.DistributedSampler(train_set, num_replicas=hvd.size(), rank=hvd.rank()) \
                            if multi_gpu else None
        train = DataLoader(train_set,
                           batch_size=opt.batch_size,
                           shuffle=False if multi_gpu else True,
Ejemplo n.º 3
0
    # train
    t = SupervisedTrainer(
        loss=loss,
        batch_size=32,
        checkpoint_every=50,
        print_every=10,
        expt_dir=opt.expt_dir,
    )

    seq2seq = t.train(
        seq2seq,
        train,
        num_epochs=6,
        dev_data=dev,
        optimizer=optimizer,
        teacher_forcing_ratio=0.5,
        resume=opt.resume,
    )

evaluator = Evaluator(loss=loss, batch_size=32)
dev_loss, accuracy = evaluator.evaluate(seq2seq, dev)
assert dev_loss < 1.5

beam_search = Seq2seq(seq2seq.encoder, TopKDecoder(seq2seq.decoder, 3))

predictor = Predictor(beam_search, input_vocab, output_vocab)
inp_seq = "1 3 5 7 9"
seq = predictor.predict(inp_seq.split())
assert " ".join(seq[:-1]) == inp_seq[::-1]
Ejemplo n.º 4
0
    def test_k_greater_than_1(self):
        """ Implement beam search manually and compare results from topk decoder. """
        max_len = 50
        beam_size = 3
        batch_size = 1
        hidden_size = 8
        sos = 0
        eos = 1

        for _ in range(10):
            decoder = DecoderRNN(self.vocab_size, max_len, hidden_size, sos,
                                 eos)
            for param in decoder.parameters():
                param.data.uniform_(-1, 1)
            topk_decoder = TopKDecoder(decoder, beam_size)

            encoder_hidden = torch.autograd.Variable(
                torch.randn(1, batch_size, hidden_size))
            _, hidden_topk, other_topk = topk_decoder(
                None, encoder_hidden=encoder_hidden)

            # Queue state:
            #   1. time step
            #   2. symbol
            #   3. hidden state
            #   4. accumulated log likelihood
            #   5. beam number
            batch_queue = [[(-1, sos, encoder_hidden[:, b, :].unsqueeze(1), 0,
                             None)] for b in range(batch_size)]
            time_batch_queue = [batch_queue]
            batch_finished_seqs = [list() for _ in range(batch_size)]
            for t in range(max_len):
                new_batch_queue = []
                for b in range(batch_size):
                    new_queue = []
                    for k in range(min(len(time_batch_queue[t][b]),
                                       beam_size)):
                        _, inputs, hidden, seq_score, _ = time_batch_queue[t][
                            b][k]
                        if inputs == eos:
                            batch_finished_seqs[b].append(
                                time_batch_queue[t][b][k])
                            continue
                        inputs = torch.autograd.Variable(
                            torch.LongTensor([[inputs]]))
                        context, hidden, attn = decoder.forward_step(
                            inputs, hidden, None)
                        decoder_outputs, symbols = decoder.decoder(
                            context, attn, None, None)
                        decoder_outputs = decoder_outputs.log()
                        topk_score, topk = decoder_outputs[0].data.topk(
                            beam_size)
                        for score, sym in zip(topk_score.tolist()[0],
                                              topk.tolist()[0]):
                            new_queue.append(
                                (t, sym, hidden, score + seq_score, k))
                    new_queue = sorted(new_queue,
                                       key=lambda x: x[3],
                                       reverse=True)[:beam_size]
                    new_batch_queue.append(new_queue)
                time_batch_queue.append(new_batch_queue)

            # finished beams
            finalist = [l[:beam_size] for l in batch_finished_seqs]
            # unfinished beams
            for b in range(batch_size):
                if len(finalist[b]) < beam_size:
                    last_step = sorted(time_batch_queue[-1][b],
                                       key=lambda x: x[3],
                                       reverse=True)
                    finalist[b] += last_step[:beam_size - len(finalist[b])]

            # back track
            topk = []
            for b in range(batch_size):
                batch_topk = []
                for k in range(beam_size):
                    seq = [finalist[b][k]]
                    prev_k = seq[-1][4]
                    prev_t = seq[-1][0]
                    while prev_k is not None:
                        seq.append(time_batch_queue[prev_t][b][prev_k])
                        prev_k = seq[-1][4]
                        prev_t = seq[-1][0]
                    batch_topk.append([s for s in reversed(seq)])
                topk.append(batch_topk)

            for b in range(batch_size):
                topk[b] = sorted(topk[b], key=lambda s: s[-1][3], reverse=True)

            topk_scores = other_topk['score']
            topk_lengths = other_topk['topk_length']
            topk_pred_symbols = other_topk['topk_sequence']
            for b in range(batch_size):
                precision_error = False
                for k in range(beam_size - 1):
                    if np.isclose(topk_scores[b][k], topk_scores[b][k + 1]):
                        precision_error = True
                        break
                if precision_error:
                    break
                for k in range(beam_size):
                    self.assertEqual(topk_lengths[b][k], len(topk[b][k]) - 1)
                    self.assertTrue(
                        np.isclose(topk_scores[b][k], topk[b][k][-1][3]))
                    total_steps = topk_lengths[b][k]
                    for t in range(total_steps):
                        self.assertEqual(topk_pred_symbols[t][b, k].data[0],
                                         topk[b][k][t +
                                                    1][1])  # topk includes SOS
Ejemplo n.º 5
0
 def test_init(self):
     decoder = DecoderRNN(self.vocab_size, 50, 16, 0, 1, input_dropout_p=0)
     TopKDecoder(decoder, 3)
Ejemplo n.º 6
0
parser.add_argument('--num_layer', type=int, default=1)
parser.add_argument('--num_class', type=int, default=3)
parser.add_argument('--use_cuda', type=bool, default=True)
parser.add_argument('--use_type', type=str, default='elmo')
parser.add_argument('--class_batch_size', type=int, default=1)
parser.add_argument('--seed', type=int, default=42)
opt = parser.parse_args()

random.seed(opt.seed)
np.random.seed(opt.seed)
torch.manual_seed(opt.seed)

checkpoint = Checkpoint().load(opt.expt_dir)
model = checkpoint.model

beam_search = Multi_Task(model.embedding_layer, model.encoder, TopKDecoder(model.decoder, 3),
                         model.classification, model.class_encoder, model.norm_encoder, opt=opt)
# Multi_Task(multi_task.encoder, TopKDecoder(multi_task.decoder, 3), multi_task.classification)


if torch.cuda.is_available():
    beam_search = beam_search.cuda()


input_vocab = load_from_pickle(opt.src_vocab_path)
output_vocab = load_from_pickle(opt.tgt_vocab_path)

predictor = Predictor(beam_search, input_vocab, output_vocab)
# inp_seq = ["This was largely accounted for by seed under 9 years old , about 90% of which is viable .",
#            "MENTION MENTION weddings in the summer in Aruba ofc u guys r my bridesmaids"]
# inp_seq = "MENTION MENTION weddings in the summer in Aruba ofc u guys r my bridesmaids"