예제 #1
0
def main(_):
    with tf.Graph().as_default() as single_gpu_graph:
        model_config = configuration.model_config(
            input_file_pattern=FLAGS.input_file_pattern,
            batch_size=FLAGS.batch_size)
        training_config = configuration.training_config()
        model = skip_thoughts_model.SkipThoughtsModel(model_config,
                                                      mode="train")
        model.build()

        # Setup learning rate
        if training_config.learning_rate_decay_factor > 0:
            learning_rate = tf.train.exponential_decay(
                learning_rate=float(training_config.learning_rate),
                global_step=model.global_step,
                decay_steps=training_config.learning_rate_decay_steps,
                decay_rate=training_config.learning_rate_decay_factor,
                staircase=False)
        else:
            learning_rate = tf.constant(training_config.learning_rate)

        optimizer = tf.train.AdamOptimizer(learning_rate)

        train_tensor = tf.contrib.slim.learning.create_train_op(
            total_loss=model.total_loss,
            optimizer=optimizer,
            global_step=model.global_step,
            clip_gradient_norm=training_config.clip_gradient_norm)

    def run(sess, num_iters, tensor_or_op_name_to_replica_names, num_workers,
            worker_id, num_replicas_per_worker):
        fetches = {
            'global_step':
            tensor_or_op_name_to_replica_names[model.global_step.name][0],
            'cost':
            tensor_or_op_name_to_replica_names[model.total_loss.name][0],
            'train_op':
            tensor_or_op_name_to_replica_names[train_tensor.name][0],
        }

        start = time.time()
        for i in range(num_iters):
            results = sess.run(fetches)
            if i % FLAGS.log_frequency == 0:
                end = time.time()
                throughput = float(FLAGS.log_frequency) / float(end - start)
                parallax.log.info(
                    "global step: %d, loss: %f, throughput: %f steps/sec" %
                    (results['global_step'], results['cost'], throughput))
                start = time.time()

    parallax.parallel_run(single_gpu_graph,
                          run,
                          FLAGS.resource_info_file,
                          FLAGS.max_steps,
                          sync=FLAGS.sync,
                          parallax_config=parallax_config.build_config())
def main(_):
    # Build benchmark_cnn model
    params = benchmark_cnn.make_params_from_flags()
    params, sess_config = benchmark_cnn.setup(params)
    bench = benchmark_cnn.BenchmarkCNN(params)

    # Print informaton
    tfversion = cnn_util.tensorflow_version_tuple()
    log_fn('TensorFlow:  %i.%i' % (tfversion[0], tfversion[1]))
    bench.print_info()

    # Build single-GPU benchmark_cnn model
    with tf.Graph().as_default() as single_gpu_graph:
        bench.build_model()

    def run(sess, num_iters, tensor_or_op_name_to_replica_names, num_workers,
            worker_id, num_replicas_per_worker):
        fetches = {
            'global_step':
            tensor_or_op_name_to_replica_names[bench.global_step.name][0],
            'cost':
            tensor_or_op_name_to_replica_names[bench.cost.name][0],
            'train_op':
            tensor_or_op_name_to_replica_names[bench.train_op.name][0],
        }
        if isinstance(bench.lr, tf.Tensor):
            fetches['lr'] = tensor_or_op_name_to_replica_names[
                bench.lr.name][0]

        start = time.time()
        for i in range(num_iters):
            results = sess.run(fetches)
            if i % FLAGS.log_frequency == 0:
                end = time.time()
                throughput = float(FLAGS.log_frequency) / float(end - start)
                parallax.log.info(
                    "global step: %d, lr: %f, loss: %f, "
                    "throughput: %f steps/sec" %
                    (results['global_step'], results['lr'] if 'lr' in results
                     else bench.lr, results['cost'], throughput))
                start = time.time()

    config = parallax_config.build_config()
    config.sess_config = sess_config

    parallax.parallel_run(single_gpu_graph,
                          run,
                          FLAGS.resource_info_file,
                          FLAGS.max_steps,
                          sync=FLAGS.sync,
                          parallax_config=config)
