示例#1
0
    def sample_extension(trainer):
        encdec = trainer.updater.get_optimizer("main").target
        iterator = trainer.updater.get_iterator("main")
        mb_raw = iterator.peek()

        def s_unk_tag(num, utag):
            return "S_UNK_%i" % utag

        def t_unk_tag(num, utag):
            return "T_UNK_%i" % utag

        try:
            if encdec.encdec_type() == "ff":
                src_seqs, tgt_seqs = list(six.moves.zip(*mb_raw))
                sample_once_ff(encdec, src_seqs, tgt_seqs, src_indexer, tgt_indexer, max_nb=20,
                    s_unk_tag=s_unk_tag, t_unk_tag=t_unk_tag)
            else:
    
                src_batch, tgt_batch, src_mask = make_batch_src_tgt(mb_raw, eos_idx=eos_idx, padding_idx=0, gpu=gpu, need_arg_sort=False,
                                                                    use_chainerx = use_chainerx)
        
                sample_once(encdec, src_batch, tgt_batch, src_mask, src_indexer, tgt_indexer, eos_idx,
                            max_nb=20,
                            s_unk_tag=s_unk_tag, t_unk_tag=t_unk_tag)
        except CudaException:
            log.warn("CUDARuntimeError during sample. Skipping sample")
示例#2
0
    def sample_extension(trainer):
        encdec = trainer.updater.get_optimizer("main").target
        iterator = trainer.updater.get_iterator("main")
        mb_raw = iterator.peek()

        src_batch, tgt_batch, src_mask = make_batch_src_tgt(
            mb_raw,
            eos_idx=eos_idx,
            padding_idx=0,
            gpu=gpu,
            volatile="on",
            need_arg_sort=False)

        def s_unk_tag(num, utag):
            return "S_UNK_%i" % utag

        def t_unk_tag(num, utag):
            return "T_UNK_%i" % utag

        sample_once(encdec,
                    src_batch,
                    tgt_batch,
                    src_mask,
                    src_indexer,
                    tgt_indexer,
                    eos_idx,
                    max_nb=20,
                    s_unk_tag=s_unk_tag,
                    t_unk_tag=t_unk_tag)
示例#3
0
文件: training.py 项目: mayutwmu/knmt
def train_on_data(encdec, optimizer, training_data, output_files_dict,
                  src_indexer, tgt_indexer, eos_idx, mb_size=80,
                  nb_of_batch_to_sort=20,
                  test_data=None, dev_data=None, valid_data=None,
                  gpu=None, report_every=200, randomized=False,
                  reverse_src=False, reverse_tgt=False, max_nb_iters=None, do_not_save_data_for_resuming=False,
                  noise_on_prev_word=False, curiculum_training=False,
                  use_previous_prediction=0, no_report_or_save=False,
                  use_memory_optimization=False, sample_every=200,
                  use_reinf=False,
                  save_ckpt_every=2000):
    #     ,
    #                   lexical_probability_dictionary = None,
    #                   V_tgt = None,
    #                   lexicon_prob_epsilon = 1e-3):

    if curiculum_training:
        log.info("Sorting training data by complexity")
        training_data_sorted_by_complexity = sorted(training_data, key=example_complexity)
        log.info("done")

        for s, t in training_data_sorted_by_complexity[:400]:
            print example_complexity((s, t))
            print src_indexer.deconvert(s)
            print tgt_indexer.deconvert(t)
            print

        mb_provider = minibatch_provider_curiculum(training_data_sorted_by_complexity, eos_idx, mb_size, nb_of_batch_to_sort, gpu=gpu,
                                                   randomized=randomized, sort_key=lambda x: len(x[0]),
                                                   reverse_src=reverse_src, reverse_tgt=reverse_tgt)
    else:
        mb_provider = minibatch_provider(training_data, eos_idx, mb_size, nb_of_batch_to_sort, gpu=gpu,
                                         randomized=randomized, sort_key=lambda x: len(x[0]),
                                         reverse_src=reverse_src, reverse_tgt=reverse_tgt)

