def build_config(): ckpt_config = parallax.CheckPointConfig( ckpt_dir=FLAGS.ckpt_dir, save_ckpt_steps=calculate_ckpt_steps()) ps_config = parallax.PSConfig(replicate_variables=FLAGS.replicate_variables, protocol=FLAGS.protocol, local_aggregation=FLAGS.local_aggregation, boundary_among_servers=FLAGS.boundary_among_servers, boundary_between_workers_and_servers=\ FLAGS.boundary_between_workers_and_servers) mpi_config = parallax.MPIConfig(use_allgatherv=FLAGS.use_allgatherv, mpirun_options=FLAGS.mpirun_options) parallax_config = parallax.Config() parallax_config.run_option = FLAGS.run_option parallax_config.average_sparse = False parallax_config.communication_config = parallax.CommunicationConfig( ps_config, mpi_config) parallax_config.ckpt_config = ckpt_config def get_profile_steps(): if not FLAGS.profile_steps: return [] FLAGS.profile_steps = FLAGS.profile_steps.strip() return [int(step) for step in FLAGS.profile_steps.split(',')] profile_config = parallax.ProfileConfig(profile_dir=FLAGS.profile_dir, profile_steps=get_profile_steps()) parallax_config.profile_config = profile_config parallax_config.redirect_path = FLAGS.redirect_path return parallax_config
def build_config(): ckpt_config = parallax.CheckPointConfig(ckpt_dir=FLAGS.ckpt_dir, save_ckpt_steps=calculate_ckpt_steps()) ps_config = parallax.PSConfig(replicate_variables=FLAGS.replicate_variables, protocol=FLAGS.protocol) mpi_config = parallax.MPIConfig(use_allgatherv=FLAGS.use_allgatherv, mpirun_options=FLAGS.mpirun_options) parallax_config = parallax.Config() parallax_config.run_option = FLAGS.run_option parallax_config.average_sparse = False parallax_config.communication_config = parallax.CommunicationConfig(ps_config, mpi_config) parallax_config.ckpt_config=ckpt_config parallax_config.redirect_path = FLAGS.redirect_path return parallax_config
def build_config(): ckpt_config = parallax.CheckPointConfig( ckpt_dir=FLAGS.ckpt_dir, save_ckpt_steps=calculate_ckpt_steps()) ps_config = parallax.PSConfig(replicate_variables=FLAGS.replicate_variables, protocol=FLAGS.protocol, local_aggregation=FLAGS.local_aggregation, boundary_among_servers=FLAGS.boundary_among_servers, boundary_between_workers_and_servers=\ FLAGS.boundary_between_workers_and_servers) mpi_config = parallax.MPIConfig(mpirun_options=FLAGS.mpirun_options) def get_profile_steps(): if FLAGS.profile_steps: FLAGS.profile_steps = FLAGS.profile_steps.strip() return [int(step) for step in FLAGS.profile_steps.split(',')] return None def get_profile_range(): if FLAGS.profile_range: FLAGS.profile_range = FLAGS.profile_range.strip() splits = FLAGS.profile_range.split(',') return (int(splits[0]), int(splits[1])) return None profile_config = parallax.ProfileConfig( profile_dir=FLAGS.profile_dir, profile_steps=get_profile_steps(), profile_range=get_profile_range(), profile_worker=FLAGS.profile_worker) parallax_config = parallax.Config() parallax_config.run_option = FLAGS.run_option parallax_config.average_sparse = False parallax_config.communication_config = parallax.CommunicationConfig( ps_config, mpi_config) parallax_config.ckpt_config = ckpt_config parallax_config.profile_config = profile_config parallax_config.redirect_path = FLAGS.redirect_path parallax_config.export_graph_path = FLAGS.export_graph_path return parallax_config
mnist = input_data.read_data_sets('MNIST_data', one_hot=True) # Build single-GPU rnn model single_gpu_graph = tf.Graph() with single_gpu_graph.as_default(): ops = rnn() train_op = ops['train_op'] loss = ops['loss'] acc = ops['acc'] x = ops['images'] y = ops['labels'] is_training = ops['is_training'] parallax_config = parallax.Config() ckpt_config = parallax.CheckPointConfig(ckpt_dir='parallax_ckpt', save_ckpt_steps=1) parallax_config.ckpt_config = ckpt_config sess, num_workers, worker_id, num_replicas_per_worker = parallax.parallel_run( single_gpu_graph, FLAGS.resource_info_file, sync=FLAGS.sync, parallax_config=parallax_config) start = time.time() for i in range(FLAGS.max_steps): batch = mnist.train.next_batch(FLAGS.batch_size, shuffle=False) _, loss_ = sess.run([train_op, loss], feed_dict={x: [batch[0]], y: [batch[1]], is_training: [True]}) if i % FLAGS.log_frequency == 0:
"""Number of iterations to run for each workers.""") tf.app.flags.DEFINE_integer('log_frequency', 50, """How many steps between two logs.""") tf.app.flags.DEFINE_integer('batch_size', 32, """Batch size""") tf.app.flags.DEFINE_boolean('sync', True, '') mnist = input_data.read_data_sets('MNIST_data', one_hot=True) # Build single-GPU LeNet model single_gpu_graph = tf.Graph() with single_gpu_graph.as_default(): model = LeNet() parallax_config = parallax.Config() ckpt_config = parallax.CheckPointConfig(ckpt_dir='ckpt', save_ckpt_steps=FLAGS.log_frequency) parallax_config.ckpt_config = ckpt_config sess, num_workers, worker_id, num_replicas_per_worker = parallax.parallel_run( single_gpu_graph, FLAGS.resource_info_file, sync=FLAGS.sync, parallax_config=parallax_config) start = time.time() for i in range(FLAGS.max_steps): batch = mnist.train.next_batch(FLAGS.batch_size, shuffle=False) _, loss = sess.run([model.train_op, model.loss], feed_dict={model.images: [batch[0]], model.labels: [batch[1]], model.is_training: [True]}) if i % FLAGS.log_frequency == 0:
train_datsets = parallax.shard.shard(train_datasets) # sharding 추가해줘야함. iterator = train_datasets.make_one_shot_iterator() inputs, labels = iterator.get_next() single_worker_model = build_and_compile_cnn_model(forward_only=True) logits = single_worker_model(inputs, training=True) accuracy = tf.metrics.accuracy( labels=labels, predictions=tf.argmax(logits, axis=1))[1] loss = tf.keras.losses.sparse_categorical_crossentropy(labels, logits) optimizer=tf.train.GradientDescentOptimizer(learning_rate=0.001) step = tf.train.get_or_create_global_step() train_op = optimizer.minimize(loss, step) loss = tf.reduce_mean(loss) parallax_config = parallax.Config() parallax_config.run_option = FLAGS.run_option parallax_config.ckpt_config = parallax.CheckPointConfig(ckpt_dir=FLAGS.ckpt_dir, save_ckpt_steps=FLAGS.save_ckpt_steps) sess, num_workers, worker_id, num_replicas_per_worker = \ parallax.parallel_run(single_gpu_graph, FLAGS.resource_info_file, parallax_config=parallax_config) for i in range(NUM_EPOCHS * STEPS_PER_EPOCH / NUM_WORKERS): step_, loss_, accuracy_, _ = sess.run([step, loss, accuracy, train_op]) if i % 10 == 0: print('step:%d, loss: %2f, accuracy: %2f' % (step_[0], loss_[0], accuracy_[0])) print('작업이 깔끔하게 끝났습니다.')