Esempio n. 1
0
def main():
    global device
    args = parse()
    if args.gpu_id >= 0 and torch.cuda.is_available():
        device = torch.device('cuda:' + str(args.gpu_id))
    s_vocab = pickle.load(open(args.s_vocab, 'rb'))
    t_vocab = pickle.load(open(args.t_vocab, 'rb'))
    vs, es, hs = args.vocab_size, args.embed_size, args.hidden_size
    if args.model_type == 'EncDec':
        model = models.EncoderDecoder(
            s_vocab_size=vs, t_vocab_size=vs, hidden_size=hs, embed_size=es, weight_decay=1e-5
        ).to(device)
    elif args.model_type == 'Attn':
        model = models.AttentionSeq2Seq(
            s_vocab_size=vs, t_vocab_size=vs, embed_size=es, hidden_size=hs,
            num_s_layers=2, bidirectional=True, weight_decay=1e-5
        ).to(device)
    else:
        sys.stderr.write('%s is not found. Model type is `EncDec` or `Attn`.' % args.model_type)
    model.load_state_dict(torch.load(args.model_prefix + '.model'))
    translate(args.src, model, s_vocab, t_vocab, args.output, device, 100, reverse=args.reverse)
Esempio n. 2
0
def command_line2():
    import argparse
    parser = argparse.ArgumentParser(
        description="Use a RNNSearch model",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("lattice_fn")
    parser.add_argument("source_sentence_fn")
    parser.add_argument("training_config", help="prefix of the trained model")
    parser.add_argument("trained_model", help="prefix of the trained model")

    parser.add_argument("--gpu",
                        type=int,
                        help="specify gpu number to use, if any")
    parser.add_argument("--skip_in_src", type=int, default=0)
    args = parser.parse_args()

    config_training_fn = args.training_config  # args.model_prefix + ".train.config"

    log.info("loading model config from %s" % config_training_fn)
    config_training = json.load(open(config_training_fn))

    voc_fn = config_training["voc"]
    log.info("loading voc from %s" % voc_fn)
    src_voc, tgt_voc = json.load(open(voc_fn))

    src_indexer = Indexer.make_from_serializable(src_voc)
    tgt_indexer = Indexer.make_from_serializable(tgt_voc)
    tgt_voc = None
    src_voc = None

    #     Vi = len(src_voc) + 1 # + UNK
    #     Vo = len(tgt_voc) + 1 # + UNK

    Vi = len(src_indexer)  # + UNK
    Vo = len(tgt_indexer)  # + UNK

    print config_training

    Ei = config_training["command_line"]["Ei"]
    Hi = config_training["command_line"]["Hi"]
    Eo = config_training["command_line"]["Eo"]
    Ho = config_training["command_line"]["Ho"]
    Ha = config_training["command_line"]["Ha"]
    Hl = config_training["command_line"]["Hl"]

    eos_idx = Vo
    encdec = models.EncoderDecoder(Vi, Ei, Hi, Vo + 1, Eo, Ho, Ha, Hl)

    log.info("loading model from %s" % args.trained_model)
    serializers.load_npz(args.trained_model, encdec)

    if args.gpu is not None:
        encdec = encdec.to_gpu(args.gpu)

    src_sent_f = codecs.open(args.source_sentence_fn, encoding="utf8")
    for _ in xrange(args.skip_in_src):
        src_sent_f.readline()
    src_sentence = src_sent_f.readline().strip().split(" ")
    log.info("translating sentence %s" % (" ".join(src_sentence)))
    src_seq = src_indexer.convert(src_sentence)
    log.info("src seq: %r" % src_seq)

    log.info("loading lattice %s" % args.lattice_fn)
    lattice_f = codecs.open(args.lattice_fn, "r", encoding="utf8")
    all_edges = parse_lattice_file(lattice_f)
    log.info("loaded")

    lattice_map = [None] * len(all_edges)
    for num_lattice, edge_list in enumerate(all_edges):
        lattice_map[num_lattice] = Lattice(edge_list)
        top_lattice_id = num_lattice

    log.info("built lattices")

    log.info("removing epsilons")
    log.info("nb edges before %i" % sum(
        len(edge_list) for lattice in lattice_map
        for edge_list in lattice.outgoing.itervalues()))
    remove_all_epsilons(lattice_map)
    log.info("nb edges before %i" % sum(
        len(edge_list) for lattice in lattice_map
        for edge_list in lattice.outgoing.itervalues()))

    if args.gpu is not None:
        seq_as_batch = [
            Variable(cuda.to_gpu(np.array([x], dtype=np.int32), args.gpu),
                     volatile="on") for x in src_seq
        ]
    else:
        seq_as_batch = [
            Variable(np.array([x], dtype=np.int32), volatile="on")
            for x in src_seq
        ]
    predictor = encdec.get_predictor(seq_as_batch, [])

    global_memoizer = {}
    global_count_memoizer = {}
    initial_node = Node(top_lattice_id)
    initial_node.add_elem(PosElem(Lattice.kInitial))
    current_path = initial_node
    selected_seq = []
    while True:
        print "#node current_path", current_path.count_distincts_subnodes()
        current_path.assert_is_reduced_and_consistent()
        next_words_set = current_path.get_next_w(lattice_map, global_memoizer,
                                                 global_count_memoizer)
        for w in next_words_set:
            next_words_set[w] = sum(next_words_set[w].itervalues())
        has_eos = Lattice.EOS in next_words_set
        next_words_list = sorted(
            list(w for w in next_words_set if w != Lattice.EOS))
        print "next_words_set", next_words_set
        voc_choice = tgt_indexer.convert(next_words_list)
        if has_eos:
            voc_choice.append(eos_idx)
        chosen = predictor(voc_choice)

        if chosen != eos_idx and tgt_indexer.is_unk_idx(chosen):
            print "warning: unk chosen"
            unk_list = []
            for ix, t_idx in enumerate(voc_choice):
                if tgt_indexer.is_unk_idx(t_idx):
                    unk_list.append((next_words_set[next_words_list[ix]],
                                     next_words_list[ix]))
            unk_list.sort(reverse=True)
            print "UNK:", unk_list
            selected_w = unk_list[0][1]
        else:
            idx_chosen = voc_choice.index(
                chosen
            )  # TODO: better handling when several tgt candidates map to UNK

            selected_w = (next_words_list + [Lattice.EOS])[idx_chosen]


#         for num_word, word in enumerate(next_words_list):
#             print num_word, word
#         print "selected_seq", selected_seq
#         i = int(raw_input("choice\n"))
#         selected_w = next_words_list[i]
#

        selected_seq.append(selected_w)
        print "selected_seq", selected_seq

        current_path.update_better(selected_w, lattice_map, global_memoizer)
        current_path.reduce()
        if current_path.is_empty_node():
            print "DONE"
            break
    print "final seq:", selected_seq
Esempio n. 3
0
def create_encdec_from_config(config_training):

    voc_fn = config_training["voc"]
    log.info("loading voc from %s"% voc_fn)
    src_voc, tgt_voc = json.load(open(voc_fn))
    
    src_indexer = Indexer.make_from_serializable(src_voc)
    tgt_indexer = Indexer.make_from_serializable(tgt_voc)
    tgt_voc = None
    src_voc = None
    
    
#     Vi = len(src_voc) + 1 # + UNK
#     Vo = len(tgt_voc) + 1 # + UNK
    
    Vi = len(src_indexer) # + UNK
    Vo = len(tgt_indexer) # + UNK
    
    print config_training
    
    Ei = config_training["command_line"]["Ei"]
    Hi = config_training["command_line"]["Hi"]
    Eo = config_training["command_line"]["Eo"]
    Ho = config_training["command_line"]["Ho"]
    Ha = config_training["command_line"]["Ha"]
    Hl = config_training["command_line"]["Hl"]
    
    is_multitarget = config_training["is_multitarget"]

    if is_multitarget:
        print "Last state of backward encoder RNN is first state of decoder RNN."

    encoder_cell_type = config_training["command_line"].get("encoder_cell_type", "gru")
    decoder_cell_type = config_training["command_line"].get("decoder_cell_type", "gru")
    
    use_bn_length = config_training["command_line"].get("use_bn_length", None)
    
    import gzip
    
    if "lexical_probability_dictionary" in config_training["command_line"] and config_training["command_line"]["lexical_probability_dictionary"] is not None:
        log.info("opening lexical_probability_dictionary %s" % config_training["command_line"]["lexical_probability_dictionary"])
        lexical_probability_dictionary_all = json.load(gzip.open(config_training["command_line"]["lexical_probability_dictionary"], "rb"))
        log.info("computing lexical_probability_dictionary_indexed")
        lexical_probability_dictionary_indexed = {}
        for ws in lexical_probability_dictionary_all:
            ws_idx = src_indexer.convert([ws])[0]
            if ws_idx in lexical_probability_dictionary_indexed:
                assert src_indexer.is_unk_idx(ws_idx)
            else:
                lexical_probability_dictionary_indexed[ws_idx] = {}
            for wt in lexical_probability_dictionary_all[ws]:
                wt_idx = tgt_indexer.convert([wt])[0]
                if wt_idx in lexical_probability_dictionary_indexed[ws_idx]:
                    assert src_indexer.is_unk_idx(ws_idx) or tgt_indexer.is_unk_idx(wt_idx)
                    lexical_probability_dictionary_indexed[ws_idx][wt_idx] += lexical_probability_dictionary_all[ws][wt]
                else:
                    lexical_probability_dictionary_indexed[ws_idx][wt_idx] = lexical_probability_dictionary_all[ws][wt]
        lexical_probability_dictionary = lexical_probability_dictionary_indexed
    else:
        lexical_probability_dictionary = None
    
    eos_idx = Vo
    encdec = models.EncoderDecoder(Vi, Ei, Hi, Vo + 1, Eo, Ho, Ha, Hl, use_bn_length = use_bn_length,
                                   encoder_cell_type = rnn_cells.create_cell_model_from_string(encoder_cell_type),
                                       decoder_cell_type = rnn_cells.create_cell_model_from_string(decoder_cell_type),
                                       lexical_probability_dictionary = lexical_probability_dictionary,
                                       lex_epsilon = config_training["command_line"].get("lexicon_prob_epsilon", 0.001), is_multitarget = is_multitarget)
    
    return encdec, eos_idx, src_indexer, tgt_indexer
Esempio n. 4
0
    return all_result, all_loss.mean().item()

if args.encode_savepath != '':
    start_time = time.time()
    cp = torch.load(args.load_state)
    c_args = cp['args']
    c_args.load_state = args.load_state
    c_args.encode_savepath = args.encode_savepath
    c_args.test_data = args.test_data
    args = c_args
    print('checkpoint arguments loaded')
    print(args)
    print(time.time() - start_time)
    #start_time = time.time()
    testdata = readdata.readfile(args.test_data, args.batch, args.max_length, 'cut', False)
    model = cuda(models.EncoderDecoder(2, args.hidden_size, args.layers, args.dropout, False), args.cuda)
    model.load_state_dict(cp['state_dict'])
    loss = MSE
    print(time.time() - start_time)
    #start_time = time.time()
    all_test_result, all_test_loss = eval_data(testdata, 100)
    all_test_result = np.array(all_test_result)
    pickle.dump(all_test_result, open(args.encode_savepath, 'wb'))
    print(time.time() - start_time)
    exit()

print(args)

if not os.path.exists(args.checkpoint):
    os.mkdir(args.checkpoint)
if not os.path.isdir(args.checkpoint):
Esempio n. 5
0
def command_line(arguments=None):
    import argparse
    parser = argparse.ArgumentParser(
        description="Train a RNNSearch model",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        "data_prefix",
        help="prefix of the training data created by make_data.py")
    parser.add_argument(
        "save_prefix",
        help="prefix to be added to all files created during the training")
    parser.add_argument("--gpu",
                        type=int,
                        nargs="+",
                        default=None,
                        help="specify gpu number to use, if any")
    #parser.add_argument("--gpulist", type = int, nargs = "+", default = None, help = "specify gpu number to use, if any")
    parser.add_argument(
        "--load_model",
        help="load the parameters of a previously trained model")
    parser.add_argument("--load_optimizer_state",
                        help="load previously saved optimizer states")
    parser.add_argument("--Ei",
                        type=int,
                        default=620,
                        help="Source words embedding size.")
    parser.add_argument("--Eo",
                        type=int,
                        default=620,
                        help="Target words embedding size.")
    parser.add_argument("--Hi",
                        type=int,
                        default=1000,
                        help="Source encoding layer size.")
    parser.add_argument("--Ho",
                        type=int,
                        default=1000,
                        help="Target hidden layer size.")
    parser.add_argument("--Ha",
                        type=int,
                        default=1000,
                        help="Attention Module Hidden layer size.")
    parser.add_argument("--Hl",
                        type=int,
                        default=500,
                        help="Maxout output size.")
    parser.add_argument("--mb_size",
                        type=int,
                        default=80,
                        help="Minibatch size")
    parser.add_argument("--nb_batch_to_sort",
                        type=int,
                        default=20,
                        help="Sort this many batches by size.")
    parser.add_argument("--noise_on_prev_word",
                        default=False,
                        action="store_true")

    parser.add_argument(
        "--use_memory_optimization",
        default=False,
        action="store_true",
        help="Experimental option that could strongly reduce memory used.")

    parser.add_argument("--max_nb_iters",
                        type=int,
                        default=None,
                        help="maximum number of iterations")

    parser.add_argument("--max_src_tgt_length",
                        type=int,
                        help="Limit length of training sentences")

    parser.add_argument("--l2_gradient_clipping",
                        type=float,
                        default=1,
                        help="L2 gradient clipping. 0 for None")

    parser.add_argument("--hard_gradient_clipping",
                        type=float,
                        nargs=2,
                        help="hard gradient clipping.")

    parser.add_argument("--weight_decay",
                        type=float,
                        help="Weight decay value. ")

    parser.add_argument("--optimizer",
                        choices=[
                            "sgd", "rmsprop", "rmspropgraves", "momentum",
                            "nesterov", "adam", "adagrad", "adadelta"
                        ],
                        default="adam",
                        help="Optimizer type.")
    parser.add_argument("--learning_rate",
                        type=float,
                        default=0.01,
                        help="Learning Rate")
    parser.add_argument("--momentum",
                        type=float,
                        default=0.9,
                        help="Momentum term")
    parser.add_argument("--report_every",
                        type=int,
                        default=200,
                        help="report every x iterations")
    parser.add_argument("--randomized_data",
                        default=False,
                        action="store_true")
    parser.add_argument("--use_accumulated_attn",
                        default=False,
                        action="store_true")

    parser.add_argument("--use_deep_attn", default=False, action="store_true")

    parser.add_argument("--no_shuffle_of_training_data",
                        default=False,
                        action="store_true")
    parser.add_argument("--no_resume", default=False, action="store_true")

    parser.add_argument("--init_orth", default=False, action="store_true")

    parser.add_argument("--reverse_src", default=False, action="store_true")
    parser.add_argument("--reverse_tgt", default=False, action="store_true")

    parser.add_argument("--curiculum_training",
                        default=False,
                        action="store_true")

    parser.add_argument("--use_bn_length", default=0, type=int)
    parser.add_argument("--use_previous_prediction", default=0, type=float)

    parser.add_argument("--no_report_or_save",
                        default=False,
                        action="store_true")

    parser.add_argument(
        "--lexical_probability_dictionary",
        help=
        "lexical translation probabilities in zipped JSON format. Used to implement https://arxiv.org/abs/1606.02006"
    )
    parser.add_argument(
        "--lexicon_prob_epsilon",
        default=1e-3,
        type=float,
        help="epsilon value for combining the lexical probabilities")

    parser.add_argument(
        "--encoder_cell_type",
        default="lstm",
        help=
        "cell type of encoder. format: type,param1:val1,param2:val2,... where type is in [%s]"
        % (" ".join(rnn_cells.cell_dict.keys())))
    parser.add_argument(
        "--decoder_cell_type",
        default="lstm",
        help="cell type of decoder. format same as for encoder")

    parser.add_argument("--sample_every", default=200, type=int)

    parser.add_argument("--save_ckpt_every", default=4000, type=int)

    parser.add_argument("--use_reinf", default=False, action="store_true")

    parser.add_argument("--is_multitarget", default=False, action="store_true")
    parser.add_argument(
        "--postprocess",
        default=False,
        action="store_true",
        help=
        "This flag indicates whether the translations should be postprocessed or not. For now it simply indicates that the BPE segmentation should be undone."
    )

    args = parser.parse_args(args=arguments)

    output_files_dict = {}
    output_files_dict["train_config"] = args.save_prefix + ".train.config"
    output_files_dict[
        "model_ckpt"] = args.save_prefix + ".model." + "ckpt" + ".npz"
    output_files_dict[
        "model_final"] = args.save_prefix + ".model." + "final" + ".npz"
    output_files_dict[
        "model_best"] = args.save_prefix + ".model." + "best" + ".npz"
    output_files_dict[
        "model_best_loss"] = args.save_prefix + ".model." + "best_loss" + ".npz"

    output_files_dict[
        "test_translation_output"] = args.save_prefix + ".test.out"
    output_files_dict["test_src_output"] = args.save_prefix + ".test.src.out"
    output_files_dict["dev_translation_output"] = args.save_prefix + ".dev.out"
    output_files_dict["dev_src_output"] = args.save_prefix + ".dev.src.out"
    output_files_dict[
        "valid_translation_output"] = args.save_prefix + ".valid.out"
    output_files_dict["valid_src_output"] = args.save_prefix + ".valid.src.out"
    output_files_dict["sqlite_db"] = args.save_prefix + ".result.sqlite"
    output_files_dict[
        "optimizer_ckpt"] = args.save_prefix + ".optimizer." + "ckpt" + ".npz"
    output_files_dict[
        "optimizer_final"] = args.save_prefix + ".optimizer." + "final" + ".npz"

    save_prefix_dir, save_prefix_fn = os.path.split(args.save_prefix)
    ensure_path(save_prefix_dir)

    already_existing_files = []
    for key_info, filename in output_files_dict.iteritems(
    ):  #, valid_data_fn]:
        if os.path.exists(filename):
            already_existing_files.append(filename)
    if len(already_existing_files) > 0:
        print "Warning: existing files are going to be replaced / updated: ", already_existing_files
        #raw_input("Press Enter to Continue")

    config_fn = args.data_prefix + ".data.config"
    voc_fn = args.data_prefix + ".voc"
    data_fn = args.data_prefix + ".data.json.gz"

    log.info("loading training data from %s" % data_fn)
    training_data_all = json.load(gzip.open(data_fn, "rb"))

    training_data = training_data_all["train"]

    log.info("loaded %i sentences as training data" % len(training_data))

    if "test" in training_data_all:
        test_data = training_data_all["test"]
        log.info("Found test data: %i sentences" % len(test_data))
    else:
        test_data = None
        log.info("No test data found")

    if "dev" in training_data_all:
        dev_data = training_data_all["dev"]
        log.info("Found dev data: %i sentences" % len(dev_data))
    else:
        dev_data = None
        log.info("No dev data found")

    if "valid" in training_data_all:
        valid_data = training_data_all["valid"]
        log.info("Found valid data: %i sentences" % len(valid_data))
    else:
        valid_data = None
        log.info("No valid data found")

    log.info("loading voc from %s" % voc_fn)
    src_voc, tgt_voc = json.load(open(voc_fn))

    src_indexer = Indexer.make_from_serializable(src_voc)
    tgt_indexer = Indexer.make_from_serializable(tgt_voc)
    tgt_voc = None
    src_voc = None

    #     Vi = len(src_voc) + 1 # + UNK
    #     Vo = len(tgt_voc) + 1 # + UNK

    Vi = len(src_indexer)  # + UNK
    Vo = len(tgt_indexer)  # + UNK

    if args.lexical_probability_dictionary is not None:
        log.info("opening lexical_probability_dictionary %s" %
                 args.lexical_probability_dictionary)
        lexical_probability_dictionary_all = json.load(
            gzip.open(args.lexical_probability_dictionary, "rb"))
        log.info("computing lexical_probability_dictionary_indexed")
        lexical_probability_dictionary_indexed = {}
        for ws in lexical_probability_dictionary_all:
            ws_idx = src_indexer.convert([ws])[0]
            if ws_idx in lexical_probability_dictionary_indexed:
                assert src_indexer.is_unk_idx(ws_idx)
            else:
                lexical_probability_dictionary_indexed[ws_idx] = {}
            for wt in lexical_probability_dictionary_all[ws]:
                wt_idx = tgt_indexer.convert([wt])[0]
                if wt_idx in lexical_probability_dictionary_indexed[ws_idx]:
                    assert src_indexer.is_unk_idx(
                        ws_idx) or tgt_indexer.is_unk_idx(wt_idx)
                    lexical_probability_dictionary_indexed[ws_idx][
                        wt_idx] += lexical_probability_dictionary_all[ws][wt]
                else:
                    lexical_probability_dictionary_indexed[ws_idx][
                        wt_idx] = lexical_probability_dictionary_all[ws][wt]
        lexical_probability_dictionary = lexical_probability_dictionary_indexed
    else:
        lexical_probability_dictionary = None

    if args.max_src_tgt_length is not None:
        log.info("filtering sentences of length larger than %i" %
                 (args.max_src_tgt_length))
        filtered_training_data = []
        nb_filtered = 0
        for src, tgt in training_data:
            if len(src) <= args.max_src_tgt_length and len(
                    tgt) <= args.max_src_tgt_length:
                filtered_training_data.append((src, tgt))
            else:
                nb_filtered += 1
        log.info("filtered %i sentences of length larger than %i" %
                 (nb_filtered, args.max_src_tgt_length))
        training_data = filtered_training_data

    if not args.no_shuffle_of_training_data:
        log.info("shuffling")
        import random
        random.shuffle(training_data)
        log.info("done")

#
#     Vi = len(src_voc) + 1 # + UNK
#     Vo = len(tgt_voc) + 1 # + UNK

    is_multitarget = args.is_multitarget

    config_training = {
        "command_line": args.__dict__,
        "Vi": Vi,
        "Vo": Vo,
        "voc": voc_fn,
        "data": data_fn,
        "is_multitarget": is_multitarget
    }
    save_train_config_fn = output_files_dict["train_config"]
    log.info("Saving training config to %s" % save_train_config_fn)
    with io.open(save_train_config_fn, 'w', encoding="utf-8") as outfile:
        outfile.write(unicode(json.dumps(config_training, ensure_ascii=False)))
    #json.dump(config_training, open(save_train_config_fn, "w"), indent=2, separators=(',', ': '))

    eos_idx = Vo

    # Selecting Attention type
    attn_cls = models.AttentionModule
    if args.use_accumulated_attn:
        raise NotImplemented
#         encdec = models.EncoderDecoder(Vi, args.Ei, args.Hi, Vo + 1, args.Eo, args.Ho, args.Ha, args.Hl,
#                                        attn_cls= models.AttentionModuleAcumulated,
#                                        init_orth = args.init_orth)
    if args.use_deep_attn:
        attn_cls = models.DeepAttentionModule

    # Creating encoder/decoder
    encdec = models.EncoderDecoder(
        Vi,
        args.Ei,
        args.Hi,
        Vo + 1,
        args.Eo,
        args.Ho,
        args.Ha,
        args.Hl,
        init_orth=args.init_orth,
        use_bn_length=args.use_bn_length,
        attn_cls=attn_cls,
        encoder_cell_type=args.encoder_cell_type,
        decoder_cell_type=args.decoder_cell_type,
        lexical_probability_dictionary=lexical_probability_dictionary,
        lex_epsilon=args.lexicon_prob_epsilon,
        is_multitarget=is_multitarget)

    if args.load_model is not None:
        serializers.load_npz(args.load_model, encdec)

    if args.gpu is not None:
        models_list = []
        models_list.append(encdec)
        import copy
        for i in range(len(args.gpu) - 1):
            log.info(
                "Creating copy #%d of model for data parallel computation." %
                (i + 1))
            encdec_copy = copy.deepcopy(encdec)
            models_list.append(encdec_copy)
        for i in range(len(args.gpu)):
            models_list[i] = models_list[i].to_gpu(args.gpu[i])
        assert models_list[0] == encdec

    #print len(models_list)

    if args.optimizer == "adadelta":
        optimizer = optimizers.AdaDelta()
    elif args.optimizer == "adam":
        optimizer = optimizers.Adam()
    elif args.optimizer == "adagrad":
        optimizer = optimizers.AdaGrad(lr=args.learning_rate)
    elif args.optimizer == "sgd":
        optimizer = optimizers.SGD(lr=args.learning_rate)
    elif args.optimizer == "momentum":
        optimizer = optimizers.MomentumSGD(lr=args.learning_rate,
                                           momentum=args.momentum)
    elif args.optimizer == "nesterov":
        optimizer = optimizers.NesterovAG(lr=args.learning_rate,
                                          momentum=args.momentum)
    elif args.optimizer == "rmsprop":
        optimizer = optimizers.RMSprop(lr=args.learning_rate)
    elif args.optimizer == "rmspropgraves":
        optimizer = optimizers.RMSpropGraves(lr=args.learning_rate,
                                             momentum=args.momentum)
    else:
        raise NotImplemented
    with cuda.get_device(args.gpu):
        optimizer.setup(encdec)

    if args.l2_gradient_clipping is not None and args.l2_gradient_clipping > 0:
        optimizer.add_hook(
            chainer.optimizer.GradientClipping(args.l2_gradient_clipping))

    if args.hard_gradient_clipping is not None and args.hard_gradient_clipping > 0:
        optimizer.add_hook(
            chainer.optimizer.GradientHardClipping(
                *args.hard_gradient_clipping))

    if args.weight_decay is not None:
        optimizer.add_hook(chainer.optimizer.WeightDecay(args.weight_decay))

    if args.load_optimizer_state is not None:
        with cuda.get_device(args.gpu):
            serializers.load_npz(args.load_optimizer_state, optimizer)

    with cuda.get_device(args.gpu[0]):
        #         with MyTimerHook() as timer:
        #             try:
        train_on_data(
            encdec,
            optimizer,
            training_data,
            output_files_dict,
            src_indexer,
            tgt_indexer,
            eos_idx=eos_idx,
            mb_size=args.mb_size,
            nb_of_batch_to_sort=args.nb_batch_to_sort * len(args.gpu),
            test_data=test_data,
            dev_data=dev_data,
            valid_data=valid_data,
            gpu=args.gpu,
            report_every=args.report_every,
            randomized=args.randomized_data,
            reverse_src=args.reverse_src,
            reverse_tgt=args.reverse_tgt,
            max_nb_iters=args.max_nb_iters,
            do_not_save_data_for_resuming=args.no_resume,
            noise_on_prev_word=args.noise_on_prev_word,
            curiculum_training=args.curiculum_training,
            use_previous_prediction=args.use_previous_prediction,
            no_report_or_save=args.no_report_or_save,
            use_memory_optimization=args.use_memory_optimization,
            sample_every=args.sample_every,
            use_reinf=args.use_reinf,
            save_ckpt_every=args.save_ckpt_every,
            postprocess=args.postprocess,
            models_list=models_list
            #                     lexical_probability_dictionary = lexical_probability_dictionary,
            #                     V_tgt = Vo + 1,
            #                     lexicon_prob_epsilon = args.lexicon_prob_epsilon
        )


#             finally:
#                 print timer
#                 timer.print_sorted()
#                 print "total time:"
#                 print(timer.total_time())

    import sys
    sys.exit(0)

    import training_chainer
    with cuda.get_device(args.gpu):
        training_chainer.train_on_data_chainer(
            encdec,
            optimizer,
            training_data,
            output_files_dict,
            src_indexer,
            tgt_indexer,
            eos_idx=eos_idx,
            output_dir=args.save_prefix,
            stop_trigger=None,
            mb_size=args.mb_size,
            nb_of_batch_to_sort=args.nb_batch_to_sort,
            test_data=test_data,
            dev_data=dev_data,
            valid_data=valid_data,
            gpu=args.gpu,
            report_every=args.report_every,
            randomized=args.randomized_data,
            reverse_src=args.reverse_src,
            reverse_tgt=args.reverse_tgt,
            max_nb_iters=args.max_nb_iters,
            do_not_save_data_for_resuming=args.no_resume,
            noise_on_prev_word=args.noise_on_prev_word,
            curiculum_training=args.curiculum_training,
            use_previous_prediction=args.use_previous_prediction,
            no_report_or_save=args.no_report_or_save,
            use_memory_optimization=args.use_memory_optimization,
            sample_every=args.sample_every,
            use_reinf=args.use_reinf,
            save_ckpt_every=args.save_ckpt_every,
            postprocess=args.postprocess
            #                     lexical_probability_dictionary = lexical_probability_dictionary,
            #                     V_tgt = Vo + 1,
            #                     lexicon_prob_epsilon = args.lexicon_prob_epsilon
        )
Esempio n. 6
0
def main():
    global device
    args = parse()
    s_vocab = utils.make_vocab(args.train_src, args.vocab_size)
    t_vocab = utils.make_vocab(args.train_tgt, args.vocab_size)
    train_source_seqs, train_target_seqs = [], []
    valid_source_seqs, valid_target_seqs = [], []

    if args.gpu_id is not None and torch.cuda.is_available():
        device = torch.device('cuda:' + args.gpu_id[0])

    # ファイルを全てID列に変換
    with open(args.train_src, encoding='utf-8') as fin:
        for line in fin:
            train_source_seqs.append([
                s_vocab[t] if t in s_vocab else s_vocab['<UNK>']
                for t in line.strip().split(' ')
            ])
    with open(args.train_tgt, encoding='utf-8') as fin:
        for line in fin:
            train_target_seqs.append([
                t_vocab[t] if t in t_vocab else t_vocab['<UNK>']
                for t in line.strip().split(' ')
            ])
    with open(args.valid_src, encoding='utf-8') as fin:
        for line in fin:
            valid_source_seqs.append([
                s_vocab[t] if t in s_vocab else s_vocab['<UNK>']
                for t in line.strip().split(' ')
            ])
    with open(args.valid_tgt, encoding='utf-8') as fin:
        for line in fin:
            valid_target_seqs.append([
                t_vocab[t] if t in t_vocab else t_vocab['<UNK>']
                for t in line.strip().split(' ')
            ])

    if args.model_type == 'EncDec':
        model = models.EncoderDecoder(s_vocab_size=args.vocab_size,
                                      t_vocab_size=args.vocab_size,
                                      embed_size=args.embed_size,
                                      hidden_size=args.hidden_size,
                                      weight_decay=1e-5).to(device)
    elif args.model_type == 'Attn':
        model = models.AttentionSeq2Seq(s_vocab_size=args.vocab_size,
                                        t_vocab_size=args.vocab_size,
                                        embed_size=args.embed_size,
                                        hidden_size=args.hidden_size,
                                        num_s_layers=2,
                                        bidirectional=True,
                                        weight_decay=1e-5).to(device)
    else:
        sys.stderr.write('%s is not found. Model type is `EncDec` or `Attn`.' %
                         args.model_type)

    if args.gpu_id is not None and len(args.gpu_id) > 1:
        model = torch.nn.DataParallel(model, device_ids=args.gpu_id)

    train_losses, valid_losses = train(train_source_seqs, train_target_seqs,
                                       valid_source_seqs, valid_target_seqs,
                                       model, s_vocab, t_vocab, args.epochs,
                                       args.batch_size, device, args.reverse)

    # テストデータの翻訳に必要な各データを出力
    pickle.dump(s_vocab, open('s_vocab.pkl', 'wb'))
    pickle.dump(t_vocab, open('t_vocab.pkl', 'wb'))
    torch.save(model.state_dict(), args.model_prefix + '.model')

    plt.plot(np.array([i for i in range(1,
                                        len(train_losses) + 1)]),
             train_losses,
             label='train loss')
    plt.plot(np.array([i for i in range(1,
                                        len(valid_losses) + 1)]),
             valid_losses,
             label='valid loss')
    plt.xlabel('Epochs')
    plt.ylabel('loss')
    plt.legend()
    plt.tight_layout()
    plt.savefig('loss_curve.pdf')