def memory_test(size):
  """Evaluates gradient, returns memory in MB's and gradient eval time in
  seconds."""
  global sess, RESNET_SIZE

  RESNET_SIZE = size
  
  start_time0 = time.perf_counter()
  tf.reset_default_graph()
  
  loss = create_loss()

  start_time = time.perf_counter()
  grads = tf.group(tf.gradients(loss, tf.trainable_variables()))

  sess = create_session()
  sessrun(tf.global_variables_initializer())
  times = []
  memories = []
  for i in range(3):
    start_time = time.perf_counter()
    sessrun(grads)
    elapsed_time = time.perf_counter() - start_time
    times.append(elapsed_time)
    mem_use = mem_util.peak_memory(run_metadata)['/gpu:0']/1e6
    memories.append(mem_use)

  return np.min(memories), np.min(times)
def gradient_memory_measure_mb():
  """Evaluates gradient, prints peak memory in MBs."""
  global sess
  
  start_time0 = time.perf_counter()
  loss = create_loss()

  if DUMP_GRAPHDEF:
    open('graphdef.txt', 'w').write(str(tf.get_default_graph().as_graph_def()))

  # use block_layer1, block_layer2, block_layer3 as checkpoint nodes
  g = tf.get_default_graph()
  ops = g.get_operations()
  for op in ge.filter_ops_from_regex(ops, "block_layer"):
    tf.add_to_collection("checkpoints", op.outputs[0])

  start_time = time.perf_counter()
  grads = tf.gradients(loss, tf.trainable_variables())
  
  start_time = time.perf_counter()
  sess = create_session()
  start_time = time.perf_counter()
  sessrun(tf.global_variables_initializer())
  start_time = time.perf_counter()
  sessrun(grads)
  start_time = time.perf_counter()
  sessrun(grads)

  mem_use = mem_util.peak_memory(run_metadata)['/gpu:0']/1e6
  
  print("Memory used: %.2f MB "%(mem_use))
  total_time = time.perf_counter()-start_time0
  print("Total time: %.2f sec"%(total_time))
  assert total_time < 100
  return mem_use
def gradient_memory_mbs():
  """Evaluates gradient, prints peak memory."""
  start_time0 = time.perf_counter()
  start_time = start_time0
  tf.reset_default_graph()
  tf.set_random_seed(1)
  
  train_op, loss = create_train_op_and_loss()
  print("Graph construction: %.2f ms" %(1000*(time.perf_counter()-start_time)))

  g = tf.get_default_graph()
  ops = g.get_operations()
  
  for op in ge.filter_ops_from_regex(ops, "block_layer"):
    tf.add_to_collection("checkpoints", op.outputs[0])

  sess = create_session()
  sessrun(tf.global_variables_initializer())
  start_time = time.perf_counter()
  sessrun(train_op)
  start_time = time.perf_counter()
  print("loss %f"%(sess.run(loss),))
  
  print("Compute time: %.2f ms" %(1000*(time.perf_counter()-start_time)))

  mem_use = mem_util.peak_memory(run_metadata)['/gpu:0']/1e6
  print("Memory used: %.2f MB "%(mem_use))
  total_time = time.perf_counter()-start_time0
  assert total_time < 100
  return mem_use
Example #4
0
def memory_test(size):
  """Evaluates gradient, returns memory in MB's and gradient eval time in
  seconds."""
  global sess, RESNET_SIZE

  RESNET_SIZE = size
  
  start_time0 = time.perf_counter()
  tf.reset_default_graph()
  
  loss = create_loss()

  start_time = time.perf_counter()
  grads = tf.group(tf.gradients(loss, tf.trainable_variables()))

  sess = create_session()
  sessrun(tf.global_variables_initializer())
  times = []
  memories = []
  for i in range(3):
    start_time = time.perf_counter()
    sessrun(grads)
    elapsed_time = time.perf_counter() - start_time
    times.append(elapsed_time)
    mem_use = mem_util.peak_memory(run_metadata)['/gpu:0']/1e6
    memories.append(mem_use)

  return np.min(memories), np.min(times)
