Exemple #1
0
def build_config():

    ckpt_config = parallax.CheckPointConfig(
        ckpt_dir=FLAGS.ckpt_dir, save_ckpt_steps=calculate_ckpt_steps())
    ps_config = parallax.PSConfig(replicate_variables=FLAGS.replicate_variables,
                                  protocol=FLAGS.protocol,
                                  local_aggregation=FLAGS.local_aggregation,
                                  boundary_among_servers=FLAGS.boundary_among_servers,
                                  boundary_between_workers_and_servers=\
                                  FLAGS.boundary_between_workers_and_servers)
    mpi_config = parallax.MPIConfig(use_allgatherv=FLAGS.use_allgatherv,
                                    mpirun_options=FLAGS.mpirun_options)
    parallax_config = parallax.Config()
    parallax_config.run_option = FLAGS.run_option
    parallax_config.average_sparse = False
    parallax_config.communication_config = parallax.CommunicationConfig(
        ps_config, mpi_config)
    parallax_config.ckpt_config = ckpt_config

    def get_profile_steps():
        if not FLAGS.profile_steps:
            return []
        FLAGS.profile_steps = FLAGS.profile_steps.strip()
        return [int(step) for step in FLAGS.profile_steps.split(',')]

    profile_config = parallax.ProfileConfig(profile_dir=FLAGS.profile_dir,
                                            profile_steps=get_profile_steps())
    parallax_config.profile_config = profile_config
    parallax_config.redirect_path = FLAGS.redirect_path

    return parallax_config
def build_config():

    ckpt_config = parallax.CheckPointConfig(ckpt_dir=FLAGS.ckpt_dir,
                                            save_ckpt_steps=calculate_ckpt_steps())
    ps_config = parallax.PSConfig(replicate_variables=FLAGS.replicate_variables,
                                  protocol=FLAGS.protocol)
    mpi_config = parallax.MPIConfig(use_allgatherv=FLAGS.use_allgatherv,
                                    mpirun_options=FLAGS.mpirun_options)
    parallax_config = parallax.Config()
    parallax_config.run_option = FLAGS.run_option
    parallax_config.average_sparse = False
    parallax_config.communication_config = parallax.CommunicationConfig(ps_config, mpi_config)
    parallax_config.ckpt_config=ckpt_config
    parallax_config.redirect_path = FLAGS.redirect_path

    return parallax_config
Exemple #3
0
def build_config():

    ckpt_config = parallax.CheckPointConfig(
        ckpt_dir=FLAGS.ckpt_dir, save_ckpt_steps=calculate_ckpt_steps())
    ps_config = parallax.PSConfig(replicate_variables=FLAGS.replicate_variables,
                                  protocol=FLAGS.protocol,
                                  local_aggregation=FLAGS.local_aggregation,
                                  boundary_among_servers=FLAGS.boundary_among_servers,
                                  boundary_between_workers_and_servers=\
                                  FLAGS.boundary_between_workers_and_servers)
    mpi_config = parallax.MPIConfig(mpirun_options=FLAGS.mpirun_options)

    def get_profile_steps():
        if FLAGS.profile_steps:
            FLAGS.profile_steps = FLAGS.profile_steps.strip()
            return [int(step) for step in FLAGS.profile_steps.split(',')]
        return None

    def get_profile_range():
        if FLAGS.profile_range:
            FLAGS.profile_range = FLAGS.profile_range.strip()
            splits = FLAGS.profile_range.split(',')
            return (int(splits[0]), int(splits[1]))
        return None

    profile_config = parallax.ProfileConfig(
        profile_dir=FLAGS.profile_dir,
        profile_steps=get_profile_steps(),
        profile_range=get_profile_range(),
        profile_worker=FLAGS.profile_worker)

    parallax_config = parallax.Config()
    parallax_config.run_option = FLAGS.run_option
    parallax_config.average_sparse = False
    parallax_config.communication_config = parallax.CommunicationConfig(
        ps_config, mpi_config)
    parallax_config.ckpt_config = ckpt_config
    parallax_config.profile_config = profile_config
    parallax_config.redirect_path = FLAGS.redirect_path
    parallax_config.export_graph_path = FLAGS.export_graph_path

    return parallax_config
Exemple #4
0
tf.app.flags.DEFINE_boolean('sync', True, '')

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

# Build single-GPU rnn model
single_gpu_graph = tf.Graph()
with single_gpu_graph.as_default():
    ops = rnn()
    train_op = ops['train_op']
    loss = ops['loss']
    acc = ops['acc']
    x = ops['images']
    y = ops['labels']
    is_training = ops['is_training']

parallax_config = parallax.Config()
ckpt_config = parallax.CheckPointConfig(ckpt_dir='parallax_ckpt',
                                        save_ckpt_steps=1)
parallax_config.ckpt_config = ckpt_config

sess, num_workers, worker_id, num_replicas_per_worker = parallax.parallel_run(
    single_gpu_graph,
    FLAGS.resource_info_file,
    sync=FLAGS.sync,
    parallax_config=parallax_config)

start = time.time()
for i in range(FLAGS.max_steps):
  batch = mnist.train.next_batch(FLAGS.batch_size, shuffle=False)
  _, loss_ = sess.run([train_op, loss], feed_dict={x: [batch[0]],
                                                   y: [batch[1]],