예제 #3
0
def main(_):
    with tf.Graph().as_default() as single_gpu_graph:
        global_step = tf.train.get_or_create_global_step()
        x = tf.placeholder(tf.float32, shape=(2))
        y = tf.placeholder(tf.float32, shape=(1))

        w = tf.get_variable(name='w', shape=(2, 1))
        b = tf.get_variable(name='b', shape=(1))

        pred = tf.nn.bias_add(tf.matmul(tf.expand_dims(x, axis=0), w), b)
        loss = tf.reduce_sum(tf.pow(pred - tf.expand_dims(y, axis=0), 2)) / 2

        optimizer = tf.train.GradientDescentOptimizer(args.learning_rate)
        train_op = optimizer.minimize(loss, global_step=global_step)

        # init = tf.global_variables_initializer()

    def run(sess, num_iters, op_name_to_replica_op_names, num_workers,
            worker_id, num_replicas_per_worker):
        cursor = 0
        for i in range(num_iters):
            feed_dict = {}
            for replica in range(num_replicas_per_worker):
                feed_dict[op_name_to_replica_op_names[x.name][replica]] = \
                    train_x[cursor % num_samples]
                feed_dict[op_name_to_replica_op_names[y.name][replica]] = \
                    train_y[cursor % num_samples]
                cursor += 1
            fetches = {
                'global_step':
                op_name_to_replica_op_names[global_step.name][0],
                'loss': loss.name,
                'train_op': train_op.name
            }
            results = sess.run(fetches, feed_dict=feed_dict)

            if i % 5 == 0:
                print("global step: %d, loss: %f" %
                      (results['global_step'], results['loss']))

    # with tf.Session() as sess:
    #   sess.run(init)
    #   _run(sess, 100, {x: [x], y: [y], loss: loss, train_op: train_op}, 1)

    resource_info = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                 'resource_info')
    parallax.parallel_run(single_gpu_graph, run, resource_info, 1000)
예제 #4
0
def train(hps):
    """Training loop."""

    single_gpu_graph = tf.Graph()
    with single_gpu_graph.as_default():

        images, labels = cifar_input.build_input('cifar10',
                                                 FLAGS.train_data_path,
                                                 hps.batch_size, 'train')
        model = resnet_model.ResNet(hps, images, labels, 'train')
        model.build_graph()

        truth = tf.argmax(model.labels, axis=1)
        predictions = tf.argmax(model.predictions, axis=1)
        precision = tf.reduce_mean(tf.to_float(tf.equal(predictions, truth)))

    ########################################################################
    #### FIXME: Get session for distributed environments using Parallax ####
    #### Pass parallax_config as an argument                            ####
    ########################################################################

    parallax_sess, num_workers, worker_id, num_replicas_per_worker = \
          parallax.parallel_run(single_gpu_graph,
                                FLAGS.resource_info_file,
                                sync=FLAGS.sync,
                                parallax_config=parallax_config.build_config())

    for i in range(350000):

        _, global_step, cost, precision_ = \
            parallax_sess.run([model.train_op, model.global_step, model.cost, precision])

        if i % 10 == 0:
            print('step: %d, loss: %.3f, precision: %.3f' %
                  (global_step[0], cost[0], precision_[0]))

            # Tuning learning rate
            train_step = global_step[0]
            if train_step < 10000:
                lrn_rate = 0.1
            elif train_step < 15000:
                lrn_rate = 0.01
            elif train_step < 20000:
                lrn_rate = 0.001
            else:
                lrn_rate = 0.0001
            feed_dict = {model.lrn_rate: []}
            for worker in range(num_replicas_per_worker):
                feed_dict[model.lrn_rate].append(lrn_rate)
            parallax_sess.run(model.global_step, feed_dict=feed_dict)