Example #5
0
def gradient_memory_mbs():
    """Evaluates gradient, prints peak memory."""
    start_time0 = time.perf_counter()
    start_time = start_time0
    tf.reset_default_graph()
    tf.set_random_seed(1)

    train_op, loss = create_train_op_and_loss()
    print("Graph construction: %.2f ms" % (1000 *
                                           (time.perf_counter() - start_time)))

    g = tf.get_default_graph()
    ops = g.get_operations()

    for op in ge.filter_ops_from_regex(ops, "block_layer"):
        tf.add_to_collection("checkpoints", op.outputs[0])

    sess = create_session()
    sessrun(tf.global_variables_initializer())
    start_time = time.perf_counter()
    sessrun(train_op)
    start_time = time.perf_counter()
    print("loss %f" % (sess.run(loss), ))

    print("Compute time: %.2f ms" % (1000 *
                                     (time.perf_counter() - start_time)))

    mem_use = mem_util.peak_memory(run_metadata)['/gpu:0'] / 1e6
    print("Memory used: %.2f MB " % (mem_use))
    total_time = time.perf_counter() - start_time0
    assert total_time < 100
    return mem_use
def gradient_memory_measure_mb():
  """Evaluates gradient, prints peak memory in MBs."""
  global sess
  
  start_time0 = time.perf_counter()
  loss = create_loss()

  if DUMP_GRAPHDEF:
    open('graphdef.txt', 'w').write(str(tf.get_default_graph().as_graph_def()))

  # use block_layer1, block_layer2, block_layer3 as checkpoint nodes
  g = tf.get_default_graph()
  ops = g.get_operations()
  for op in ge.filter_ops_from_regex(ops, "block_layer"):
    tf.add_to_collection("checkpoints", op.outputs[0])

  start_time = time.perf_counter()
  grads = tf.gradients(loss, tf.trainable_variables())
  
  start_time = time.perf_counter()
  sess = create_session()
  start_time = time.perf_counter()
  sessrun(tf.global_variables_initializer())
  start_time = time.perf_counter()
  sessrun(grads)
  start_time = time.perf_counter()
  sessrun(grads)

  mem_use = mem_util.peak_memory(run_metadata)['/gpu:0']/1e6
  
  print("Memory used: %.2f MB "%(mem_use))
  total_time = time.perf_counter()-start_time0
  print("Total time: %.2f sec"%(total_time))
  assert total_time < 100
  return mem_use
def test_peak_gpu():
  global sess, run_metadata
  tf.reset_default_graph()
  
  assert tf.test.is_gpu_available(), "This test requires GPU"
  # create backprop for A0->A1->A2->A3
  with tf.device("/cpu:0"):
    b0 = _chain_backprop(3)

  # create backprop for A0->A1->A2->A3
  with tf.device("/gpu:0"):
    c0 = _chain_backprop(3)

  sess = create_session()
  sessrun(tf.group(b0.op, c0.op))
  peak_cpu = mem_util.peak_memory(run_metadata)['/cpu:0']
  peak_gpu = mem_util.peak_memory(run_metadata)['/gpu:0']
  assert abs(peak_cpu - 4e6) < 1e4
  assert abs(peak_gpu - 4e6) < 1e4
def test_peak_gpu():
    global sess, run_metadata
    tf.reset_default_graph()

    assert tf.test.is_gpu_available(), "This test requires GPU"
    # create backprop for A0->A1->A2->A3
    with tf.device("/cpu:0"):
        b0 = _chain_backprop(3)

    # create backprop for A0->A1->A2->A3
    with tf.device("/gpu:0"):
        c0 = _chain_backprop(3)

    sess = create_session()
    sessrun(tf.group(b0.op, c0.op))
    peak_cpu = mem_util.peak_memory(run_metadata)['/cpu:0']
    peak_gpu = mem_util.peak_memory(run_metadata)['/gpu:0']
    assert abs(peak_cpu - 4e6) < 1e4
    assert abs(peak_gpu - 4e6) < 1e4
def test_peak():
  global sess, run_metadata
  tf.reset_default_graph()
  
  # create backprop for A0->A1->A2->A3
  with tf.device("/cpu:0"):
    b0 = _chain_backprop(3)

  # this needs 4 MB of memory
  # A0/A1 share memory since A0 is not consumed by anyone, therefore at peak
  # we have A1,A2,A3,B0 stored in memory
  
  sess = create_session()
  sessrun(b0.op)
  peak_cpu = mem_util.peak_memory(run_metadata)['/cpu:0']
  assert abs(peak_cpu - 4e6) < 1e4