#     mb_provider = minibatch_provider(training_data, eos_idx, mb_size, nb_of_batch_to_sort, gpu = gpu,
#                                      randomized = randomized, sort_key = lambda x:len(x[1]))

    def s_unk_tag(num, utag):
        return "S_UNK_%i" % utag

    def t_unk_tag(num, utag):
        return "T_UNK_%i" % utag

    def save_model(suffix):
        if suffix == "final":
            fn_save = output_files_dict["model_final"]
        elif suffix == "ckpt":
            fn_save = output_files_dict["model_ckpt"]
        elif suffix == "best":
            fn_save = output_files_dict["model_best"]
        elif suffix == "best_loss":
            fn_save = output_files_dict["model_best_loss"]
        else:
            assert False
        log.info("saving model to %s" % fn_save)
        serializers.save_npz(fn_save, encdec)

    def train_once(src_batch, tgt_batch, src_mask):  # , lexicon_matrix = None):
        t0 = time.clock()
        encdec.zerograds()
        t1 = time.clock()
        (total_loss, total_nb_predictions), attn = encdec(src_batch, tgt_batch, src_mask, raw_loss_info=True,
                                                          noise_on_prev_word=noise_on_prev_word,
                                                          use_previous_prediction=use_previous_prediction,
                                                          mode="train")