예제 #5
0
def main(_):
    single_gpu_graph = tf.Graph()
    with single_gpu_graph.as_default():
        global_step = tf.train.get_or_create_global_step()
        x = tf.placeholder(tf.float32, shape=(2))
        y = tf.placeholder(tf.float32, shape=(1))

        w = tf.get_variable(name='w', shape=(2, 1))
        b = tf.get_variable(name='b', shape=(1))

        pred = tf.nn.bias_add(tf.matmul(tf.expand_dims(x, axis=0), w), b)
        loss = tf.reduce_sum(tf.pow(pred - tf.expand_dims(y, axis=0), 2)) / 2

        optimizer = tf.train.GradientDescentOptimizer(args.learning_rate)
        train_op = optimizer.minimize(loss, global_step=global_step)

        # init = tf.global_variables_initializer()

    def run(sess, num_workers, worker_id, num_replicas_per_worker):
        cursor = 0
        for i in range(1000):
            feed_dict = {}
            feed_dict[x] = [train_x[(cursor + j) % num_samples] for j in \
                range(num_replicas_per_worker)]
            feed_dict[y] = [train_y[(cursor + j) % num_samples] for j in \
                range(num_replicas_per_worker)]
            cursor += num_replicas_per_worker
            fetches = {
                'global_step': global_step,
                'loss': loss,
                'train_op': train_op
            }

            results = sess.run(fetches, feed_dict=feed_dict)

            if i % 5 == 0:
                print("global step: %d, loss: %f" %
                      (results['global_step'][0], results['loss'][0]))

    resource_info = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                 'resource_info')
    sess, num_workers, worker_id, num_replicas_per_worker = \
        parallax.parallel_run(single_gpu_graph, resource_info)
    run(sess, num_workers, worker_id, num_replicas_per_worker)
def main(_):
    # Build benchmark_cnn model
    params = benchmark_cnn.make_params_from_flags()
    params, sess_config = benchmark_cnn.setup(params)
    bench = benchmark_cnn.BenchmarkCNN(params)

    # Print informaton
    tfversion = cnn_util.tensorflow_version_tuple()
    log_fn('TensorFlow:  %i.%i' % (tfversion[0], tfversion[1]))
    bench.print_info()

    # Build single-GPU benchmark_cnn model
    single_gpu_graph = tf.Graph()
    with single_gpu_graph.as_default():
        bench.build_model()

    config = parallax_config.build_config()
    config.sess_config = sess_config

    sess, num_workers, worker_id, num_replicas_per_worker = \
        parallax.parallel_run(single_gpu_graph,
                              FLAGS.resource_info_file,
                              sync=FLAGS.sync,
                              parallax_config=config)

    fetches = {
        'global_step': bench.global_step,
        'cost': bench.cost,
        'train_op': bench.train_op,
    }

    start = time.time()
    for i in range(FLAGS.max_steps):
        results = sess.run(fetches)
        if (i + 1) % FLAGS.log_frequency == 0:
            end = time.time()
            throughput = float(FLAGS.log_frequency) / float(end - start)
            parallax.log.info(
                "global step: %d, loss: %f, throughput: %f steps/sec" %
                (results['global_step'][0] + 1, results['cost'][0],
                 throughput))
            start = time.time()
예제 #7
0
    ops = rnn()
    train_op = ops['train_op']
    loss = ops['loss']
    acc = ops['acc']
    x = ops['images']
    y = ops['labels']
    is_training = ops['is_training']

parallax_config = parallax.Config()
ckpt_config = parallax.CheckPointConfig(ckpt_dir='parallax_ckpt',
                                        save_ckpt_steps=1)
parallax_config.ckpt_config = ckpt_config

sess, num_workers, worker_id, num_replicas_per_worker = parallax.parallel_run(
    single_gpu_graph,
    FLAGS.resource_info_file,
    sync=FLAGS.sync,
    parallax_config=parallax_config)

start = time.time()
for i in range(FLAGS.max_steps):
  batch = mnist.train.next_batch(FLAGS.batch_size, shuffle=False)
  _, loss_ = sess.run([train_op, loss], feed_dict={x: [batch[0]],
                                                   y: [batch[1]],
                                                   is_training: [True]})
  if i % FLAGS.log_frequency == 0:
    end = time.time()
    throughput = float(FLAGS.log_frequency) / float(end - start)
    acc_ = sess.run(acc, feed_dict={x: [mnist.test.images],
                                    y: [mnist.test.labels],
                                    is_training: [False]})[0]
