コード例 #1
0
def run_train(model, flags_obj, master, is_chief):
    utils_context.training = True

    batch_size = flags_obj.batch_size // model.batch_size_ratio
    source = euler_ops.sample_node(count=batch_size,
                                   node_type=flags_obj.train_node_type)
    source.set_shape([batch_size])

    sim_outputs, cor_outputs = model(source)
    _, sim_loss, metric_name, sim_metric = sim_outputs
    _, cor_loss, ___________, cor_metric = cor_outputs
    loss = sim_loss + cor_loss

    optimizer_class = optimizers.get(flags_obj.optimizer)
    optimizer = optimizer_class(flags_obj.learning_rate)
    global_step = tf.train.get_or_create_global_step()
    train_op = optimizer.minimize(loss, global_step=global_step)

    hooks = []

    tensor_to_log = {
        'step': global_step,
        'loss': loss,
        'sim_loss': sim_loss,
        'cor_loss': cor_loss,
        'sim_metric': sim_metric,
        'cor_metric': cor_metric
    }
    hooks.append(
        tf.train.LoggingTensorHook(tensor_to_log,
                                   every_n_iter=flags_obj.log_steps))

    num_steps = int(
        (flags_obj.max_id + 1) // flags_obj.batch_size * flags_obj.num_epochs)
    hooks.append(tf.train.StopAtStepHook(last_step=num_steps))

    extra_param_name = '_'.join(map(str, flags_obj.fanouts))
    output_dir = ckpt_dir = '{}/{}/{}_{}_{}/'.format(flags_obj.model_dir,
                                                     flags_obj.model,
                                                     extra_param_name,
                                                     flags_obj.dim,
                                                     flags_obj.embedding_dim)
    print("output dir: {}".format(output_dir))

    if len(flags_obj.worker_hosts) == 0 or flags_obj.task_index == 1:
        hooks.append(
            tf.train.ProfilerHook(save_secs=180, output_dir=output_dir))
    if len(flags_obj.worker_hosts):
        hooks.append(utils_hooks.SyncExitHook(len(flags_obj.worker_hosts)))
    if hasattr(model, 'make_session_run_hook'):
        hooks.append(model.make_session_run_hook())

    with tf.train.MonitoredTrainingSession(master=master,
                                           is_chief=is_chief,
                                           checkpoint_dir=ckpt_dir,
                                           log_step_count_steps=None,
                                           hooks=hooks,
                                           config=config) as sess:
        while not sess.should_stop():
            sess.run(train_op)
コード例 #2
0
ファイル: run_loop.py プロジェクト: zqz981/euler
def run_train(model, flags_obj, master, is_chief):
    utils_context.training = True

    batch_size = flags_obj.batch_size // model.batch_size_ratio
    if flags_obj.model == 'line' or flags_obj.model == 'randomwalk':
        source = euler_ops.sample_node(count=batch_size,
                                       node_type=flags_obj.all_node_type)
    else:
        source = euler_ops.sample_node(count=batch_size,
                                       node_type=flags_obj.train_node_type)
    source.set_shape([batch_size])
    # dataset = tf.data.TextLineDataset(flags_obj.id_file)
    # dataset = dataset.map(
    #     lambda id_str: tf.string_to_number(id_str, out_type=tf.int64))
    # dataset = dataset.shuffle(buffer_size=20000)
    # dataset = dataset.batch(batch_size)
    # dataset = dataset.repeat(flags_obj.num_epochs)
    # source = dataset.make_one_shot_iterator().get_next()
    _, loss, metric_name, metric = model(source)

    optimizer_class = optimizers.get(flags_obj.optimizer)
    optimizer = optimizer_class(learning_rate=flags_obj.learning_rate)
    global_step = tf.train.get_or_create_global_step()
    train_op = optimizer.minimize(loss, global_step=global_step)

    hooks = []

    tensor_to_log = {'step': global_step, 'loss': loss, metric_name: metric}
    hooks.append(
        tf.train.LoggingTensorHook(tensor_to_log,
                                   every_n_iter=flags_obj.log_steps))

    num_steps = int(
        (flags_obj.max_id + 1) // batch_size * flags_obj.num_epochs)
    hooks.append(tf.train.StopAtStepHook(last_step=num_steps))

    if len(flags_obj.worker_hosts) == 0 or flags_obj.task_index == 1:
        hooks.append(
            tf.train.ProfilerHook(save_secs=180,
                                  output_dir=flags_obj.model_dir))
    if len(flags_obj.worker_hosts):
        hooks.append(utils_hooks.SyncExitHook(len(flags_obj.worker_hosts)))
    if hasattr(model, 'make_session_run_hook'):
        hooks.append(model.make_session_run_hook())

    with tf.train.MonitoredTrainingSession(master=master,
                                           is_chief=is_chief,
                                           checkpoint_dir=flags_obj.model_dir,
                                           log_step_count_steps=None,
                                           hooks=hooks) as sess:
        while not sess.should_stop():
            sess.run(train_op)