def cli(args): logging.basicConfig( format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) utils.report_gpu_info() config = cli_config_and_parse_args(args) print(config.__dict__) print("Continue? Starts in 10s. [Ctrl-C] to stop.") # let's save some seconds select.select([sys.stdin], [], [], 3) utils.seed_modules( config, numpy_seed=10000, torch_seed=20000, torcu_cuda_seed_all=2192, cuda_deterministic=True, kgegrok_base_seed=30000, cuda_benchmark=config.cudnn_benchmark) triple_source = data.TripleSource(config.data_dir, config.triple_order, config.triple_delimiter) model_class = utils.load_class_from_module(config.model, 'kgegrok.models', 'kgegrok.text_models') evaluator = evaluation.ParallelEvaluator(config, triple_source) # maybe roughly 10s now select.select([sys.stdin], [], [], 4) if config.mode == 'train': optimizer_class = utils.load_class_from_module(config.optimizer, 'torch.optim') cli_train(triple_source, config, model_class, optimizer_class) elif config.mode == 'train_validate': optimizer_class = utils.load_class_from_module(config.optimizer, 'torch.optim') cli_train_and_validate(triple_source, config, model_class, optimizer_class, evaluator) elif config.mode == 'test': cli_test(triple_source, config, model_class, evaluator) elif config.mode == 'demo_prediction': cli_demo_prediction(triple_source, config, model_class) elif config.mode == 'profile': optimizer_class = utils.load_class_from_module(config.optimizer, 'torch.optim') cli_profile(triple_source, config, model_class, optimizer_class) else: raise RuntimeError("Wrong mode {} selected.".format(config.mode))
'triple_order': 'htr', 'delimiter': ' ', 'batch_size': 10000, 'num_workers': 10, 'num_evaluation_workers': 10, 'model': "TransE", 'optimizer': "SGD", 'margin': 1.0, 'epochs': 1000, 'lambda_': 0.001, 'evaluation_load_factor': 0.01 } config = utils.build_config_with_dict(default_args) print(config.__dict__) triple_source = data.TripleSource(config.data_dir, config.triple_order, config.triple_delimiter) model_class = utils.load_class_from_module(config.model, 'kgegrok.models', 'kgegrok.text_models') evaluator = evaluation.ParallelEvaluator(config, triple_source) for changed_config in ParameterGrid(grid): d = {} d.update(config.__dict__) changed_config[ 'name'] = "TransE-FB15k237-neg_e_{}-neg_r_{}-ent_dim_{}-alpha_{}".format( changed_config['negative_entity'], changed_config['negative_relation'], changed_config['entity_embedding_dimension'], changed_config['alpha']) d.update(changed_config) if os.path.exists(
def source(config): triple_dir = 'kgegrok/tests/fixtures/triples' return data.TripleSource(triple_dir, 'hrt', ' ')