예제 #8
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    processors = {
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "mrpc": MrpcProcessor,
        "xnli": XnliProcessor,
    }

    tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                  FLAGS.init_checkpoint)

    if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
        raise ValueError(
            "At least one of `do_train`, `do_eval` or `do_predict' must be True."
        )

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    tf.gfile.MakeDirs(FLAGS.output_dir)

    task_name = FLAGS.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()

    label_list = processor.get_labels()

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    tpu_cluster_resolver = None
    if FLAGS.use_tpu and FLAGS.tpu_name:
        tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=FLAGS.master,
        model_dir=FLAGS.output_dir,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_tpu_cores,
            per_host_input_for_training=is_per_host))

    train_examples = None
    num_train_steps = None
    num_warmup_steps = None
    if FLAGS.do_train:
        train_examples = processor.get_train_examples(FLAGS.data_dir)
        num_train_steps = int(
            len(train_examples) / FLAGS.train_batch_size *
            FLAGS.num_train_epochs)
        num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
    """
  model_fn = model_fn_builder(
      bert_config=bert_config,
      num_labels=len(label_list),
      init_checkpoint=FLAGS.init_checkpoint,
      learning_rate=FLAGS.learning_rate,
      num_train_steps=num_train_steps,
      num_warmup_steps=num_warmup_steps,
      use_tpu=FLAGS.use_tpu,
      use_one_hot_embeddings=FLAGS.use_tpu)
  """

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    """
  estimator = tf.contrib.tpu.TPUEstimator(
      use_tpu=FLAGS.use_tpu,
      model_fn=model_fn,
      config=run_config,
      train_batch_size=FLAGS.train_batch_size,
      eval_batch_size=FLAGS.eval_batch_size,
      predict_batch_size=FLAGS.predict_batch_size)
  """

    train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
    file_based_convert_examples_to_features(train_examples, label_list,
                                            FLAGS.max_seq_length, tokenizer,
                                            train_file)
    tf.logging.info("***** Running training *****")
    tf.logging.info("  Num examples = %d", len(train_examples))
    tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)
    tf.logging.info("  Num steps = %d", num_train_steps)

    single_gpu_graph = tf.Graph()
    with single_gpu_graph.as_default():
        train_input_fn = file_based_input_fn_builder(
            input_file=train_file,
            seq_length=FLAGS.max_seq_length,
            is_training=True,
            drop_remainder=True)

        dataset = train_input_fn(FLAGS.train_batch_size)
        features = dataset.make_one_shot_iterator().get_next()

        model_fn = model_fn_builder(bert_config=bert_config,
                                    num_labels=len(label_list),
                                    init_checkpoint=FLAGS.init_checkpoint,
                                    learning_rate=FLAGS.learning_rate,
                                    num_train_steps=num_train_steps,
                                    num_warmup_steps=num_warmup_steps,
                                    use_tpu=FLAGS.use_tpu,
                                    use_one_hot_embeddings=FLAGS.use_tpu)

        total_loss = model_fn(features)
        train_op, _ = optimization.create_optimizer(total_loss,
                                                    FLAGS.learning_rate,
                                                    num_train_steps,
                                                    num_warmup_steps,
                                                    FLAGS.use_tpu)

    def run(sess, num_workers, worker_id, num_replicas_per_worker):
        print('num_workers: {} | worker_id: {}'.format(num_workers, worker_id))
        sys.stdout.flush()
        dataset = train_input_fn(
            FLAGS.train_batch_size * num_replicas_per_worker, num_workers,
            worker_id)
        features = dataset.make_one_shot_iterator().get_next()

        total_loss = model_fn(features)
        train_op, global_step = optimization.create_optimizer(
            total_loss, FLAGS.learning_rate, num_train_steps, num_warmup_steps,
            FLAGS.use_tpu)

        _global_step = 0
        for i in range(num_train_steps):
            print("[before session run - worker:{} | global step:{}]".format(
                worker_id, _global_step))
            sys.stdout.flush()
            loss, _global_step, _ = sess.run(
                [total_loss, global_step, train_op])
            print("[after session run - worker:{} | global step:{} - loss:{}]".
                  format(worker_id, _global_step, loss))
            sys.stdout.flush()

            if i % 100 == 0:
                print('[worker:{} | global step:{} - loss:{}]'.format(
                    worker_id, _global_step, loss))


    sess, num_workers, worker_id, num_replicas_per_worker = \
        parallax.parallel_run(single_gpu_graph,
                              FLAGS.resource_info_file,
                              sync=True,
                              parallax_config=parallax_config.build_config())
    run(sess, num_workers, worker_id, num_replicas_per_worker)
