Ejemplo n.º 1
0
    def test_bucket_by_quantiles(self):
        with self.test_session() as sess:
            data = tf.data.Dataset.from_tensor_slices(list(range(10))).repeat()
            data = data.apply(
                ops.bucket_by_quantiles(len_fn=lambda x: x,
                                        batch_size=4,
                                        n_buckets=2,
                                        hist_bounds=[2, 4, 6, 8]))
            it = data.make_initializable_iterator()
            sess.run(it.initializer)
            sess.run(tf.local_variables_initializer())
            next_op = it.get_next()

            # Let the model gather statistics, it sees 4*5=20 = 2 epochs,
            # so each bin should have a count of 4
            for _ in range(5):
                sess.run(next_op)

            counts = sess.run(tf.local_variables()[0])
            self.assertEqual(counts.tolist(), [4, 8, 12, 16, 20])

            # At this point the model should perfectly quantize the input
            for _ in range(4):
                out = sess.run(next_op)
                if out[0] < 5:
                    self.assertAllInRange(out, 0, 5)
                else:
                    self.assertAllInRange(out, 5, 10)
Ejemplo n.º 2
0
def count_variables_by_type(variables=None):
    """Returns a dict mapping dtypes to number of variables and scalars.

  Args:
    variables: iterable of `tf.Variable`s, or None. If None is passed, then all
      global and local variables in the current graph are used.

  Returns:
    A dict mapping tf.dtype keys to a dict containing the keys 'num_scalars' and
      'num_variables'.
  """
    if variables is None:
        variables = tf.global_variables() + tf.local_variables()
    unique_types = set(v.dtype.base_dtype for v in variables)
    results_dict = {}
    for dtype in unique_types:
        if dtype == tf.string:
            logging.warning(
                "NB: string Variables present. The memory usage for these  Variables "
                "will not be accurately computed as it depends on the exact strings "
                "stored in a particular session.")
        vars_of_type = [v for v in variables if v.dtype.base_dtype == dtype]
        num_scalars = sum(v.shape.num_elements() for v in vars_of_type)
        results_dict[dtype] = {
            "num_variables": len(vars_of_type),
            "num_scalars": num_scalars
        }
    return results_dict
Ejemplo n.º 3
0
def load_checkpoint(use_placeholder=False, session=None):
  dataset = build("data")
  model = build("model")
  if use_placeholder:
    inputs = dataset.get_placeholders()
  else:
    inputs = dataset()

  info = model.eval(inputs)
  if session is None:
    session = tf.Session()
  saver = tf.train.Saver()
  checkpoint_dir = get_checkpoint_dir()
  checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
  saver.restore(session, checkpoint_file)

  print('Successfully restored Checkpoint "{}"'.format(checkpoint_file))
  # print variables
  variables = tf.global_variables() + tf.local_variables()
  for row in snt.format_variables(variables, join_lines=False):
    print(row)

  return {
      "session": session,
      "model": model,
      "info": info,
      "inputs": inputs,
      "dataset": dataset,
  }
Ejemplo n.º 4
0
 def initialize_variables(self):
     """Initialize global variables."""
     train_vars = tf.trainable_variables()
     other_vars = [
         var for var in tf.global_variables() + tf.local_variables()
         if var not in train_vars
     ]
     self.sess.run([v.initializer for v in train_vars])
     self.sess.run([v.initializer for v in other_vars])
Ejemplo n.º 5
0
  def testUsage(self, custom_getter_fn):
    # Create a module with no custom getters.
    linear = snt.Linear(10)

    # Create a module within the scope of an 'override args' custom getter.
    local_custom_getter = custom_getter_fn(
        collections=[tf.GraphKeys.LOCAL_VARIABLES])
    with tf.variable_scope("", custom_getter=local_custom_getter):
      local_linear = snt.Linear(10)

    # Connect both modules to the graph, creating their variables.
    inputs = tf.placeholder(dtype=tf.float32, shape=(7, 11))
    linear(inputs)
    local_linear(inputs)

    self.assertIn(linear.w, tf.global_variables())
    self.assertNotIn(linear.w, tf.local_variables())
    self.assertIn(local_linear.w, tf.local_variables())
    self.assertNotIn(local_linear.w, tf.global_variables())
