def main(): global device args = parse() if args.gpu_id >= 0 and torch.cuda.is_available(): device = torch.device('cuda:' + str(args.gpu_id)) s_vocab = pickle.load(open(args.s_vocab, 'rb')) t_vocab = pickle.load(open(args.t_vocab, 'rb')) vs, es, hs = args.vocab_size, args.embed_size, args.hidden_size if args.model_type == 'EncDec': model = models.EncoderDecoder( s_vocab_size=vs, t_vocab_size=vs, hidden_size=hs, embed_size=es, weight_decay=1e-5 ).to(device) elif args.model_type == 'Attn': model = models.AttentionSeq2Seq( s_vocab_size=vs, t_vocab_size=vs, embed_size=es, hidden_size=hs, num_s_layers=2, bidirectional=True, weight_decay=1e-5 ).to(device) else: sys.stderr.write('%s is not found. Model type is `EncDec` or `Attn`.' % args.model_type) model.load_state_dict(torch.load(args.model_prefix + '.model')) translate(args.src, model, s_vocab, t_vocab, args.output, device, 100, reverse=args.reverse)
def command_line2(): import argparse parser = argparse.ArgumentParser( description="Use a RNNSearch model", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("lattice_fn") parser.add_argument("source_sentence_fn") parser.add_argument("training_config", help="prefix of the trained model") parser.add_argument("trained_model", help="prefix of the trained model") parser.add_argument("--gpu", type=int, help="specify gpu number to use, if any") parser.add_argument("--skip_in_src", type=int, default=0) args = parser.parse_args() config_training_fn = args.training_config # args.model_prefix + ".train.config" log.info("loading model config from %s" % config_training_fn) config_training = json.load(open(config_training_fn)) voc_fn = config_training["voc"] log.info("loading voc from %s" % voc_fn) src_voc, tgt_voc = json.load(open(voc_fn)) src_indexer = Indexer.make_from_serializable(src_voc) tgt_indexer = Indexer.make_from_serializable(tgt_voc) tgt_voc = None src_voc = None # Vi = len(src_voc) + 1 # + UNK # Vo = len(tgt_voc) + 1 # + UNK Vi = len(src_indexer) # + UNK Vo = len(tgt_indexer) # + UNK print config_training Ei = config_training["command_line"]["Ei"] Hi = config_training["command_line"]["Hi"] Eo = config_training["command_line"]["Eo"] Ho = config_training["command_line"]["Ho"] Ha = config_training["command_line"]["Ha"] Hl = config_training["command_line"]["Hl"] eos_idx = Vo encdec = models.EncoderDecoder(Vi, Ei, Hi, Vo + 1, Eo, Ho, Ha, Hl) log.info("loading model from %s" % args.trained_model) serializers.load_npz(args.trained_model, encdec) if args.gpu is not None: encdec = encdec.to_gpu(args.gpu) src_sent_f = codecs.open(args.source_sentence_fn, encoding="utf8") for _ in xrange(args.skip_in_src): src_sent_f.readline() src_sentence = src_sent_f.readline().strip().split(" ") log.info("translating sentence %s" % (" ".join(src_sentence))) src_seq = src_indexer.convert(src_sentence) log.info("src seq: %r" % src_seq) log.info("loading lattice %s" % args.lattice_fn) lattice_f = codecs.open(args.lattice_fn, "r", encoding="utf8") all_edges = parse_lattice_file(lattice_f) log.info("loaded") lattice_map = [None] * len(all_edges) for num_lattice, edge_list in enumerate(all_edges): lattice_map[num_lattice] = Lattice(edge_list) top_lattice_id = num_lattice log.info("built lattices") log.info("removing epsilons") log.info("nb edges before %i" % sum( len(edge_list) for lattice in lattice_map for edge_list in lattice.outgoing.itervalues())) remove_all_epsilons(lattice_map) log.info("nb edges before %i" % sum( len(edge_list) for lattice in lattice_map for edge_list in lattice.outgoing.itervalues())) if args.gpu is not None: seq_as_batch = [ Variable(cuda.to_gpu(np.array([x], dtype=np.int32), args.gpu), volatile="on") for x in src_seq ] else: seq_as_batch = [ Variable(np.array([x], dtype=np.int32), volatile="on") for x in src_seq ] predictor = encdec.get_predictor(seq_as_batch, []) global_memoizer = {} global_count_memoizer = {} initial_node = Node(top_lattice_id) initial_node.add_elem(PosElem(Lattice.kInitial)) current_path = initial_node selected_seq = [] while True: print "#node current_path", current_path.count_distincts_subnodes() current_path.assert_is_reduced_and_consistent() next_words_set = current_path.get_next_w(lattice_map, global_memoizer, global_count_memoizer) for w in next_words_set: next_words_set[w] = sum(next_words_set[w].itervalues()) has_eos = Lattice.EOS in next_words_set next_words_list = sorted( list(w for w in next_words_set if w != Lattice.EOS)) print "next_words_set", next_words_set voc_choice = tgt_indexer.convert(next_words_list) if has_eos: voc_choice.append(eos_idx) chosen = predictor(voc_choice) if chosen != eos_idx and tgt_indexer.is_unk_idx(chosen): print "warning: unk chosen" unk_list = [] for ix, t_idx in enumerate(voc_choice): if tgt_indexer.is_unk_idx(t_idx): unk_list.append((next_words_set[next_words_list[ix]], next_words_list[ix])) unk_list.sort(reverse=True) print "UNK:", unk_list selected_w = unk_list[0][1] else: idx_chosen = voc_choice.index( chosen ) # TODO: better handling when several tgt candidates map to UNK selected_w = (next_words_list + [Lattice.EOS])[idx_chosen] # for num_word, word in enumerate(next_words_list): # print num_word, word # print "selected_seq", selected_seq # i = int(raw_input("choice\n")) # selected_w = next_words_list[i] # selected_seq.append(selected_w) print "selected_seq", selected_seq current_path.update_better(selected_w, lattice_map, global_memoizer) current_path.reduce() if current_path.is_empty_node(): print "DONE" break print "final seq:", selected_seq
def create_encdec_from_config(config_training): voc_fn = config_training["voc"] log.info("loading voc from %s"% voc_fn) src_voc, tgt_voc = json.load(open(voc_fn)) src_indexer = Indexer.make_from_serializable(src_voc) tgt_indexer = Indexer.make_from_serializable(tgt_voc) tgt_voc = None src_voc = None # Vi = len(src_voc) + 1 # + UNK # Vo = len(tgt_voc) + 1 # + UNK Vi = len(src_indexer) # + UNK Vo = len(tgt_indexer) # + UNK print config_training Ei = config_training["command_line"]["Ei"] Hi = config_training["command_line"]["Hi"] Eo = config_training["command_line"]["Eo"] Ho = config_training["command_line"]["Ho"] Ha = config_training["command_line"]["Ha"] Hl = config_training["command_line"]["Hl"] is_multitarget = config_training["is_multitarget"] if is_multitarget: print "Last state of backward encoder RNN is first state of decoder RNN." encoder_cell_type = config_training["command_line"].get("encoder_cell_type", "gru") decoder_cell_type = config_training["command_line"].get("decoder_cell_type", "gru") use_bn_length = config_training["command_line"].get("use_bn_length", None) import gzip if "lexical_probability_dictionary" in config_training["command_line"] and config_training["command_line"]["lexical_probability_dictionary"] is not None: log.info("opening lexical_probability_dictionary %s" % config_training["command_line"]["lexical_probability_dictionary"]) lexical_probability_dictionary_all = json.load(gzip.open(config_training["command_line"]["lexical_probability_dictionary"], "rb")) log.info("computing lexical_probability_dictionary_indexed") lexical_probability_dictionary_indexed = {} for ws in lexical_probability_dictionary_all: ws_idx = src_indexer.convert([ws])[0] if ws_idx in lexical_probability_dictionary_indexed: assert src_indexer.is_unk_idx(ws_idx) else: lexical_probability_dictionary_indexed[ws_idx] = {} for wt in lexical_probability_dictionary_all[ws]: wt_idx = tgt_indexer.convert([wt])[0] if wt_idx in lexical_probability_dictionary_indexed[ws_idx]: assert src_indexer.is_unk_idx(ws_idx) or tgt_indexer.is_unk_idx(wt_idx) lexical_probability_dictionary_indexed[ws_idx][wt_idx] += lexical_probability_dictionary_all[ws][wt] else: lexical_probability_dictionary_indexed[ws_idx][wt_idx] = lexical_probability_dictionary_all[ws][wt] lexical_probability_dictionary = lexical_probability_dictionary_indexed else: lexical_probability_dictionary = None eos_idx = Vo encdec = models.EncoderDecoder(Vi, Ei, Hi, Vo + 1, Eo, Ho, Ha, Hl, use_bn_length = use_bn_length, encoder_cell_type = rnn_cells.create_cell_model_from_string(encoder_cell_type), decoder_cell_type = rnn_cells.create_cell_model_from_string(decoder_cell_type), lexical_probability_dictionary = lexical_probability_dictionary, lex_epsilon = config_training["command_line"].get("lexicon_prob_epsilon", 0.001), is_multitarget = is_multitarget) return encdec, eos_idx, src_indexer, tgt_indexer
return all_result, all_loss.mean().item() if args.encode_savepath != '': start_time = time.time() cp = torch.load(args.load_state) c_args = cp['args'] c_args.load_state = args.load_state c_args.encode_savepath = args.encode_savepath c_args.test_data = args.test_data args = c_args print('checkpoint arguments loaded') print(args) print(time.time() - start_time) #start_time = time.time() testdata = readdata.readfile(args.test_data, args.batch, args.max_length, 'cut', False) model = cuda(models.EncoderDecoder(2, args.hidden_size, args.layers, args.dropout, False), args.cuda) model.load_state_dict(cp['state_dict']) loss = MSE print(time.time() - start_time) #start_time = time.time() all_test_result, all_test_loss = eval_data(testdata, 100) all_test_result = np.array(all_test_result) pickle.dump(all_test_result, open(args.encode_savepath, 'wb')) print(time.time() - start_time) exit() print(args) if not os.path.exists(args.checkpoint): os.mkdir(args.checkpoint) if not os.path.isdir(args.checkpoint):
def command_line(arguments=None): import argparse parser = argparse.ArgumentParser( description="Train a RNNSearch model", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( "data_prefix", help="prefix of the training data created by make_data.py") parser.add_argument( "save_prefix", help="prefix to be added to all files created during the training") parser.add_argument("--gpu", type=int, nargs="+", default=None, help="specify gpu number to use, if any") #parser.add_argument("--gpulist", type = int, nargs = "+", default = None, help = "specify gpu number to use, if any") parser.add_argument( "--load_model", help="load the parameters of a previously trained model") parser.add_argument("--load_optimizer_state", help="load previously saved optimizer states") parser.add_argument("--Ei", type=int, default=620, help="Source words embedding size.") parser.add_argument("--Eo", type=int, default=620, help="Target words embedding size.") parser.add_argument("--Hi", type=int, default=1000, help="Source encoding layer size.") parser.add_argument("--Ho", type=int, default=1000, help="Target hidden layer size.") parser.add_argument("--Ha", type=int, default=1000, help="Attention Module Hidden layer size.") parser.add_argument("--Hl", type=int, default=500, help="Maxout output size.") parser.add_argument("--mb_size", type=int, default=80, help="Minibatch size") parser.add_argument("--nb_batch_to_sort", type=int, default=20, help="Sort this many batches by size.") parser.add_argument("--noise_on_prev_word", default=False, action="store_true") parser.add_argument( "--use_memory_optimization", default=False, action="store_true", help="Experimental option that could strongly reduce memory used.") parser.add_argument("--max_nb_iters", type=int, default=None, help="maximum number of iterations") parser.add_argument("--max_src_tgt_length", type=int, help="Limit length of training sentences") parser.add_argument("--l2_gradient_clipping", type=float, default=1, help="L2 gradient clipping. 0 for None") parser.add_argument("--hard_gradient_clipping", type=float, nargs=2, help="hard gradient clipping.") parser.add_argument("--weight_decay", type=float, help="Weight decay value. ") parser.add_argument("--optimizer", choices=[ "sgd", "rmsprop", "rmspropgraves", "momentum", "nesterov", "adam", "adagrad", "adadelta" ], default="adam", help="Optimizer type.") parser.add_argument("--learning_rate", type=float, default=0.01, help="Learning Rate") parser.add_argument("--momentum", type=float, default=0.9, help="Momentum term") parser.add_argument("--report_every", type=int, default=200, help="report every x iterations") parser.add_argument("--randomized_data", default=False, action="store_true") parser.add_argument("--use_accumulated_attn", default=False, action="store_true") parser.add_argument("--use_deep_attn", default=False, action="store_true") parser.add_argument("--no_shuffle_of_training_data", default=False, action="store_true") parser.add_argument("--no_resume", default=False, action="store_true") parser.add_argument("--init_orth", default=False, action="store_true") parser.add_argument("--reverse_src", default=False, action="store_true") parser.add_argument("--reverse_tgt", default=False, action="store_true") parser.add_argument("--curiculum_training", default=False, action="store_true") parser.add_argument("--use_bn_length", default=0, type=int) parser.add_argument("--use_previous_prediction", default=0, type=float) parser.add_argument("--no_report_or_save", default=False, action="store_true") parser.add_argument( "--lexical_probability_dictionary", help= "lexical translation probabilities in zipped JSON format. Used to implement https://arxiv.org/abs/1606.02006" ) parser.add_argument( "--lexicon_prob_epsilon", default=1e-3, type=float, help="epsilon value for combining the lexical probabilities") parser.add_argument( "--encoder_cell_type", default="lstm", help= "cell type of encoder. format: type,param1:val1,param2:val2,... where type is in [%s]" % (" ".join(rnn_cells.cell_dict.keys()))) parser.add_argument( "--decoder_cell_type", default="lstm", help="cell type of decoder. format same as for encoder") parser.add_argument("--sample_every", default=200, type=int) parser.add_argument("--save_ckpt_every", default=4000, type=int) parser.add_argument("--use_reinf", default=False, action="store_true") parser.add_argument("--is_multitarget", default=False, action="store_true") parser.add_argument( "--postprocess", default=False, action="store_true", help= "This flag indicates whether the translations should be postprocessed or not. For now it simply indicates that the BPE segmentation should be undone." ) args = parser.parse_args(args=arguments) output_files_dict = {} output_files_dict["train_config"] = args.save_prefix + ".train.config" output_files_dict[ "model_ckpt"] = args.save_prefix + ".model." + "ckpt" + ".npz" output_files_dict[ "model_final"] = args.save_prefix + ".model." + "final" + ".npz" output_files_dict[ "model_best"] = args.save_prefix + ".model." + "best" + ".npz" output_files_dict[ "model_best_loss"] = args.save_prefix + ".model." + "best_loss" + ".npz" output_files_dict[ "test_translation_output"] = args.save_prefix + ".test.out" output_files_dict["test_src_output"] = args.save_prefix + ".test.src.out" output_files_dict["dev_translation_output"] = args.save_prefix + ".dev.out" output_files_dict["dev_src_output"] = args.save_prefix + ".dev.src.out" output_files_dict[ "valid_translation_output"] = args.save_prefix + ".valid.out" output_files_dict["valid_src_output"] = args.save_prefix + ".valid.src.out" output_files_dict["sqlite_db"] = args.save_prefix + ".result.sqlite" output_files_dict[ "optimizer_ckpt"] = args.save_prefix + ".optimizer." + "ckpt" + ".npz" output_files_dict[ "optimizer_final"] = args.save_prefix + ".optimizer." + "final" + ".npz" save_prefix_dir, save_prefix_fn = os.path.split(args.save_prefix) ensure_path(save_prefix_dir) already_existing_files = [] for key_info, filename in output_files_dict.iteritems( ): #, valid_data_fn]: if os.path.exists(filename): already_existing_files.append(filename) if len(already_existing_files) > 0: print "Warning: existing files are going to be replaced / updated: ", already_existing_files #raw_input("Press Enter to Continue") config_fn = args.data_prefix + ".data.config" voc_fn = args.data_prefix + ".voc" data_fn = args.data_prefix + ".data.json.gz" log.info("loading training data from %s" % data_fn) training_data_all = json.load(gzip.open(data_fn, "rb")) training_data = training_data_all["train"] log.info("loaded %i sentences as training data" % len(training_data)) if "test" in training_data_all: test_data = training_data_all["test"] log.info("Found test data: %i sentences" % len(test_data)) else: test_data = None log.info("No test data found") if "dev" in training_data_all: dev_data = training_data_all["dev"] log.info("Found dev data: %i sentences" % len(dev_data)) else: dev_data = None log.info("No dev data found") if "valid" in training_data_all: valid_data = training_data_all["valid"] log.info("Found valid data: %i sentences" % len(valid_data)) else: valid_data = None log.info("No valid data found") log.info("loading voc from %s" % voc_fn) src_voc, tgt_voc = json.load(open(voc_fn)) src_indexer = Indexer.make_from_serializable(src_voc) tgt_indexer = Indexer.make_from_serializable(tgt_voc) tgt_voc = None src_voc = None # Vi = len(src_voc) + 1 # + UNK # Vo = len(tgt_voc) + 1 # + UNK Vi = len(src_indexer) # + UNK Vo = len(tgt_indexer) # + UNK if args.lexical_probability_dictionary is not None: log.info("opening lexical_probability_dictionary %s" % args.lexical_probability_dictionary) lexical_probability_dictionary_all = json.load( gzip.open(args.lexical_probability_dictionary, "rb")) log.info("computing lexical_probability_dictionary_indexed") lexical_probability_dictionary_indexed = {} for ws in lexical_probability_dictionary_all: ws_idx = src_indexer.convert([ws])[0] if ws_idx in lexical_probability_dictionary_indexed: assert src_indexer.is_unk_idx(ws_idx) else: lexical_probability_dictionary_indexed[ws_idx] = {} for wt in lexical_probability_dictionary_all[ws]: wt_idx = tgt_indexer.convert([wt])[0] if wt_idx in lexical_probability_dictionary_indexed[ws_idx]: assert src_indexer.is_unk_idx( ws_idx) or tgt_indexer.is_unk_idx(wt_idx) lexical_probability_dictionary_indexed[ws_idx][ wt_idx] += lexical_probability_dictionary_all[ws][wt] else: lexical_probability_dictionary_indexed[ws_idx][ wt_idx] = lexical_probability_dictionary_all[ws][wt] lexical_probability_dictionary = lexical_probability_dictionary_indexed else: lexical_probability_dictionary = None if args.max_src_tgt_length is not None: log.info("filtering sentences of length larger than %i" % (args.max_src_tgt_length)) filtered_training_data = [] nb_filtered = 0 for src, tgt in training_data: if len(src) <= args.max_src_tgt_length and len( tgt) <= args.max_src_tgt_length: filtered_training_data.append((src, tgt)) else: nb_filtered += 1 log.info("filtered %i sentences of length larger than %i" % (nb_filtered, args.max_src_tgt_length)) training_data = filtered_training_data if not args.no_shuffle_of_training_data: log.info("shuffling") import random random.shuffle(training_data) log.info("done") # # Vi = len(src_voc) + 1 # + UNK # Vo = len(tgt_voc) + 1 # + UNK is_multitarget = args.is_multitarget config_training = { "command_line": args.__dict__, "Vi": Vi, "Vo": Vo, "voc": voc_fn, "data": data_fn, "is_multitarget": is_multitarget } save_train_config_fn = output_files_dict["train_config"] log.info("Saving training config to %s" % save_train_config_fn) with io.open(save_train_config_fn, 'w', encoding="utf-8") as outfile: outfile.write(unicode(json.dumps(config_training, ensure_ascii=False))) #json.dump(config_training, open(save_train_config_fn, "w"), indent=2, separators=(',', ': ')) eos_idx = Vo # Selecting Attention type attn_cls = models.AttentionModule if args.use_accumulated_attn: raise NotImplemented # encdec = models.EncoderDecoder(Vi, args.Ei, args.Hi, Vo + 1, args.Eo, args.Ho, args.Ha, args.Hl, # attn_cls= models.AttentionModuleAcumulated, # init_orth = args.init_orth) if args.use_deep_attn: attn_cls = models.DeepAttentionModule # Creating encoder/decoder encdec = models.EncoderDecoder( Vi, args.Ei, args.Hi, Vo + 1, args.Eo, args.Ho, args.Ha, args.Hl, init_orth=args.init_orth, use_bn_length=args.use_bn_length, attn_cls=attn_cls, encoder_cell_type=args.encoder_cell_type, decoder_cell_type=args.decoder_cell_type, lexical_probability_dictionary=lexical_probability_dictionary, lex_epsilon=args.lexicon_prob_epsilon, is_multitarget=is_multitarget) if args.load_model is not None: serializers.load_npz(args.load_model, encdec) if args.gpu is not None: models_list = [] models_list.append(encdec) import copy for i in range(len(args.gpu) - 1): log.info( "Creating copy #%d of model for data parallel computation." % (i + 1)) encdec_copy = copy.deepcopy(encdec) models_list.append(encdec_copy) for i in range(len(args.gpu)): models_list[i] = models_list[i].to_gpu(args.gpu[i]) assert models_list[0] == encdec #print len(models_list) if args.optimizer == "adadelta": optimizer = optimizers.AdaDelta() elif args.optimizer == "adam": optimizer = optimizers.Adam() elif args.optimizer == "adagrad": optimizer = optimizers.AdaGrad(lr=args.learning_rate) elif args.optimizer == "sgd": optimizer = optimizers.SGD(lr=args.learning_rate) elif args.optimizer == "momentum": optimizer = optimizers.MomentumSGD(lr=args.learning_rate, momentum=args.momentum) elif args.optimizer == "nesterov": optimizer = optimizers.NesterovAG(lr=args.learning_rate, momentum=args.momentum) elif args.optimizer == "rmsprop": optimizer = optimizers.RMSprop(lr=args.learning_rate) elif args.optimizer == "rmspropgraves": optimizer = optimizers.RMSpropGraves(lr=args.learning_rate, momentum=args.momentum) else: raise NotImplemented with cuda.get_device(args.gpu): optimizer.setup(encdec) if args.l2_gradient_clipping is not None and args.l2_gradient_clipping > 0: optimizer.add_hook( chainer.optimizer.GradientClipping(args.l2_gradient_clipping)) if args.hard_gradient_clipping is not None and args.hard_gradient_clipping > 0: optimizer.add_hook( chainer.optimizer.GradientHardClipping( *args.hard_gradient_clipping)) if args.weight_decay is not None: optimizer.add_hook(chainer.optimizer.WeightDecay(args.weight_decay)) if args.load_optimizer_state is not None: with cuda.get_device(args.gpu): serializers.load_npz(args.load_optimizer_state, optimizer) with cuda.get_device(args.gpu[0]): # with MyTimerHook() as timer: # try: train_on_data( encdec, optimizer, training_data, output_files_dict, src_indexer, tgt_indexer, eos_idx=eos_idx, mb_size=args.mb_size, nb_of_batch_to_sort=args.nb_batch_to_sort * len(args.gpu), test_data=test_data, dev_data=dev_data, valid_data=valid_data, gpu=args.gpu, report_every=args.report_every, randomized=args.randomized_data, reverse_src=args.reverse_src, reverse_tgt=args.reverse_tgt, max_nb_iters=args.max_nb_iters, do_not_save_data_for_resuming=args.no_resume, noise_on_prev_word=args.noise_on_prev_word, curiculum_training=args.curiculum_training, use_previous_prediction=args.use_previous_prediction, no_report_or_save=args.no_report_or_save, use_memory_optimization=args.use_memory_optimization, sample_every=args.sample_every, use_reinf=args.use_reinf, save_ckpt_every=args.save_ckpt_every, postprocess=args.postprocess, models_list=models_list # lexical_probability_dictionary = lexical_probability_dictionary, # V_tgt = Vo + 1, # lexicon_prob_epsilon = args.lexicon_prob_epsilon ) # finally: # print timer # timer.print_sorted() # print "total time:" # print(timer.total_time()) import sys sys.exit(0) import training_chainer with cuda.get_device(args.gpu): training_chainer.train_on_data_chainer( encdec, optimizer, training_data, output_files_dict, src_indexer, tgt_indexer, eos_idx=eos_idx, output_dir=args.save_prefix, stop_trigger=None, mb_size=args.mb_size, nb_of_batch_to_sort=args.nb_batch_to_sort, test_data=test_data, dev_data=dev_data, valid_data=valid_data, gpu=args.gpu, report_every=args.report_every, randomized=args.randomized_data, reverse_src=args.reverse_src, reverse_tgt=args.reverse_tgt, max_nb_iters=args.max_nb_iters, do_not_save_data_for_resuming=args.no_resume, noise_on_prev_word=args.noise_on_prev_word, curiculum_training=args.curiculum_training, use_previous_prediction=args.use_previous_prediction, no_report_or_save=args.no_report_or_save, use_memory_optimization=args.use_memory_optimization, sample_every=args.sample_every, use_reinf=args.use_reinf, save_ckpt_every=args.save_ckpt_every, postprocess=args.postprocess # lexical_probability_dictionary = lexical_probability_dictionary, # V_tgt = Vo + 1, # lexicon_prob_epsilon = args.lexicon_prob_epsilon )
def main(): global device args = parse() s_vocab = utils.make_vocab(args.train_src, args.vocab_size) t_vocab = utils.make_vocab(args.train_tgt, args.vocab_size) train_source_seqs, train_target_seqs = [], [] valid_source_seqs, valid_target_seqs = [], [] if args.gpu_id is not None and torch.cuda.is_available(): device = torch.device('cuda:' + args.gpu_id[0]) # ファイルを全てID列に変換 with open(args.train_src, encoding='utf-8') as fin: for line in fin: train_source_seqs.append([ s_vocab[t] if t in s_vocab else s_vocab['<UNK>'] for t in line.strip().split(' ') ]) with open(args.train_tgt, encoding='utf-8') as fin: for line in fin: train_target_seqs.append([ t_vocab[t] if t in t_vocab else t_vocab['<UNK>'] for t in line.strip().split(' ') ]) with open(args.valid_src, encoding='utf-8') as fin: for line in fin: valid_source_seqs.append([ s_vocab[t] if t in s_vocab else s_vocab['<UNK>'] for t in line.strip().split(' ') ]) with open(args.valid_tgt, encoding='utf-8') as fin: for line in fin: valid_target_seqs.append([ t_vocab[t] if t in t_vocab else t_vocab['<UNK>'] for t in line.strip().split(' ') ]) if args.model_type == 'EncDec': model = models.EncoderDecoder(s_vocab_size=args.vocab_size, t_vocab_size=args.vocab_size, embed_size=args.embed_size, hidden_size=args.hidden_size, weight_decay=1e-5).to(device) elif args.model_type == 'Attn': model = models.AttentionSeq2Seq(s_vocab_size=args.vocab_size, t_vocab_size=args.vocab_size, embed_size=args.embed_size, hidden_size=args.hidden_size, num_s_layers=2, bidirectional=True, weight_decay=1e-5).to(device) else: sys.stderr.write('%s is not found. Model type is `EncDec` or `Attn`.' % args.model_type) if args.gpu_id is not None and len(args.gpu_id) > 1: model = torch.nn.DataParallel(model, device_ids=args.gpu_id) train_losses, valid_losses = train(train_source_seqs, train_target_seqs, valid_source_seqs, valid_target_seqs, model, s_vocab, t_vocab, args.epochs, args.batch_size, device, args.reverse) # テストデータの翻訳に必要な各データを出力 pickle.dump(s_vocab, open('s_vocab.pkl', 'wb')) pickle.dump(t_vocab, open('t_vocab.pkl', 'wb')) torch.save(model.state_dict(), args.model_prefix + '.model') plt.plot(np.array([i for i in range(1, len(train_losses) + 1)]), train_losses, label='train loss') plt.plot(np.array([i for i in range(1, len(valid_losses) + 1)]), valid_losses, label='valid loss') plt.xlabel('Epochs') plt.ylabel('loss') plt.legend() plt.tight_layout() plt.savefig('loss_curve.pdf')