def test_peak():
    global sess, run_metadata
    tf.reset_default_graph()

    # create backprop for A0->A1->A2->A3
    with tf.device("/cpu:0"):
        b0 = _chain_backprop(3)

    # this needs 4 MB of memory
    # A0/A1 share memory since A0 is not consumed by anyone, therefore at peak
    # we have A1,A2,A3,B0 stored in memory

    sess = create_session()
    sessrun(b0.op)
    peak_cpu = mem_util.peak_memory(run_metadata)['/cpu:0']
    assert abs(peak_cpu - 4e6) < 1e4
Example #11
0
def gradient_memory_measure_mb():
    """Evaluates gradient, prints peak memory in MBs."""
    global sess

    assert tf.test.is_gpu_available()
    tf.reset_default_graph()
    tf.set_random_seed(1)
    np.random.seed(1)

    start_time0 = time.perf_counter()
    loss = create_loss()

    if DUMP_GRAPHDEF:
        open('graphdef.txt',
             'w').write(str(tf.get_default_graph().as_graph_def()))

    # use block_layer1, block_layer2, block_layer3 as checkpoint nodes
    # this is only active when checkpoint strategy=collection is used
    g = tf.get_default_graph()
    ops = g.get_operations()
    for op in ge.filter_ops_from_regex(ops, "block_layer"):
        tf.add_to_collection("checkpoints", op.outputs[0])

    start_time = time.perf_counter()
    grads = tf.gradients(loss, tf.trainable_variables())

    vars = tf.trainable_variables()
    grads = tf.gradients(loss, vars)
    grads_and_vars = zip(grads, vars)
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=1e-5)

    train_op = optimizer.apply_gradients(grads_and_vars)

    start_time = time.perf_counter()
    sess = create_session()
    start_time = time.perf_counter()
    sessrun(tf.global_variables_initializer())
    start_time = time.perf_counter()
    sessrun(grads)
    start_time = time.perf_counter()
    sessrun(grads)

    # without checkpoints we expect following sequence of losses
    # Loss 35.49785, memory 626.32 MB
    # Loss 32.18098, memory 626.31 MB
    # Loss 29.42088, memory 628.37 MB
    # Loss 28.29715, memory 628.37 MB
    # Loss 26.50492, memory 628.37 MB
    # Loss 25.59675, memory 628.37 MB
    # Loss 24.45332, memory 628.37 MB
    # Loss 23.91770, memory 628.37 MB
    # Loss 22.29025, memory 626.31 MB
    # Loss 22.42356, memory 626.31 MB

    loss0 = sess.run(loss)
    assert loss0 > 35
    for i in range(10):
        sessrun(train_op)
        loss0 = sess.run(loss)
        mem_use = mem_util.peak_memory(run_metadata)['/gpu:0'] / 1e6
        print("Loss %.5f, memory %.2f MB" % (loss0, mem_use))

    assert loss0 < 22

    mem_use = mem_util.peak_memory(run_metadata)['/gpu:0'] / 1e6

    print("Memory used: %.2f MB " % (mem_use))
    total_time = time.perf_counter() - start_time0
    print("Total time: %.2f sec" % (total_time))
    assert total_time < 100
    return mem_use
def train_mnist():
    global sess

    # restrict to cpu:0
    tf.reset_default_graph()
    tf.set_random_seed(1)
    np.random.seed(1)
    tf_dev = tf.device(TEST_DEVICE)
    tf_dev.__enter__()

    #  FLAGS = parse_flags()
    # Train the model

    # replace Dataset ops with constant images because gradient rewriting
    # tries to differentiate graphs containing IteratorGetNext
    # TODO: make it work with Dataset ops
    images = tf.Variable(tf.random_uniform((FLAGS_batch_size, 28**2)))
    labels = tf.Variable(
        tf.concat(
            [tf.ones((FLAGS_batch_size, 1)),
             tf.zeros((FLAGS_batch_size, 9))],
            axis=1))

    def train_input_fn():
        dataset = train_dataset(FLAGS_data_dir)
        dataset = dataset.batch(FLAGS_batch_size)
        (images, labels) = dataset.make_one_shot_iterator().get_next()
        num_images = FLAGS_batch_size
        return (images[:num_images], labels[:num_images])

    if USE_REAL_DATA:
        images, labels = train_input_fn()