#         ,
#                                                           lexicon_probability_matrix = lexicon_matrix,
#                                                           lex_epsilon = lexicon_prob_epsilon)
        loss = total_loss / total_nb_predictions
        t2 = time.clock()
        loss.backward()
        t3 = time.clock()
        optimizer.update()
        t4 = time.clock()
        print "loss:", loss.data,
        print " time %f zgrad:%f fwd:%f bwd:%f upd:%f" % (t4 - t0, t1 - t0, t2 - t1, t3 - t2, t4 - t3)
        return float(total_loss.data), total_nb_predictions

    def train_once_optim(src_batch, tgt_batch, src_mask):
        t0 = time.clock()
        encdec.zerograds()
        t1 = time.clock()
        loss, total_nb_predictions = encdec.compute_loss_and_backward(
            src_batch, tgt_batch, src_mask)
        t2 = time.clock()
        print "loss:", loss,
        t3 = time.clock()
        optimizer.update()
        t4 = time.clock()
        print " time %f zgrad:%f fwd:%f bwd:%f upd:%f" % (t4 - t0, t1 - t0, t2 - t1, t3 - t2, t4 - t3)
        return float(loss) * total_nb_predictions, total_nb_predictions

    def train_once_reinf(src_batch, tgt_batch, src_mask):  # , lexicon_matrix = None):
        t0 = time.clock()
        encdec.zerograds()
        t1 = time.clock()

        from nmt_chainer.utilities import utils
        test_ref = utils.de_batch(tgt_batch, is_variable=True)

        reinf_loss = encdec.get_reinf_loss(src_batch, src_mask, eos_idx,
                                           test_ref, nb_steps=50, nb_samples=5,
                                           use_best_for_sample=False,
                                           temperature=None,
                                           mode="test")

        t2 = time.clock()
        reinf_loss.backward()
        t3 = time.clock()
        optimizer.update()
        t4 = time.clock()
        print "reinf loss:", reinf_loss.data, reinf_loss.data / len(src_batch)
        print " time %f zgrad:%f fwd:%f bwd:%f upd:%f" % (t4 - t0, t1 - t0, t2 - t1, t3 - t2, t4 - t3)
        return float(reinf_loss.data), len(src_batch)

    if test_data is not None:
        test_src_data = [x for x, y in test_data]
        test_references = [y for x, y in test_data]

        def translate_test():
            translations_fn = output_files_dict["test_translation_output"]  # save_prefix + ".test.out"
            control_src_fn = output_files_dict["test_src_output"]  # save_prefix + ".test.src.out"
            return translate_to_file(encdec, eos_idx, test_src_data, mb_size, tgt_indexer,
                                     translations_fn, test_references=test_references, control_src_fn=control_src_fn,
                                     src_indexer=src_indexer, gpu=gpu, nb_steps=50, reverse_src=reverse_src, reverse_tgt=reverse_tgt,
                                     s_unk_tag=s_unk_tag, t_unk_tag=t_unk_tag)

        def compute_test_loss():
            log.info("computing test loss")
            test_loss = compute_loss_all(encdec, test_data, eos_idx, mb_size, gpu=gpu,
                                         reverse_src=reverse_src, reverse_tgt=reverse_tgt)
            log.info("test loss: %f" % test_loss)
            return test_loss
    else:
        def translate_test():
            log.info("translate_test: No test data given")
            return None

        def compute_test_loss():
            log.info("compute_test_loss: No test data given")
            return None

    if dev_data is not None:
        dev_src_data = [x for x, y in dev_data]
        dev_references = [y for x, y in dev_data]

        def translate_dev():
            translations_fn = output_files_dict["dev_translation_output"]  # save_prefix + ".test.out"
            control_src_fn = output_files_dict["dev_src_output"]  # save_prefix + ".test.src.out"
            return translate_to_file(encdec, eos_idx, dev_src_data, mb_size, tgt_indexer,
                                     translations_fn, test_references=dev_references, control_src_fn=control_src_fn,
                                     src_indexer=src_indexer, gpu=gpu, nb_steps=50, reverse_src=reverse_src, reverse_tgt=reverse_tgt,
                                     s_unk_tag=s_unk_tag, t_unk_tag=t_unk_tag)

        def compute_dev_loss():
            log.info("computing dev loss")
            dev_loss = compute_loss_all(encdec, dev_data, eos_idx, mb_size, gpu=gpu,
                                        reverse_src=reverse_src, reverse_tgt=reverse_tgt)
            log.info("dev loss: %f" % dev_loss)
            return dev_loss
    else:
        def translate_dev():
            log.info("translate_dev: No dev data given")
            return None

        def compute_dev_loss():
            log.info("compute_dev_loss: No dev data given")
            return None

    if valid_data is not None:
        valid_src_data = [x for x, y in valid_data]
        valid_references = [y for x, y in valid_data]

        def translate_valid():
            translations_fn = output_files_dict["valid_translation_output"]  # save_prefix + ".test.out"
            control_src_fn = output_files_dict["valid_src_output"]  # save_prefix + ".test.src.out"
            return translate_to_file(encdec, eos_idx, valid_src_data, mb_size, tgt_indexer,
                                     translations_fn, test_references=valid_references, control_src_fn=control_src_fn,
                                     src_indexer=src_indexer, gpu=gpu, nb_steps=50, reverse_src=reverse_src, reverse_tgt=reverse_tgt,
                                     s_unk_tag=s_unk_tag, t_unk_tag=t_unk_tag)

        def compute_valid_loss():
            log.info("computing valid loss")
            dev_loss = compute_loss_all(encdec, valid_data, eos_idx, mb_size, gpu=gpu,
                                        reverse_src=reverse_src, reverse_tgt=reverse_tgt)
            log.info("valid loss: %f" % dev_loss)
            return dev_loss
    else:
        def translate_valid():
            log.info("translate_valid: No valid data given")
            return None

        def compute_valid_loss():
            log.info("compute_valid_loss: No valid data given")
            return None

    try:
        best_dev_bleu = 0
        best_dev_loss = None
        prev_time = time.clock()
        prev_i = None
        total_loss_this_interval = 0
        total_nb_predictions_this_interval = 0
        for i in xrange(sys.maxsize):
            if max_nb_iters is not None and max_nb_iters <= i:
                break
            print i,
            src_batch, tgt_batch, src_mask = mb_provider.next()
            if src_batch[0].data.shape[0] != mb_size:
                log.warn("got minibatch of size %i instead of %i" % (src_batch[0].data.shape[0], mb_size))

#             if lexical_probability_dictionary is not None:
#                 lexicon_matrix = utils.compute_lexicon_matrix(src_batch, lexical_probability_dictionary)
#                 if gpu is not None:
#                     lexicon_matrix = cuda.to_gpu(lexicon_matrix, gpu)
#             else:
#                 lexicon_matrix = None

