def main(): parser = argparse.ArgumentParser() parser.add_argument('--params') parser.add_argument('--training_params') args = parser.parse_args() # todo logdir # todo load from logdir if possible, load params training_params = Trainer.get_default_hparams() if args.training_params: with open(args.training_params, encoding='utf-8') as fin: training_params = training_params.parse_dict(json.load(fin)) logger_config = LOGGING_BASE model_name = os.path.join(training_params.prefix, training_params.model_name) model_folder = os.path.join(TRAINED_MODELS_FOLDER, model_name) latest_folder = find_latest_experiment(model_folder) if os.path.exists( model_folder) else None latest_folder = None if training_params.force_override else latest_folder new_folder = create_new_experiment(model_folder, latest_folder) logger_config['handlers']['debug']['filename'] = os.path.join( new_folder, 'debug_logs') logger_config['handlers']['stdout']['filename'] = os.path.join( new_folder, 'stdout_logs') logging.config.dictConfig(logger_config) copy2(args.training_params, new_folder) copy2(args.params, new_folder) logger.info("Using python binary at {}".format(sys.executable)) os.environ['CUDA_VISIBLE_DEVICES'] = str( training_params.cuda_visible_devices) if torch.cuda.is_available() and training_params.use_cuda: logger.info('GPU found, running on device {}'.format( torch.cuda.current_device())) elif training_params.use_cuda: logger.warning( 'GPU not found, running on CPU. Overriding use_cuda to False.') training_params.set('use_cuda', False) else: logger.debug('GPU found, but use_cuda=False, consider using GPU.') log_experiment_info(model_name, new_folder, latest_folder) hps = get_hps(new_folder, args) dataset_path = os.path.join(DATASET_DIR, hps.dataset) full_dataset, src, tgt = read_problem(dataset_path, n_sents=None) dataset, src, tgt = read_problem(dataset_path, n_sents=len(full_dataset["train"][0]) // hps.fraction) #DELETE ME #print(len(dataset["dev"])) #n = len(dataset["dev"]) // hps.fraction #print(n) #new_dev = (dataset["dev"][0][:n], dataset["dev"][1][:n]) #dataset["dev"] = new_dev #END DELTE ME training_params.set('logdir', new_folder) log_parameters_info(hps, training_params) log_dataset_info(hps, full_dataset, dataset) batch_sampler = BatchSampler(dataset, src_lang=src, tgt_lang=tgt, batch_size=training_params.batch_size) searchengine = None if hps.tm_init: logger.info("Using translation memory.") if hps.tm_overfitted: logger.info("Using overfitted search engine.") searchengine = OverfittedSearchEngine() else: logger.info("Using normal search engine.") searchengine = SearchEngine() searchengine.load(hps.tm_bin_path) searchengine.set_dictionary(full_dataset) if hps.tm_50_50: searchengine.remove_train_set(dataset["train"][0]) writer = SummaryWriter(log_dir=training_params.logdir) model = s2s.Seq2Seq(src, tgt, hps, training_params, writer=writer, searchengine=searchengine) with open(os.path.join(new_folder, "model.meta"), "w") as fout: fout.write(repr(model)) translate_to_all_loggers(repr(model)) trainer = Trainer(model, batch_sampler, hps, training_params, writer, searchengine) trainer.train() writer.export_scalars_to_json("./all_scalars.json") writer.close()
def main(): sys.stderr.write(sys.executable + "\n") os.environ['CUDA_VISIBLE_DEVICES'] = "6" if torch.cuda.is_available(): sys.stderr.write("Running on device {}\n".format( torch.cuda.current_device())) parser = argparse.ArgumentParser() parser.add_argument('--params') parser.add_argument('--training_params') parser.add_argument('--model_state') #parser.add_argument('--src_path') #parser.add_argument('--tgt_path') parser.add_argument('--dataset') args = parser.parse_args() hps = s2s.Seq2Seq.get_default_hparams() if args.params: with open(args.params, encoding='utf-8') as fin: hps = hps.parse_dict(json.load(fin)) training_params = Trainer.get_default_hparams() if args.training_params: with open(args.training_params, encoding='utf-8') as fin: training_params = training_params.parse_dict(json.load(fin)) full_dataset, src, tgt = read_problem(args.dataset, n_sents=None) searchengine = None if hps.tm_init: if hps.tm_overfitted: searchengine = OverfittedSearchEngine() else: searchengine = SearchEngine() searchengine.load(hps.tm_bin_path) searchengine.set_dictionary(full_dataset) sys.stderr.write("Using searchengine: {}\n".format( searchengine.__class__)) #dataset_name = "../../preprocessed/he-en/" #full_dataset, src, tgt = read_problem(dataset_name, n_sents=None) #src = Lang(args.src_path) #tgt = Lang(args.tgt_path) training_params = training_params.parse_dict({'batch_size': 32}) writer = WriterMock() model = s2s.Seq2Seq(src, tgt, hps, training_params, writer=writer, searchengine=searchengine) if training_params.use_cuda: model = model.cuda() sys.stderr.write("Loading the model state\n") state_dict = torch.load(args.model_state) model.load_state_dict(state_dict) model.eval() sys.stderr.write("Ready!\n") sents = [] for sent in sys.stdin: sent = sent.strip().split() sents.append(sent) sents = np.array(sents) for sent in run_translation(src, model, sents, training_params): print(sent)