Ejemplo n.º 6
0
    def add_inference(self, cnn):
        assert cnn.top_layer.shape[1:] == (3, 224, 224)
        cnn.conv(1, 1, 1, 1, 1, use_batch_norm=True)
        cnn.mpool(1, 1, 1, 1, num_channels_in=1)
        cnn.reshape([-1, 224 * 224])
        cnn.affine(1, activation=None)

        # Assert that the batch norm variables are filtered out for L2 loss.
        variables = tf.global_variables() + tf.local_variables()
        assert len(variables) > len(self.filter_l2_loss_vars(variables))
Ejemplo n.º 7
0
    def clean_acc_history(self):
        """Cleans accumulated counter in metrics.accuracy."""

        if not hasattr(self, 'clean_accstate_op'):
            self.clean_accstate_op = [
                a.assign(0)
                for a in utils.get_var(tf.local_variables(), 'accuracy')
            ]
            logging.info('Create {} clean accuracy state ops'.format(
                len(self.clean_accstate_op)))
        self.sess.run(self.clean_accstate_op)
Ejemplo n.º 8
0
  def testExplicitArgOverridden(self):
    # Create a variable within the scope of an 'override args' custom getter.
    local_custom_getter = snt.custom_getters.override_args(
        collections=[tf.GraphKeys.LOCAL_VARIABLES])
    with tf.variable_scope("", custom_getter=local_custom_getter):
      # Explicitly specify an arg that disagrees with the custom getter.
      v = tf.get_variable("v", (), collections=[
          tf.GraphKeys.GLOBAL_VARIABLES])

    # The custom getter should win.
    self.assertIn(v, tf.local_variables())
    self.assertNotIn(v, tf.global_variables())
Ejemplo n.º 9
0
def summarize_variables(variables=None):
    """Logs a summary of variable information.

    This function groups Variables by dtype and prints out the number of Variables
    and the total number of scalar values for each datatype, as well as the total
    memory consumed.

    For Variables of type tf.string, the memory usage cannot be accurately
    calculated from the Graph as the memory requirements change based on what
    strings are actually stored, which can only be determined inside a session.
    In this case, the amount of memory used to stored the pointers to the strings
    is logged, along with a warning.

    Args:
      variables: iterable of variables; if not provided, then all variables
        (in the default graph) are summarized.
    """
    if variables is None:
        variables = tf.global_variables() + tf.local_variables()
    total_num_scalars = 0
    total_num_bytes = 0
    # Sort by string representation of type name, so output is deterministic.
    unique_types_ordered = sorted(
        set([v.dtype.base_dtype for v in variables]),
        key=lambda dtype: "%r" % dtype,
    )
    for dtype in unique_types_ordered:
        if dtype == tf.string:
            tf.logging.warning(
                "NB: string Variables present. The memory usage for these  Variables "
                "will not be accurately computed as it depends on the exact strings "
                "stored in a particular session.")
        vars_of_type = [v for v in variables if v.dtype.base_dtype == dtype]
        num_scalars = sum(v.shape.num_elements() for v in vars_of_type)
        num_bytes = num_scalars * dtype.size
        tf.logging.info(
            "%r: %d variables comprising %d scalars, %s",
            dtype,
            len(vars_of_type),
            num_scalars,
            _num_bytes_to_human_readable(num_bytes),
        )
        total_num_scalars += num_scalars
        total_num_bytes += num_bytes
    tf.logging.info(
        "Total: %d variables comprising %d scalars, %s",
        len(variables),
        total_num_scalars,
        _num_bytes_to_human_readable(total_num_bytes),
    )
Ejemplo n.º 10
0
def log_variables(variables=None):
    """Logs variable information.

    This function logs the name, shape, type, collections, and device for either
    all variables or a given iterable of variables.

    Args:
      variables: iterable of variables; if not provided, then all variables
          (in the default graph) are logged.
    """
    if variables is None:
        variables = tf.global_variables() + tf.local_variables()
    for row in format_variables(variables, join_lines=False):
        tf.logging.info(row)
Ejemplo n.º 11
0
 def initialize_variables(self):
     """Initializes global variables."""
     if FLAGS.pretrained_ckpt:
         # Used for imagenet pretraining
         self.net.init_model_weight(FLAGS.pretrained_ckpt,
                                    include_top=FLAGS.mode == 'evaluation',
                                    mode=FLAGS.pretrained_ckpt_mode)
     train_vars = tf.trainable_variables()
     other_vars = [
         var for var in tf.global_variables() + tf.local_variables()
         if var not in train_vars
     ]
     self.sess.run([v.initializer for v in train_vars])
     self.sess.run([v.initializer for v in other_vars])