#    images = tf.stop_gradient(images)
#    labels = tf.stop_gradient(labels)

    logits = mnist_model(images, tf.estimator.ModeKeys.TRAIN, 'channels_last')
    cross_entropy = tf.losses.softmax_cross_entropy(logits=logits,
                                                    onehot_labels=labels)
    loss = cross_entropy
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=1e-2)

    vars = tf.trainable_variables()
    grads = tf.gradients(loss, vars)
    grads_and_vars = zip(grads, vars)
    train_op = optimizer.apply_gradients(grads_and_vars)

    sess = create_session()
    sess.run(tf.global_variables_initializer())
    print("Loss %.5f" % (sess.run(loss)))
    for i in range(10):
        sessrun(train_op)
        mem_use = mem_util.peak_memory(run_metadata)[TEST_DEVICE] / 1e6
        print("Loss %.5f, memory %.2f MB" % (sess.run(loss), mem_use))

    # should print something like this for actual dataset
    # 2.12764
    # 1.87759
    # 1.54445
    # 1.29149
    # 1.18474
    # 0.884424
    # 0.69454
    # 0.770236
    # 0.629259
    # 0.654465

    assert sess.run(loss) < 100
Example #13
0
def cpu_peak():
    return mem_util.peak_memory(run_metadata)['/cpu:0']
Example #14
0
def test_chain_constant_memory():
  """Test that backprop on a chain of length n takes constant memory."""
  global sess, run_metadata

  from tensorflow.python.ops import gen_math_ops
  tanh_grad = gen_math_ops._tanh_grad

  size_mbs = 1   # size of each node
  size = size_mbs * 250000  

  gg = tf.get_default_graph()
  
  tf_dev = tf.device('/cpu:0')
  tf_dev.__enter__()
  
  n = 20  
  A = [None]*(n+1)
  A[0] = tf.fill((size,), 1.0, name="A0")
  for L in range(1, n+1):
    name = "A"+str(L)
    A[L] = tf.tanh(A[L-1], name=name)

  B = [None]*(n+1)
  B[n] = tf.fill((size,), 1.0, name="B"+str(n))
    
  run_after(B[n].op, A[n].op)
  for L in range(n-1, -1, -1):
    name = "B"+str(L)
    B[L] = tanh_grad(A[L+1], B[L+1], name=name)

  # for each op, obtain steps during which any output of this op is consumed
  execution_order = linearize_lib.get_execution_order(B[0])
  consuming_schedule = OrderedDict()
  for op in gg.get_operations():
    consuming_ops = OrderedSet()  # OrderedSet for determinism
    for output in op.outputs:
      consuming_ops.update(output.consumers())
    consuming_schedule[op] = [execution_order.index(c) for c in consuming_ops]

  for step, op in enumerate(execution_order):
    for op_input in op.inputs:
      # get all the times when this input is consumed
      consume_times = consuming_schedule[op_input.op]
      assert step in consume_times

      # if it's been consumed before, save memory by recomputing it
      consumed_before = len([t for t in consume_times if t<step]) > 0
      if consumed_before:
        assert step>0
        # want recomputation to happen as late as possible, schedule to run
        # it after the op that was scheduled to execute right before this op
        prev_op = execution_order[step-1]
        new_input = recompute_tensor(op_input, known_values=[A[0]],
                                     preceding_op=prev_op)
        replace_input(op, old_input=op_input, new_input=new_input)

  sess = create_session()
  sessrun(B[0].op)
  peak_cpu = mem_util.peak_memory(run_metadata)['/cpu:0']

   # chain of length 20, backprop should use 3 MB instead of 20
   
  print("Memory to backprop on chain of length %d: %.1f MB" %(n, peak_cpu/1e6,))
  assert abs(peak_cpu - 3e6) < 1e4
