def train(config, graph): def model_fn(): return GraphSage( graph, config['class_num'], config['features_num'], config['batch_size'], val_batch_size=config['val_batch_size'], test_batch_size=config['test_batch_size'], categorical_attrs_desc=config['categorical_attrs_desc'], hidden_dim=config['hidden_dim'], in_drop_rate=config['in_drop_rate'], neighs_num=config['neighs_num'], agg_type=config['agg_type'], full_graph_mode=config['full_graph_mode']) ps_hosts = FLAGS.ps_hosts.split(",") worker_hosts = FLAGS.worker_hosts.split(",") cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts}) trainer = gl.DistTFTrainer(model_fn, cluster_spec=cluster, task_name=FLAGS.job_name, task_index=FLAGS.task_index, epoch=config['epoch'], optimizer=gl.get_tf_optimizer( config['learning_algo'], config['learning_rate'], config['weight_decay'])) if FLAGS.job_name == 'worker': # also graph-learn client in this example. trainer.train_and_evaluate() else: trainer.join()
def train(config, graph): def model_fn(): return BipartiteGraphSage(graph, config['batch_size'], config['hidden_dim'], config['output_dim'], config['hops_num'], config['u_neighs_num'], config['i_neighs_num'], u_features_num=config['u_features_num'], u_categorical_attrs_desc=config['u_categorical_attrs_desc'], i_features_num=config['i_features_num'], i_categorical_attrs_desc=config['i_categorical_attrs_desc'], neg_num=config['neg_num'], use_input_bn=config['use_input_bn'], act=config['act'], agg_type=config['agg_type'], need_dense=config['need_dense'], in_drop_rate=config['drop_out'], ps_hosts=config['ps_hosts']) trainer = gl.LocalTFTrainer(model_fn, epoch=config['epoch'], optimizer=gl.get_tf_optimizer( config['learning_algo'], config['learning_rate'], config['weight_decay'])) trainer.train() u_embs = trainer.get_node_embedding("u") np.save('u_emb', u_embs) i_embs = trainer.get_node_embedding("i") np.save('i_emb', i_embs)
def train(config, graph): def model_fn(): return LINE(graph, config['node_count'], config['hidden_dim'], config['neg_num'], config['batch_size'], config['s2h'], config['ps_hosts'], config['proximity'], config['node_type'], config['edge_type']) trainer = gl.LocalTFTrainer(model_fn, epoch=config['epoch'], optimizer=gl.get_tf_optimizer( config['learning_algo'], config['learning_rate'])) trainer.train() embs = trainer.get_node_embedding() np.save(config['emb_save_dir'], embs)
def train(config, graph): def model_fn(): return DeepWalk(graph, config['walk_len'], config['window_size'], config['node_count'], config['hidden_dim'], config['neg_num'], config['batch_size'], s2h=config['s2h'], ps_hosts=config['ps_hosts'], temperature=config['temperature']) trainer = gl.LocalTFTrainer(model_fn, epoch=config['epoch'], optimizer=gl.get_tf_optimizer( config['learning_algo'], config['learning_rate'])) trainer.train() embs = trainer.get_node_embedding() np.save(config['emb_save_dir'], embs)
def train(config, graph): def model_fn(): return TransE(graph, config['neg_num'], config['batch_size'], config['margin'], config['entity_num'], config['relation_num'], config['hidden_dim'], s2h=config['s2h'], ps_hosts=config['ps_hosts']) trainer = gl.LocalTFTrainer(model_fn, epoch=config['epoch'], optimizer=gl.get_tf_optimizer( config['learning_algo'], config['learning_rate'])) trainer.train() entity_embs = trainer.get_node_embedding('entity') relation_embs = trainer.get_node_embedding('relation') return entity_embs, relation_embs
def train(config, graph): def model_fn(): return GCN(graph, config['class_num'], config['features_num'], config['batch_size'], val_batch_size=config['val_batch_size'], test_batch_size=config['test_batch_size'], categorical_attrs_desc=config['categorical_attrs_desc'], hidden_dim=config['hidden_dim'], in_drop_rate=config['in_drop_rate'], hops_num=config['hops_num'], neighs_num=config['neighs_num'], full_graph_mode = config['full_graph_mode']) trainer = gl.LocalTFTrainer(model_fn, epoch=config['epoch'], optimizer=gl.get_tf_optimizer( config['learning_algo'], config['learning_rate'], config['weight_decay'])) trainer.train_and_evaluate()
def train(config, graph): def model_fn(): return GraphSage(graph, config['class_num'], config['features_num'], config['batch_size'], categorical_attrs_desc=config['categorical_attrs_desc'], hidden_dim=config['hidden_dim'], in_drop_rate=config['in_drop_rate'], neighs_num=config['neighs_num'], full_graph_mode=config['full_graph_mode'], unsupervised=config['unsupervised'], neg_num=config['neg_num'], agg_type=config['agg_type']) trainer = gl.LocalTFTrainer(model_fn, epoch=config['epoch'], optimizer=gl.get_tf_optimizer( config['learning_algo'], config['learning_rate'], config['weight_decay'])) trainer.train() embs = trainer.get_node_embedding() np.save(config['emb_save_dir'], embs)