Ejemplo n.º 12
0
  def testWithNested(self, custom_getter_fn):
    # Create a module with a custom getter, within the scope of an
    # 'override args' custom getter.
    local_custom_getter = custom_getter_fn(
        collections=[tf.GraphKeys.LOCAL_VARIABLES])
    with tf.variable_scope("", custom_getter=local_custom_getter):
      local_linear = snt.Linear(10, custom_getter=_suffix_custom_getter)

    # Connect the module to the graph, creating its variables.
    inputs = tf.placeholder(dtype=tf.float32, shape=(7, 11))
    local_linear(inputs)

    # Both custom getters should be effective.
    self.assertIn(local_linear.w, tf.local_variables())
    self.assertNotIn(local_linear.w, tf.global_variables())
    self.assertEqual("linear/w_test", local_linear.w.op.name)
Ejemplo n.º 13
0
def log_variables(variables=None):
    """Logs variable information.

  This function logs the name, shape, type, collections, and device for either
  all variables or a given iterable of variables. In the "Device" columns,
  the nature of the variable (legacy or resource (for ResourceVariables)) is
  also specified in parenthesis.

  Args:
    variables: iterable of variables; if not provided, then all variables
        (in the default graph) are logged.
  """
    if variables is None:
        variables = tf.global_variables() + tf.local_variables()
    for row in format_variables(variables, join_lines=False):
        logging.info(row)
Ejemplo n.º 14
0
    def get_post_init_ops(self):
        # Copy initialized variables for variables on the parameter server
        # to the local copy of the variable.

        local_vars = tf.local_variables()
        local_var_by_name = dict([(self._strip_port(v.name), v)
                                  for v in local_vars])
        post_init_ops = []
        for v in tf.global_variables():
            if v.name.startswith(variable_mgr_util.PS_SHADOW_VAR_PREFIX +
                                 '/v0/'):
                prefix = self._strip_port(
                    v.name[len(variable_mgr_util.PS_SHADOW_VAR_PREFIX +
                               '/v0'):])
                for i in range(self.benchmark_cnn.num_gpus):
                    name = 'v%s%s' % (i, prefix)
                    if name in local_var_by_name:
                        copy_to = local_var_by_name[name]
                        post_init_ops.append(copy_to.assign(v.read_value()))
        return post_init_ops
Ejemplo n.º 15
0
  def testGetAllLocalVariables(self, get_non_trainable_variables):
    def local_custom_getter(getter, *args, **kwargs):
      kwargs["trainable"] = False
      if "collections" in kwargs and kwargs["collections"] is not None:
        kwargs["collections"] += [tf.GraphKeys.LOCAL_VARIABLES]
      else:
        kwargs["collections"] = [tf.GraphKeys.LOCAL_VARIABLES]
      return getter(*args, **kwargs)

    inputs = tf.ones(dtype=tf.float32, shape=[10, 10])
    # Create a new ModuleWithSubmodules that uses all local variables
    with tf.variable_scope("", custom_getter=local_custom_getter):
      submodule_a = SimpleModule(name="simple_submodule")
      submodule_b = ComplexModule(name="complex_submodule")
      local_module = ModuleWithSubmodules(
          submodule_a=submodule_a, submodule_b=submodule_b)
    local_module(inputs)  # pylint: disable=not-callable

    self.assertEmpty(local_module.get_all_variables())
    self.assertEmpty(tf.all_variables())
    self.assertLen(tf.local_variables(), 12)

    all_variables = get_non_trainable_variables(local_module)
    all_variable_names = sorted([str(v.name) for v in all_variables])
    self.assertEqual([
        "complex_submodule/linear_1/b:0",
        "complex_submodule/linear_1/w:0",
        "complex_submodule/linear_2/b:0",
        "complex_submodule/linear_2/w:0",
        "module_with_submodules/complex_build/linear_1/b:0",
        "module_with_submodules/complex_build/linear_1/w:0",
        "module_with_submodules/complex_build/linear_2/b:0",
        "module_with_submodules/complex_build/linear_2/w:0",
        "module_with_submodules/simple_build/b:0",
        "module_with_submodules/simple_build/w:0",
        "simple_submodule/b:0",
        "simple_submodule/w:0",
    ], all_variable_names)