def train_mnist():
  global sess

  # restrict to cpu:0
  tf.reset_default_graph()
  tf.set_random_seed(1)
  np.random.seed(1)
  tf_dev = tf.device(TEST_DEVICE)
  tf_dev.__enter__()

  #  FLAGS = parse_flags()
  # Train the model


  # replace Dataset ops with constant images because gradient rewriting
  # tries to differentiate graphs containing IteratorGetNext
  # TODO: make it work with Dataset ops
  images = tf.Variable(tf.random_uniform((FLAGS_batch_size, 28**2)))
  labels = tf.Variable(tf.concat([tf.ones((FLAGS_batch_size, 1)),
                                  tf.zeros((FLAGS_batch_size, 9))], axis=1))
  def train_input_fn():
    dataset = train_dataset(FLAGS_data_dir)
    dataset = dataset.batch(FLAGS_batch_size)
    (images, labels) = dataset.make_one_shot_iterator().get_next()
    num_images = FLAGS_batch_size
    return (images[:num_images], labels[:num_images])

  if USE_REAL_DATA:
    images, labels = train_input_fn()
#    images = tf.stop_gradient(images)
#    labels = tf.stop_gradient(labels)

  
  logits = mnist_model(images, tf.estimator.ModeKeys.TRAIN, 'channels_last')
  cross_entropy = tf.losses.softmax_cross_entropy(logits=logits,
                                                  onehot_labels=labels)
  loss = cross_entropy
  optimizer = tf.train.GradientDescentOptimizer(learning_rate=1e-2)

  vars = tf.trainable_variables()
  grads = tf.gradients(loss, vars)
  grads_and_vars = zip(grads, vars)
  train_op = optimizer.apply_gradients(grads_and_vars)
  
  sess = create_session()
  sess.run(tf.global_variables_initializer())
  print("Loss %.5f" %(sess.run(loss)))
  for i in range(10):
    sessrun(train_op)
    mem_use = mem_util.peak_memory(run_metadata)[TEST_DEVICE]/1e6
    print("Loss %.5f, memory %.2f MB" %(sess.run(loss), mem_use))

  # should print something like this for actual dataset
  # 2.12764
  # 1.87759
  # 1.54445
  # 1.29149
  # 1.18474
  # 0.884424
  # 0.69454
  # 0.770236
  # 0.629259
  # 0.654465

  assert sess.run(loss) < 100
def cpu_peak():
  return mem_util.peak_memory(run_metadata)['/cpu:0']
