예제 #1
0
def train(config, graph):
  def model_fn():
    return  BipartiteGraphSage(graph,
                               config['batch_size'],
                               config['hidden_dim'],
                               config['output_dim'],
                               config['hops_num'],
                               config['u_neighs_num'],
                               config['i_neighs_num'],
                               u_features_num=config['u_features_num'],
                               u_categorical_attrs_desc=config['u_categorical_attrs_desc'],
                               i_features_num=config['i_features_num'],
                               i_categorical_attrs_desc=config['i_categorical_attrs_desc'],
                               neg_num=config['neg_num'],
                               use_input_bn=config['use_input_bn'],
                               act=config['act'],
                               agg_type=config['agg_type'],
                               need_dense=config['need_dense'],
                               in_drop_rate=config['drop_out'],
                               ps_hosts=config['ps_hosts'])
  trainer = gl.LocalTFTrainer(model_fn,
                              epoch=config['epoch'],
                              optimizer=gl.get_tf_optimizer(
                                  config['learning_algo'],
                                  config['learning_rate'],
                                  config['weight_decay']))

  trainer.train()

  u_embs = trainer.get_node_embedding("u")
  np.save('u_emb', u_embs)

  i_embs = trainer.get_node_embedding("i")
  np.save('i_emb', i_embs)
예제 #2
0
파일: train.py 프로젝트: yuinm/graph-learn
def train(config, graph):
    def model_fn():
        return LINE(graph, config['node_count'], config['hidden_dim'],
                    config['neg_num'], config['batch_size'], config['s2h'],
                    config['ps_hosts'], config['proximity'],
                    config['node_type'], config['edge_type'])

    trainer = gl.LocalTFTrainer(model_fn,
                                epoch=config['epoch'],
                                optimizer=gl.get_tf_optimizer(
                                    config['learning_algo'],
                                    config['learning_rate']))
    trainer.train()
    embs = trainer.get_node_embedding()
    np.save(config['emb_save_dir'], embs)
예제 #3
0
파일: train.py 프로젝트: zymale/graph-learn
def train(config, graph):
  def model_fn():
    return DeepWalk(graph,
                    config['walk_len'],
                    config['window_size'],
                    config['node_count'],
                    config['hidden_dim'],
                    config['neg_num'],
                    config['batch_size'],
                    s2h=config['s2h'],
                    ps_hosts=config['ps_hosts'],
                    temperature=config['temperature'])
  trainer = gl.LocalTFTrainer(model_fn,
                              epoch=config['epoch'],
                              optimizer=gl.get_tf_optimizer(
                                config['learning_algo'],
                                config['learning_rate']))
  trainer.train()
  embs = trainer.get_node_embedding()
  np.save(config['emb_save_dir'], embs)
예제 #4
0
파일: train.py 프로젝트: yuinm/graph-learn
def train(config, graph):
    def model_fn():
        return TransE(graph,
                      config['neg_num'],
                      config['batch_size'],
                      config['margin'],
                      config['entity_num'],
                      config['relation_num'],
                      config['hidden_dim'],
                      s2h=config['s2h'],
                      ps_hosts=config['ps_hosts'])

    trainer = gl.LocalTFTrainer(model_fn,
                                epoch=config['epoch'],
                                optimizer=gl.get_tf_optimizer(
                                    config['learning_algo'],
                                    config['learning_rate']))
    trainer.train()
    entity_embs = trainer.get_node_embedding('entity')
    relation_embs = trainer.get_node_embedding('relation')
    return entity_embs, relation_embs
예제 #5
0
def train(config, graph):
  def model_fn():
    return GCN(graph,
               config['class_num'],
               config['features_num'],
               config['batch_size'],
               val_batch_size=config['val_batch_size'],
               test_batch_size=config['test_batch_size'],
               categorical_attrs_desc=config['categorical_attrs_desc'],
               hidden_dim=config['hidden_dim'],
               in_drop_rate=config['in_drop_rate'],
               hops_num=config['hops_num'],
               neighs_num=config['neighs_num'],
               full_graph_mode = config['full_graph_mode'])

  trainer = gl.LocalTFTrainer(model_fn,
                              epoch=config['epoch'],
                              optimizer=gl.get_tf_optimizer(
                                  config['learning_algo'],
                                  config['learning_rate'],
                                  config['weight_decay']))
  trainer.train_and_evaluate()
예제 #6
0
def train(config, graph):
  def model_fn():
    return GraphSage(graph,
                     config['class_num'],
                     config['features_num'],
                     config['batch_size'],
                     categorical_attrs_desc=config['categorical_attrs_desc'],
                     hidden_dim=config['hidden_dim'],
                     in_drop_rate=config['in_drop_rate'],
                     neighs_num=config['neighs_num'],
                     full_graph_mode=config['full_graph_mode'],
                     unsupervised=config['unsupervised'],
                     neg_num=config['neg_num'],
                     agg_type=config['agg_type'])
  trainer = gl.LocalTFTrainer(model_fn,
                              epoch=config['epoch'],
                              optimizer=gl.get_tf_optimizer(
                                  config['learning_algo'],
                                  config['learning_rate'],
                                  config['weight_decay']))
  trainer.train()
  embs = trainer.get_node_embedding()
  np.save(config['emb_save_dir'], embs)