Ejemplo n.º 16
0
 def savable_variables(self):
     """Returns a list/dict of savable variables to pass to tf.train.Saver."""
     params = {}
     for v in tf.global_variables():
         assert (v.name.startswith(variable_mgr_util.PS_SHADOW_VAR_PREFIX +
                                   '/v0/')
                 or v.name in ('global_step:0', 'loss_scale:0',
                               'loss_scale_normal_steps:0')), (
                                   'Invalid global variable: %s' % v)
         # We store variables in the checkpoint with the shadow variable prefix
         # removed so we can evaluate checkpoints in non-distributed replicated
         # mode. The checkpoints can also be loaded for training in
         # distributed_replicated mode.
         name = self._strip_port(
             self._remove_shadow_var_prefix_if_present(v.name))
         params[name] = v
     for v in tf.local_variables():
         # Non-trainable variables, such as batch norm moving averages, do not have
         # corresponding global shadow variables, so we add them here. Trainable
         # local variables have corresponding global shadow variables, which were
         # added in the global variable loop above.
         if v.name.startswith('v0/') and v not in tf.trainable_variables():
             params[self._strip_port(v.name)] = v
     return params
Ejemplo n.º 17
0
    def eval_metrics_host_call_fn(policy_output,
                                  value_output,
                                  pi_tensor,
                                  value_tensor,
                                  policy_cost,
                                  value_cost,
                                  l2_cost,
                                  combined_cost,
                                  step,
                                  est_mode=tf.estimator.ModeKeys.TRAIN):
        policy_entropy = -tf.reduce_mean(
            tf.reduce_sum(policy_output * tf.log(policy_output), axis=1))
        # pi_tensor is one_hot when generated from sgfs (for supervised learning)
        # and soft-max when using self-play records. argmax normalizes the two.
        policy_target_top_1 = tf.argmax(pi_tensor, axis=1)

        policy_output_in_top1 = tf.to_float(
            tf.nn.in_top_k(policy_output, policy_target_top_1, k=1))
        policy_output_in_top3 = tf.to_float(
            tf.nn.in_top_k(policy_output, policy_target_top_1, k=3))

        policy_top_1_confidence = tf.reduce_max(policy_output, axis=1)
        policy_target_top_1_confidence = tf.boolean_mask(
            policy_output,
            tf.one_hot(policy_target_top_1,
                       tf.shape(policy_output)[1]))

        value_cost_normalized = value_cost / params['value_cost_weight']
        avg_value_observed = tf.reduce_mean(value_tensor)

        with tf.variable_scope('metrics'):
            metric_ops = {
                'policy_cost':
                tf.metrics.mean(policy_cost),
                'value_cost':
                tf.metrics.mean(value_cost),
                'value_cost_normalized':
                tf.metrics.mean(value_cost_normalized),
                'l2_cost':
                tf.metrics.mean(l2_cost),
                'policy_entropy':
                tf.metrics.mean(policy_entropy),
                'combined_cost':
                tf.metrics.mean(combined_cost),
                'avg_value_observed':
                tf.metrics.mean(avg_value_observed),
                'policy_accuracy_top_1':
                tf.metrics.mean(policy_output_in_top1),
                'policy_accuracy_top_3':
                tf.metrics.mean(policy_output_in_top3),
                'policy_top_1_confidence':
                tf.metrics.mean(policy_top_1_confidence),
                'policy_target_top_1_confidence':
                tf.metrics.mean(policy_target_top_1_confidence),
                'value_confidence':
                tf.metrics.mean(tf.abs(value_output)),
            }

        if est_mode == tf.estimator.ModeKeys.EVAL:
            return metric_ops

        # NOTE: global_step is rounded to a multiple of FLAGS.summary_steps.
        eval_step = tf.reduce_min(step)

        # Create summary ops so that they show up in SUMMARIES collection
        # That way, they get logged automatically during training
        summary_writer = contrib_summary.create_file_writer(FLAGS.work_dir)
        with summary_writer.as_default(), \
                contrib_summary.record_summaries_every_n_global_steps(
                    params['summary_steps'], eval_step):
            for metric_name, metric_op in metric_ops.items():
                contrib_summary.scalar(metric_name,
                                       metric_op[1],
                                       step=eval_step)

        # Reset metrics occasionally so that they are mean of recent batches.
        reset_op = tf.variables_initializer(tf.local_variables('metrics'))
        cond_reset_op = tf.cond(
            tf.equal(eval_step % params['summary_steps'], tf.to_int64(1)),
            lambda: reset_op, lambda: tf.no_op())

        return contrib_summary.all_summary_ops() + [cond_reset_op]