Example #17
0
    def run(self, iterations, log_freq=500):
        global global_sess

        with tf.Session(graph=self.graph, config=self.config) as sess:
            global_sess = sess
            self._print_trainable_variables()

            tf.global_variables_initializer().run(session=sess)

            start = timer()
            header = str()
            header1 = str()
            format_str = str()
            header += "%9s %8s %8s" % ('Iteration', 'Epoch', 'Loss')
            header1 += "%9s %8s %8s" % (' ', ' ', ' ')
            header += " | %9s %8s %9s" % ('Cross_ent', 'Accuracy', 'Accuracy5')
            header1 += " | %9s %8s %9s" % (' ', ' Batch  ', ' ')
            format_str += "%(Iterations)9d %(Epoch)8.1f %(Loss)8.5f"
            format_str += " | %(Batch_Cross_Entropy)9.5f %(Batch_Accuracy)8.3f %(Batch_Accuracy5)9.3f"
            if self.calc_train_acc:
                header += " | %9s %8s %9s" % ('Cross_ent', 'Accuracy',
                                              'Accuracy5')
                header1 += " | %9s %8s %9s" % (' ', 'Training', ' ')
                format_str += " | %(Train_Cross_Entropy)9.5f %(Train_Accuracy)8.3f %(Train_Accuracy5)9.3f"
            if self.calc_valid_acc:
                header += " | %9s %8s %9s" % ('Cross_ent', 'Accuracy',
                                              'Accuracy5')
                header1 += " | %8s %10s %8s" % (' ', 'Validation', ' ')
                format_str += " | %(Valid_Cross_Entropy)9.5f %(Valid_Accuracy)8.3f %(Valid_Accuracy5)9.3f"
            header += " | %10s %8s %10s %8s" % ('Batch', 'Ex/sec', 'Major_It',
                                                'Total')
            header1 += " | %13s %11s %13s" % ('', 'Time (secs)', '')
            format_str += " | %(Batch_Time)10.7f %(Ex_per_sec)8d %(Maj_Time)10.6f %(Total_Time)8.6g"

            print(header1)
            print(header)
            batch_times = [0]
            ex_per_batch = [0]

            for step in range(iterations):
                feed_dict = self.dataset.get_next_batch()
                print("Step: ", step)

                if (step % log_freq == 0) or step == iterations - 1:
                    summaries = list()
                    if self.train_sum_freq == 'all' or (
                            self.train_sum_freq == 'end'
                            and step == iterations - 1):
                        summaries.append(self.training_summary_op)

                    start_acc = timer()
                    if step == 0:  # calculate loss etc without train_op
                        l, batch_xent, batch_acc, batch_acc5,\
                        train_xent, train_acc, train_acc5,\
                        valid_xent, valid_acc, valid_acc5, summary=sessrun([
                          self.loss, self.batch_xent, self.batch_accuracy, self.batch_accuracy5,
                          self.train_xent, self.train_accuracy, self.train_accuracy5,
                          self.valid_xent, self.valid_accuracy, self.valid_accuracy5, summaries],
                          feed_dict=feed_dict)
                    else:
                        _, l, batch_xent, batch_acc, batch_acc5,\
                        train_xent, train_acc, train_acc5,\
                        valid_xent, valid_acc, valid_acc5, summary=sessrun([self.train_op,
                          self.loss, self.batch_xent, self.batch_accuracy, self.batch_accuracy5,
                          self.train_xent, self.train_accuracy, self.train_accuracy5,
                          self.valid_xent, self.valid_accuracy, self.valid_accuracy5, summaries],
                          feed_dict=feed_dict)

                    end_acc = timer()
                    end = timer()

                    epoch = float(
                        self.global_step * self.dataset.batch_size) / float(
                            self.dataset.num_train_ex)

                    for s in summary:
                        if s != None:
                            self.train_writer.add_summary(s, self.global_step)
                    acc_time = (end_acc - start_acc)
                    tot_time = (end - start)

                    RES = {
                        'Iterations': step,
                        'Epoch': epoch,
                        'Loss': l,
                        'Batch_Cross_Entropy': batch_xent,
                        'Batch_Accuracy': batch_acc,
                        'Batch_Accuracy5': batch_acc5
                    }
                    if self.calc_train_acc:
                        RES.update({
                            'Train_Cross_Entropy': train_xent,
                            'Train_Accuracy': train_acc,
                            'Train_Accuracy5': train_acc5
                        })
                    if self.calc_valid_acc:
                        RES.update({
                            'Valid_Cross_Entropy': valid_xent,
                            'Valid_Accuracy': valid_acc,
                            'Valid_Accuracy5': valid_acc5
                        })
                    RES.update({
                        'Batch_Time': np.mean(batch_times),
                        'Ex_per_sec': np.mean(ex_per_batch),
                        'Maj_Time': acc_time,
                        'Total_Time': tot_time
                    })

                    print(format_str % RES)

                    self.Results.append(RES)

                    if self.check_divergence:
                        if np.isnan(l):
                            print("Loss is NaN")
                            break
                        if step > (0.25 * iterations):
                            if train_acc < (100 /
                                            self.dataset.num_classes) * 1.5:
                                print(
                                    "Training accuracy is near random - stopped due to failure to train"
                                )
                                break
                            if valid_acc < (100 /
                                            self.dataset.num_classes) * 1.5:
                                print(
                                    "Validation accuracy is near random - stopped due to failure to train"
                                )
                                break

                else:
                    start_it = timer()
                    _ = sessrun([self.train_op], feed_dict=feed_dict)
                    end_it = timer()
                    batch_times.extend([end_it - start_it])
                    ex_per_batch.extend(
                        [self.dataset.batch_size / batch_times[-1]])

                    if len(
                            batch_times
                    ) > 100:  #running average of batch calculation times
                        batch_times = batch_times[-100:]
                        ex_per_batch = ex_per_batch[-100:]

                self.global_step += 1
                if run_metadata:
                    mem_use = mem_util.peak_memory(
                        run_metadata)['/gpu:0'] / 1e6
                    print("Memory: %.2f MB" % (mem_use, ))

        pickle.dump(self.Results, open(self.save_path + "/Results.pickle",
                                       "wb"))
        self.train_writer.close()
            X_batch, y_batch = mnist.train.next_batch(batch_size)
            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            run_metadata = tf.RunMetadata()
            sess.run(training_op, feed_dict={X: X_batch, y: y_batch, training: True},
                     options=run_options,
                     run_metadata=run_metadata)
            if iteration % check_interval == 0:
                loss_val = loss.eval(feed_dict={X: mnist.validation.images,
                                                y: mnist.validation.labels})
                if loss_val < best_loss_val:
                    best_loss_val = loss_val
                    checks_since_last_progress = 0
                    best_model_params = get_model_params()
                else:
                    checks_since_last_progress += 1
                mem_use = mem_util.peak_memory(run_metadata)['/gpu:0']/1e6
                print("Memory used: %.2f MB "%(mem_use))
                max_bytes_in_use = sess.run(memory_stats_ops.MaxBytesInUse())/1e6
                print("Max Memory used: %.2f MB "%(max_bytes_in_use))

        acc_train = accuracy.eval(feed_dict={X: X_batch, y: y_batch})
        acc_val = accuracy.eval(feed_dict={X: mnist.validation.images,
                                           y: mnist.validation.labels})
        print("Epoch {}, train accuracy: {:.4f}%, valid. accuracy: {:.4f}%, valid. best loss: {:.6f}".format(
                  epoch, acc_train * 100, acc_val * 100, best_loss_val))
        if checks_since_last_progress > max_checks_without_progress:
            print("Early stopping!")
            break

    if best_model_params:
        restore_model_params(best_model_params)