예제 #9
0
def main(_):
    default_hparams = nmt.create_hparams(FLAGS)
    ## Train / Decode
    out_dir = FLAGS.out_dir
    if not tf.gfile.Exists(out_dir): tf.gfile.MakeDirs(out_dir)

    # Load hparams.
    hparams = nmt.create_or_load_hparams(out_dir,
                                         default_hparams,
                                         FLAGS.hparams_path,
                                         save_hparams=False)

    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats
    avg_ckpts = hparams.avg_ckpts

    if not hparams.attention:
        model_creator = nmt_model.Model
    else:  # Attention
        if (hparams.encoder_type == "gnmt"
                or hparams.attention_architecture in ["gnmt", "gnmt_v2"]):
            model_creator = gnmt_model.GNMTModel
        elif hparams.attention_architecture == "standard":
            model_creator = attention_model.AttentionModel
        else:
            raise ValueError("Unknown attention architecture %s" %
                             hparams.attention_architecture)

    train_model =\
        model_helper.create_train_model(model_creator, hparams, scope=None)

    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement,
        num_intra_threads=1,
        num_inter_threads=36)

    def run(train_sess, num_workers, worker_id, num_replicas_per_worker):

        # Random
        random_seed = FLAGS.random_seed
        if random_seed is not None and random_seed > 0:
            utils.print_out("# Set random seed to %d" % random_seed)
            random.seed(random_seed + worker_id)
            np.random.seed(random_seed + worker_id)

        # Log and output files
        log_file = os.path.join(out_dir, "log_%d" % time.time())
        log_f = tf.gfile.GFile(log_file, mode="a")
        utils.print_out("# log_file=%s" % log_file, log_f)

        global_step = train_sess.run(train_model.model.global_step)[0]
        last_stats_step = global_step

        # This is the training loop.
        stats, info, start_train_time = before_train(train_model, train_sess,
                                                     global_step, hparams,
                                                     log_f,
                                                     num_replicas_per_worker)

        epoch_steps = FLAGS.epoch_size / (FLAGS.batch_size * num_workers *
                                          num_replicas_per_worker)

        for i in range(FLAGS.max_steps):
            ### Run a step ###
            start_time = time.time()
            if hparams.epoch_step != 0 and hparams.epoch_step % epoch_steps == 0:
                hparams.epoch_step = 0
                skip_count = train_model.skip_count_placeholder
                feed_dict = {}
                feed_dict[skip_count] = [
                    0 for i in range(num_replicas_per_worker)
                ]
                init = train_model.iterator.initializer
                train_sess.run(init, feed_dict=feed_dict)

            if worker_id == 0:
                results = train_sess.run([
                    train_model.model.update, train_model.model.train_loss,
                    train_model.model.predict_count,
                    train_model.model.train_summary,
                    train_model.model.global_step,
                    train_model.model.word_count, train_model.model.batch_size,
                    train_model.model.grad_norm,
                    train_model.model.learning_rate
                ])
                step_result = [r[0] for r in results]

            else:
                global_step, _ = train_sess.run(
                    [train_model.model.global_step, train_model.model.update])
            hparams.epoch_step += 1

            if worker_id == 0:
                # Process step_result, accumulate stats, and write summary
                global_step, info["learning_rate"], step_summary = \
                    train.update_stats(stats, start_time, step_result)

                # Once in a while, we print statistics.
                if global_step - last_stats_step >= steps_per_stats:
                    last_stats_step = global_step
                    is_overflow = train.process_stats(stats, info, global_step,
                                                      steps_per_stats, log_f)
                    train.print_step_info("  ", global_step, info,
                                          train._get_best_results(hparams),
                                          log_f)
                    if is_overflow:
                        break

                    # Reset statistics
                    stats = train.init_stats()

    sess, num_workers, worker_id, num_replicas_per_worker = \
        parallax.parallel_run(train_model.graph,
                              FLAGS.resource_info_file,
                              sync=FLAGS.sync,
                              parallax_config=parallax_config.build_config())
    run(sess, num_workers, worker_id, num_replicas_per_worker)
예제 #10
0
  train_datsets = parallax.shard.shard(train_datasets) # sharding 추가해줘야함. 
  iterator = train_datasets.make_one_shot_iterator()
  inputs, labels = iterator.get_next()
  single_worker_model = build_and_compile_cnn_model(forward_only=True)

  logits = single_worker_model(inputs, training=True)
  accuracy = tf.metrics.accuracy(
        labels=labels, predictions=tf.argmax(logits, axis=1))[1]
  loss = tf.keras.losses.sparse_categorical_crossentropy(labels, logits)
  optimizer=tf.train.GradientDescentOptimizer(learning_rate=0.001) 
  step = tf.train.get_or_create_global_step()
  train_op = optimizer.minimize(loss, step)
  loss = tf.reduce_mean(loss)