def train_net(Net,
              training_data,
              base_lr,
              loss_weight,
              train_mode,
              num_epochs=[1, None, None],
              batch_size=64,
              weight_decay=4e-3,
              load_model=False,
              load_filename=None,
              save_model=False,
              save_filename=None,
              num_iter_to_save=10000,
              gpu_memory_fraction=1):

    images = []
    labels = []
    tasks = ['cls', 'bbx', 'pts']
    shape = 12
    if Net.__name__ == 'RNet':
        shape = 24
    elif Net.__name__ == 'ONet':
        shape = 48
    for index in range(train_mode):
        image, label = inputs(filename=[training_data[index]],
                              batch_size=batch_size,
                              num_epochs=num_epochs[index],
                              label_type=tasks[index],
                              shape=shape)
        images.append(image)
        labels.append(label)
    while len(images) is not 3:
        images.append(tf.placeholder(tf.float32, [None, shape, shape, 3]))
        labels.append(tf.placeholder(tf.float32))
    net = Net((('cls', images[0]), ('bbx', images[1]), ('pts', images[2])),
              weight_decay_coeff=weight_decay)

    print('all trainable variables:')
    all_vars = tf.get_collection(key=tf.GraphKeys.TRAINABLE_VARIABLES)
    for var in all_vars:
        print(var)

    print('all local variable:')
    local_variables = tf.local_variables()
    for l_v in local_variables:
        print(l_v.name)

    prefix = str(all_vars[0].name[0:5])
    out_put = net.get_all_output()
    cls_output = tf.reshape(out_put[0], [-1, 2])
    bbx_output = tf.reshape(out_put[1], [-1, 4])
    pts_output = tf.reshape(out_put[2], [-1, 10])

    # cls loss
    softmax_loss = loss_weight[0] * \
        tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(labels=labels[0],
                                                logits=cls_output))
    weight_losses_cls = net.get_weight_decay()['cls']
    losses_cls = softmax_loss + tf.add_n(weight_losses_cls)

    # bbx loss
    square_bbx_loss = loss_weight[1] * \
        tf.reduce_mean(tf.squared_difference(bbx_output, labels[1]))
    weight_losses_bbx = net.get_weight_decay()['bbx']
    losses_bbx = square_bbx_loss + tf.add_n(weight_losses_bbx)

    # pts loss
    square_pts_loss = loss_weight[2] * \
        tf.reduce_mean(tf.squared_difference(pts_output, labels[2]))
    weight_losses_pts = net.get_weight_decay()['pts']
    losses_pts = square_pts_loss + tf.add_n(weight_losses_pts)

    global_step_cls = tf.Variable(1, name='global_step_cls', trainable=False)
    global_step_bbx = tf.Variable(1, name='global_step_bbx', trainable=False)
    global_step_pts = tf.Variable(1, name='global_step_pts', trainable=False)

    train_cls = tf.train.AdamOptimizer(learning_rate=base_lr) \
                        .minimize(losses_cls, global_step=global_step_cls)
    train_bbx = tf.train.AdamOptimizer(learning_rate=base_lr) \
                        .minimize(losses_bbx, global_step=global_step_bbx)
    train_pts = tf.train.AdamOptimizer(learning_rate=base_lr) \
                        .minimize(losses_pts, global_step=global_step_pts)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    config = tf.ConfigProto()
    config.allow_soft_placement = True
    config.gpu_options.per_process_gpu_memory_fraction = gpu_memory_fraction
    config.gpu_options.allow_growth = True

    loss_agg_cls = [0]
    loss_agg_bbx = [0]
    loss_agg_pts = [0]
    step_value = [1, 1, 1]

    with tf.Session(config=config) as sess:
        sess.run(init_op)
        saver = tf.train.Saver(max_to_keep=200000)
        if load_model:
            saver.restore(sess, load_filename)
        else:
            net.load(load_filename, sess, prefix)
        if save_model:
            save_dir = os.path.split(save_filename)[0]
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        try:
            while not coord.should_stop():
                choic = np.random.randint(0, train_mode)
                if choic == 0:
                    _, loss_value_cls, step_value[0] = sess.run(
                        [train_cls, softmax_loss, global_step_cls])
                    loss_agg_cls.append(loss_value_cls)
                elif choic == 1:
                    _, loss_value_bbx, step_value[1] = sess.run(
                        [train_bbx, square_bbx_loss, global_step_bbx])
                    loss_agg_bbx.append(loss_value_bbx)
                else:
                    _, loss_value_pts, step_value[2] = sess.run(
                        [train_pts, square_pts_loss, global_step_pts])
                    loss_agg_pts.append(loss_value_pts)

                if sum(step_value) % (100 * train_mode) == 0:
                    agg_cls = sum(loss_agg_cls) / len(loss_agg_cls)
                    agg_bbx = sum(loss_agg_bbx) / len(loss_agg_bbx)
                    agg_pts = sum(loss_agg_pts) / len(loss_agg_pts)
                    print('Step %d for cls: loss = %.5f' %
                          (step_value[0], agg_cls),
                          end='. ')
                    print('Step %d for bbx: loss = %.5f' %
                          (step_value[1], agg_bbx),
                          end='. ')
                    print('Step %d for pts: loss = %.5f' %
                          (step_value[2], agg_pts))
                    loss_agg_cls = [0]
                    loss_agg_bbx = [0]
                    loss_agg_pts = [0]

                if save_model and (step_value[0] % num_iter_to_save == 0):
                    saver.save(sess, save_filename, global_step=step_value[0])

        except tf.errors.OutOfRangeError:
            print('Done training for %d epochs, %d steps.' %
                  (num_epochs[0], step_value[0]))
        finally:
            coord.request_stop()

        coord.join(threads)