def gradient_memory_measure_mb():
  """Evaluates gradient, prints peak memory in MBs."""
  global sess

  assert tf.test.is_gpu_available()
  tf.reset_default_graph()
  tf.set_random_seed(1)
  np.random.seed(1)
  
  start_time0 = time.perf_counter()
  loss = create_loss()

  if DUMP_GRAPHDEF:
    open('graphdef.txt', 'w').write(str(tf.get_default_graph().as_graph_def()))

  # use block_layer1, block_layer2, block_layer3 as checkpoint nodes
  # this is only active when checkpoint strategy=collection is used
  g = tf.get_default_graph()
  ops = g.get_operations()
  for op in ge.filter_ops_from_regex(ops, "block_layer"):
    tf.add_to_collection("checkpoints", op.outputs[0])

  start_time = time.perf_counter()
  grads = tf.gradients(loss, tf.trainable_variables())

  vars = tf.trainable_variables()
  grads = tf.gradients(loss, vars)
  grads_and_vars = zip(grads, vars)
  optimizer = tf.train.GradientDescentOptimizer(learning_rate=1e-5)

  train_op = optimizer.apply_gradients(grads_and_vars)

    
  start_time = time.perf_counter()
  sess = create_session()
  start_time = time.perf_counter()
  sessrun(tf.global_variables_initializer())
  start_time = time.perf_counter()
  sessrun(grads)
  start_time = time.perf_counter()
  sessrun(grads)

  # without checkpoints we expect following sequence of losses
  # Loss 35.49785, memory 626.32 MB
  # Loss 32.18098, memory 626.31 MB
  # Loss 29.42088, memory 628.37 MB
  # Loss 28.29715, memory 628.37 MB
  # Loss 26.50492, memory 628.37 MB
  # Loss 25.59675, memory 628.37 MB
  # Loss 24.45332, memory 628.37 MB
  # Loss 23.91770, memory 628.37 MB
  # Loss 22.29025, memory 626.31 MB
  # Loss 22.42356, memory 626.31 MB

  loss0 = sess.run(loss)
  assert loss0 > 35
  for i in range(10):
    sessrun(train_op)
    loss0 = sess.run(loss)
    mem_use = mem_util.peak_memory(run_metadata)['/gpu:0']/1e6
    print("Loss %.5f, memory %.2f MB" %(loss0, mem_use))

  assert loss0 < 22

  
  mem_use = mem_util.peak_memory(run_metadata)['/gpu:0']/1e6
  
  print("Memory used: %.2f MB "%(mem_use))
  total_time = time.perf_counter()-start_time0
  print("Total time: %.2f sec"%(total_time))
  assert total_time < 100
  return mem_use