parallax_config = parallax.Config()
parallax_config.run_option = FLAGS.run_option
parallax_config.ckpt_config = parallax.CheckPointConfig(ckpt_dir=FLAGS.ckpt_dir,
                                  save_ckpt_steps=FLAGS.save_ckpt_steps)

sess, num_workers, worker_id, num_replicas_per_worker = \
        parallax.parallel_run(single_gpu_graph,
                              FLAGS.resource_info_file,
                              parallax_config=parallax_config)

for i in range(NUM_EPOCHS * STEPS_PER_EPOCH / NUM_WORKERS):
  step_, loss_, accuracy_, _ = sess.run([step, loss, accuracy, train_op])
  if i % 10 == 0:
    print('step:%d, loss: %2f, accuracy: %2f' % (step_[0], loss_[0], accuracy_[0]))
print('작업이 깔끔하게 끝났습니다.')

예제 #11
0
def main(_):

    vocab = Vocabulary.from_file(
        os.path.join(FLAGS.datadir, "1b_word_vocab.txt"))
    dataset = Dataset(
        vocab,
        os.path.join(FLAGS.datadir,
                     "training-monolingual.tokenized.shuffled/*"))

    single_gpu_graph = tf.Graph()
    with single_gpu_graph.as_default():
        with tf.variable_scope("model"):
            model = language_model_graph.build_model()

    def run(sess, num_workers, worker_id, num_replicas_per_worker):

        state_c = []
        state_h = []

        if len(state_c) == 0:
            state_c.extend([
                np.zeros([FLAGS.batch_size, model.state_size],
                         dtype=np.float32)
                for _ in range(num_replicas_per_worker)
            ])
            state_h.extend([
                np.zeros([FLAGS.batch_size, model.projected_size],
                         dtype=np.float32)
                for _ in range(num_replicas_per_worker)
            ])

        prev_global_step = sess.run(model.global_step)[0]
        prev_time = time.time()
        data_iterator = dataset.iterate_forever(
            FLAGS.batch_size * num_replicas_per_worker, FLAGS.num_steps,
            num_workers, worker_id)
        fetches = {
            'global_step': model.global_step,
            'loss': model.loss,
            'train_op': model.train_op,
            'final_state_c': model.final_state_c,
            'final_state_h': model.final_state_h
        }

        for local_step in range(FLAGS.max_steps):
            if FLAGS.use_synthetic:
                x = np.random.randint(
                    low=0,
                    high=model.vocab_size,
                    size=(FLAGS.batch_size * num_replicas_per_worker,
                          FLAGS.num_steps))
                y = np.random.randint(
                    low=0,
                    high=model.vocab_size,
                    size=(FLAGS.batch_size * num_replicas_per_worker,
                          FLAGS.num_steps))
                w = np.ones((FLAGS.batch_size * num_replicas_per_worker,
                             FLAGS.num_steps))
            else:
                x, y, w = next(data_iterator)
            feeds = {}
            feeds[model.x] = np.split(x, num_replicas_per_worker)
            feeds[model.y] = np.split(y, num_replicas_per_worker)
            feeds[model.w] = np.split(w, num_replicas_per_worker)
            feeds[model.initial_state_c] = state_c
            feeds[model.initial_state_h] = state_h
            fetched = sess.run(fetches, feeds)

            state_c = fetched['final_state_c']
            state_h = fetched['final_state_h']

            if local_step % FLAGS.log_frequency == 0:
                cur_time = time.time()
                elapsed_time = cur_time - prev_time
                num_words = FLAGS.batch_size * FLAGS.num_steps
                wps = (fetched['global_step'][0] -
                       prev_global_step) * num_words / elapsed_time
                prev_global_step = fetched['global_step'][0]
                parallax.log.info(
                    "Iteration %d, time = %.2fs, wps = %.0f, train loss = %.4f"
                    % (fetched['global_step'][0], cur_time - prev_time, wps,
                       fetched['loss'][0]))
                prev_time = cur_time

    sess, num_workers, worker_id, num_replicas_per_worker = \
        parallax.parallel_run(single_gpu_graph,
                              FLAGS.resource_info_file,
                              sync=FLAGS.sync,
                              parallax_config=parallax_config.build_config())
    run(sess, num_workers, worker_id, num_replicas_per_worker)