Ejemplo n.º 19
0
global_step = tf.train.create_global_step()
learning_rate = configure_learning_rate(global_step, TRAIN_SAMPLES, FLAGS)
tf.summary.scalar('learning_rate', learning_rate)

optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                       momentum=FLAGS.momentum)
grads = optimizer.compute_gradients(total_loss)
train_op = optimizer.apply_gradients(grads, global_step=global_step)
summary_op = tf.summary.merge_all()

saver = tf.train.Saver(tf.trainable_variables())

############################################
#           For   validation               #
############################################
var_exclude = [v.name for v in tf.local_variables()]
images_val, labels_val = get_data(FLAGS.data_dir, 'validation',
                                  FLAGS.batch_size)
logits_val = resnet(images_val, training=False)
accuracy_val = get_accuracy(logits_val, labels_val)

# clear former accuracy information for validation
var_to_refresh = [v for v in tf.local_variables() if v.name not in var_exclude]
init_local_val = tf.variables_initializer(var_to_refresh)

#### HYPER PARAMETERS

print("\nHyper parameters: ")
print("TRAIN_SAMPLES: ", TRAIN_SAMPLES)
print("VAL_SAMPLES: ", VAL_SAMPLES)
print("batch_size: ", FLAGS.batch_size)
Ejemplo n.º 20
0
def main(argv):
    del argv  # unused
    if tf.io.gfile.exists(FLAGS.model_dir):
        tf.logging.warning("Warning: deleting old log directory at {}".format(
            FLAGS.model_dir))
        tf.io.gfile.rmtree(FLAGS.model_dir)
    tf.io.gfile.makedirs(FLAGS.model_dir)

    if FLAGS.fake_data:
        (x_train, y_train), (x_test, y_test) = build_fake_data()
    else:
        (x_train, y_train), (x_test,
                             y_test) = tf.keras.datasets.cifar10.load_data()

    (images, labels, handle, training_iterator,
     heldout_iterator) = build_input_pipeline(x_train, x_test, y_train, y_test,
                                              FLAGS.batch_size, 500)

    if FLAGS.architecture == "resnet":
        model_fn = bayesian_resnet
    else:
        model_fn = bayesian_vgg

    import pdb
    pdb.set_trace()

    model = model_fn(
        IMAGE_SHAPE,
        num_classes=10,
        kernel_posterior_scale_mean=FLAGS.kernel_posterior_scale_mean,
        kernel_posterior_scale_constraint=FLAGS.
        kernel_posterior_scale_constraint)
    logits = model(images)
    labels_distribution = tfd.Categorical(logits=logits)

    # Perform KL annealing. The optimal number of annealing steps
    # depends on the dataset and architecture.
    t = tf.compat.v2.Variable(0.0)
    kl_regularizer = t / (FLAGS.kl_annealing * len(x_train) / FLAGS.batch_size)

    # Compute the -ELBO as the loss. The kl term is annealed from 0 to 1 over
    # the epochs specified by the kl_annealing flag.
    log_likelihood = labels_distribution.log_prob(labels)
    neg_log_likelihood = -tf.reduce_mean(input_tensor=log_likelihood)
    kl = sum(model.losses) / len(x_train) * tf.minimum(1.0, kl_regularizer)
    loss = neg_log_likelihood + kl

    # Build metrics for evaluation. Predictions are formed from a single forward
    # pass of the probabilistic layers. They are cheap but noisy
    # predictions.
    predictions = tf.argmax(input=logits, axis=1)
    with tf.name_scope("train"):
        train_accuracy, train_accuracy_update_op = tf.metrics.accuracy(
            labels=labels, predictions=predictions)
        opt = tf.train.AdamOptimizer(FLAGS.learning_rate)
        train_op = opt.minimize(loss)
        update_step_op = tf.assign(t, t + 1)

    with tf.name_scope("valid"):
        valid_accuracy, valid_accuracy_update_op = tf.metrics.accuracy(
            labels=labels, predictions=predictions)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    stream_vars_valid = [v for v in tf.local_variables() if "valid/" in v.name]
    reset_valid_op = tf.variables_initializer(stream_vars_valid)

    with tf.Session() as sess:
        sess.run(init_op)

        # Run the training loop
        train_handle = sess.run(training_iterator.string_handle())
        heldout_handle = sess.run(heldout_iterator.string_handle())
        training_steps = int(
            round(FLAGS.epochs * (len(x_train) / FLAGS.batch_size)))
        for step in range(training_steps):
            _ = sess.run([train_op, train_accuracy_update_op, update_step_op],
                         feed_dict={handle: train_handle})

            # Manually print the frequency
            if step % 10 == 0:
                loss_value, accuracy_value, kl_value = sess.run(
                    [loss, train_accuracy, kl],
                    feed_dict={handle: train_handle})
                print("Step: {:>3d} Loss: {:.3f} Accuracy: {:.3f} KL: {:.3f}".
                      format(step, loss_value, accuracy_value, kl_value))

            if (step + 1) % FLAGS.eval_freq == 0:
                # Compute log prob of heldout set by averaging draws from the model:
                # p(heldout | train) = int_model p(heldout|model) p(model|train)
                #                   ~= 1/n * sum_{i=1}^n p(heldout | model_i)
                # where model_i is a draw from the posterior
                # p(model|train).
                # probs = np.asarray(sess.run(labels_distribution.prob(np.zeros((1,500))), feed_dict={handle:heldout_handle}))
                probs = (np.asarray([
                    sess.run((labels_distribution.prob(
                        np.tile(np.arange(10), [500, 1]).T)),
                             feed_dict={handle: heldout_handle})
                    for _ in range(FLAGS.num_monte_carlo)
                ]))
                mean_probs = np.mean(probs, axis=0)

                _, label_vals = sess.run((images, labels),
                                         feed_dict={handle: heldout_handle})
                # heldout_lp = np.mean(np.log(mean_probs[np.arange(mean_probs.shape[0]),
                #  label_vals.flatten()]))
                heldout_lp = np.mean(
                    np.log(mean_probs[:, label_vals.flatten()]))
                print(" ... Held-out nats: {:.3f}".format(heldout_lp))

                # Calculate validation accuracy
                for _ in range(20):
                    sess.run(valid_accuracy_update_op,
                             feed_dict={handle: heldout_handle})
                valid_value = sess.run(valid_accuracy,
                                       feed_dict={handle: heldout_handle})

                print(" ... Validation Accuracy: {:.3f}".format(valid_value))

                sess.run(reset_valid_op)
    a = 2