#             if i%100 == 0:
#                 print "valid",
#                 compute_valid()
            if not no_report_or_save:
                if i % sample_every == 0:
                    for v in src_batch + tgt_batch:
                        v.volatile = "on"
                    sample_once(encdec, src_batch, tgt_batch, src_mask, src_indexer, tgt_indexer, eos_idx,
                                max_nb=20,
                                s_unk_tag=s_unk_tag, t_unk_tag=t_unk_tag)
                    for v in src_batch + tgt_batch:
                        v.volatile = "off"
                if i % report_every == 0:
                    current_time = time.clock()
                    if prev_i is not None:
                        iteration_interval = i - prev_i
                        avg_time = (current_time - prev_time) / (iteration_interval)
                        avg_training_loss = total_loss_this_interval / total_nb_predictions_this_interval
                        avg_sentence_size = float(total_nb_predictions_this_interval) / (iteration_interval * mb_size)

                    else:
                        avg_time = 0
                        avg_training_loss = 0
                        avg_sentence_size = 0
                    prev_i = i
                    total_loss_this_interval = 0
                    total_nb_predictions_this_interval = 0

                    print "avg time:", avg_time
                    print "avg training loss:", avg_training_loss
                    print "avg sentence size", avg_sentence_size

                    bc_test = translate_test()
                    test_loss = compute_test_loss()
                    bc_dev = translate_dev()
                    dev_loss = compute_dev_loss()
                    bc_valid = translate_valid()
                    valid_loss = compute_valid_loss()

                    if dev_loss is not None and (best_dev_loss is None or dev_loss <= best_dev_loss):
                        best_dev_loss = dev_loss
                        log.info("saving best loss model %f" % best_dev_loss)
                        save_model("best_loss")

                    if bc_test is not None:

                        assert test_loss is not None
                        import sqlite3
                        import datetime
                        db_path = output_files_dict["sqlite_db"]
                        log.info("saving test results to %s" % (db_path))
                        db_connection = sqlite3.connect(db_path)
                        db_cursor = db_connection.cursor()
                        db_cursor.execute('''CREATE TABLE IF NOT EXISTS exp_data
        (date text, bleu_info text, iteration real,
        loss real, bleu real,
        dev_loss real, dev_bleu real,
        valid_loss real, valid_bleu real,
        avg_time real, avg_training_loss real)''')
                        infos = (datetime.datetime.now().strftime("%I:%M%p %B %d, %Y"),
                                 repr(bc_test), i, float(test_loss), bc_test.bleu(),
                                 float(dev_loss), bc_dev.bleu(),
                                 float(valid_loss) if valid_loss is not None else None, bc_valid.bleu() if bc_valid is not None else None,
                                 avg_time, avg_training_loss)
                        db_cursor.execute("INSERT INTO exp_data VALUES (?,?,?,?,?,?,?,?,?,?,?)", infos)
                        db_connection.commit()
                        db_connection.close()

                        if bc_dev.bleu() > best_dev_bleu:
                            best_dev_bleu = bc_dev.bleu()
                            log.info("saving best model %f" % best_dev_bleu)
                            save_model("best")
                    prev_time = time.clock()
                if i % save_ckpt_every == 0:
                    save_model("ckpt")
                    fn_save_optimizer = output_files_dict["optimizer_ckpt"]
                    log.info(
                        "saving optimizer parameters to %s" %
                        fn_save_optimizer)
                    serializers.save_npz(fn_save_optimizer, optimizer)

            if use_memory_optimization:
                #                 if lexicon_matrix is not None:
                #                     raise NotImplemented
                total_loss, total_nb_predictions = train_once_optim(src_batch, tgt_batch, src_mask)
            elif use_reinf:
                total_loss, total_nb_predictions = train_once_reinf(src_batch, tgt_batch, src_mask)
            else:
                total_loss, total_nb_predictions = train_once(src_batch, tgt_batch, src_mask)
#                 ,
#                                                               lexicon_matrix = lexicon_matrix)

            total_loss_this_interval += total_loss
            total_nb_predictions_this_interval += total_nb_predictions
    finally:
        if not do_not_save_data_for_resuming and not no_report_or_save:
            save_model("final")
            fn_save_optimizer = output_files_dict["optimizer_final"]
            log.info("saving optimizer parameters to %s" % fn_save_optimizer)
            serializers.save_npz(fn_save_optimizer, optimizer)