示例#1
0
def cli(args):
  logging.basicConfig(
      format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
      datefmt='%m/%d/%Y %H:%M:%S',
      level=logging.INFO)
  utils.report_gpu_info()

  config = cli_config_and_parse_args(args)
  print(config.__dict__)
  print("Continue? Starts in 10s. [Ctrl-C] to stop.")

  # let's save some seconds
  select.select([sys.stdin], [], [], 3)

  utils.seed_modules(
      config,
      numpy_seed=10000,
      torch_seed=20000,
      torcu_cuda_seed_all=2192,
      cuda_deterministic=True,
      kgegrok_base_seed=30000,
      cuda_benchmark=config.cudnn_benchmark)

  triple_source = data.TripleSource(config.data_dir, config.triple_order,
                                    config.triple_delimiter)
  model_class = utils.load_class_from_module(config.model, 'kgegrok.models',
                                             'kgegrok.text_models')

  evaluator = evaluation.ParallelEvaluator(config, triple_source)
  # maybe roughly 10s now
  select.select([sys.stdin], [], [], 4)

  if config.mode == 'train':
    optimizer_class = utils.load_class_from_module(config.optimizer,
                                                   'torch.optim')
    cli_train(triple_source, config, model_class, optimizer_class)
  elif config.mode == 'train_validate':
    optimizer_class = utils.load_class_from_module(config.optimizer,
                                                   'torch.optim')
    cli_train_and_validate(triple_source, config, model_class,
                           optimizer_class, evaluator)
  elif config.mode == 'test':
    cli_test(triple_source, config, model_class, evaluator)
  elif config.mode == 'demo_prediction':
    cli_demo_prediction(triple_source, config, model_class)
  elif config.mode == 'profile':
    optimizer_class = utils.load_class_from_module(config.optimizer,
                                                   'torch.optim')
    cli_profile(triple_source, config, model_class, optimizer_class)
  else:
    raise RuntimeError("Wrong mode {} selected.".format(config.mode))
        'triple_order': 'htr',
        'delimiter': ' ',
        'batch_size': 10000,
        'num_workers': 10,
        'num_evaluation_workers': 10,
        'model': "TransE",
        'optimizer': "SGD",
        'margin': 1.0,
        'epochs': 1000,
        'lambda_': 0.001,
        'evaluation_load_factor': 0.01
    }
    config = utils.build_config_with_dict(default_args)
    print(config.__dict__)

    triple_source = data.TripleSource(config.data_dir, config.triple_order,
                                      config.triple_delimiter)
    model_class = utils.load_class_from_module(config.model, 'kgegrok.models',
                                               'kgegrok.text_models')

    evaluator = evaluation.ParallelEvaluator(config, triple_source)
    for changed_config in ParameterGrid(grid):
        d = {}
        d.update(config.__dict__)
        changed_config[
            'name'] = "TransE-FB15k237-neg_e_{}-neg_r_{}-ent_dim_{}-alpha_{}".format(
                changed_config['negative_entity'],
                changed_config['negative_relation'],
                changed_config['entity_embedding_dimension'],
                changed_config['alpha'])
        d.update(changed_config)
        if os.path.exists(
示例#3
0
def source(config):
    triple_dir = 'kgegrok/tests/fixtures/triples'
    return data.TripleSource(triple_dir, 'hrt', ' ')