def testLossCostDecorated(self):
    params = {'trainable': True, 'normalizer_fn': slim.batch_norm,
              'normalizer_params': {'scale': True}}

    with slim.arg_scope([slim.layers.conv2d], **params):
      image = tf.constant(0.0, shape=[1, 3, 3, NUM_CHANNELS])
      conv1 = slim.layers.conv2d(
          image, 2, kernel_size=1, padding='SAME', scope='conv1')
    with self.cached_session():
      tf.global_variables_initializer().run()
      name_to_var = {v.op.name: v for v in tf.global_variables()}
      gamma1 = name_to_var['conv1/BatchNorm/gamma']
      gamma1.assign([1] * 2).eval()

    self.gamma_flop_reg = flop_regularizer.GammaFlopsRegularizer(
        [conv1.op],
        gamma_threshold=0.1,
        regularizer_decorator=dummy_decorator.DummyDecorator,
        decorator_parameters={'scale': 0.5})

    conv = tf.get_default_graph().get_operation_by_name('conv1/Conv2D')
    # we compare the computed cost and regularization calculated as follows:
    # reg_term = op_coeff * (number_of_inputs * (regularization=2 * 0.5) +
    # number_of_outputs * (input_regularization=0))
    # number_of_flops = coeff * number_of_inputs * number_of_outputs.
    with self.cached_session():
      predicted_reg = self.gamma_flop_reg.get_regularization_term([conv]).eval()
      self.assertEqual(_coeff(conv) * NUM_CHANNELS * 1, predicted_reg)
      predicted_cost = self.gamma_flop_reg.get_cost([conv]).eval()
      self.assertEqual(_coeff(conv) * 2 * NUM_CHANNELS, predicted_cost)
Beispiel #2
0
    def BuildModel(self):
        # Our test model is:
        #
        #         -> conv1 --+     -> conv3 -->
        #        /           |    /
        #  image          [concat]
        #        \           |    \
        #         -> conv2 --+     -> conv4 -->
        #
        # (the model has two "outputs", conv3 and conv4).
        #
        image = tf.constant(0.0, shape=[1, 17, 19, NUM_CHANNELS])
        conv1 = layers.conv2d(image, 13, [7, 5], padding='SAME', scope='conv1')
        conv2 = layers.conv2d(image, 23, [1, 1], padding='SAME', scope='conv2')
        concat = tf.concat([conv1, conv2], 3)
        self.conv3 = layers.conv2d(concat,
                                   29, [3, 3],
                                   stride=2,
                                   padding='SAME',
                                   scope='conv3')
        self.conv4 = layers.conv2d(concat,
                                   31, [1, 1],
                                   stride=1,
                                   padding='SAME',
                                   scope='conv4')
        self.name_to_var = {v.op.name: v for v in tf.global_variables()}

        self.gamma_flop_reg = flop_regularizer.GammaFlopsRegularizer(
            [self.conv3.op, self.conv4.op], gamma_threshold=0.45)
Beispiel #3
0
    def BuildModel(self):
        # Our test model is:
        #
        #         -> dw1 --> conv1 --+
        #        /                   |
        #  image                     [concat] --> conv3
        #        \                   |
        #         -> conv2 --> dw2 --+
        #
        # (the model has one "output", conv3).
        #
        image = tf.constant(0.0, shape=[1, 17, 19, NUM_CHANNELS])
        dw1 = layers.separable_conv2d(image,
                                      None, [3, 3],
                                      depth_multiplier=1,
                                      stride=1,
                                      scope='dw1')
        conv1 = layers.conv2d(dw1, 13, [7, 5], padding='SAME', scope='conv1')
        conv2 = layers.conv2d(image, 23, [1, 1], padding='SAME', scope='conv2')
        dw2 = layers.separable_conv2d(conv2,
                                      None, [5, 5],
                                      depth_multiplier=1,
                                      stride=1,
                                      scope='dw2')
        concat = tf.concat([conv1, dw2], 3)
        self.conv3 = layers.conv2d(concat,
                                   29, [3, 3],
                                   stride=2,
                                   padding='SAME',
                                   scope='conv3')
        self.name_to_var = {v.op.name: v for v in tf.global_variables()}

        self.gamma_flop_reg = flop_regularizer.GammaFlopsRegularizer(
            [self.conv3.op], gamma_threshold=0.45)
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)

            network_regularizer = flop_regularizer.GammaFlopsRegularizer(
                output_boundary=[logits.op],
                input_boundary=[images.op, labels.op],
                gamma_threshold=1e-3)
            regularization_strength = 7e-9
            regularizer_loss = (network_regularizer.get_regularization_term() *
                                regularization_strength)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
            return end_points, network_regularizer, regularizer_loss
Beispiel #5
0
  def testInceptionV2_TotalCost(self):
    conv_params = {
        'activation_fn': tf.nn.relu6,
        'weights_regularizer': contrib_layers.l2_regularizer(0.00004),
        'weights_initializer': tf.random_normal_initializer(stddev=0.03),
        'trainable': True,
        'biases_initializer': tf.constant_initializer(0.0),
        'normalizer_fn': contrib_layers.batch_norm,
        'normalizer_params': {
            'is_training': False,
            'decay': 0.9997,
            'scale': True,
            'epsilon': 0.001,
        }
    }

    tf.reset_default_graph()
    with slim.arg_scope([slim.layers.conv2d, slim.layers.separable_conv2d],
                        **conv_params):
      # Build model.
      image = tf.zeros([1, 224, 224, 3])
      net, _ = inception.inception_v2_base(image)
      logits = slim.layers.fully_connected(
          net,
          1001,
          activation_fn=None,
          scope='logits',
          weights_initializer=tf.random_normal_initializer(stddev=1e-3),
          biases_initializer=tf.constant_initializer(0.0))

    # Instantiate regularizers.
    flop_reg = flop_regularizer.GammaFlopsRegularizer(
        [logits.op], gamma_threshold=0.5)
    p100_reg = latency_regularizer.GammaLatencyRegularizer(
        [logits.op], gamma_threshold=0.5, hardware='P100')
    v100_reg = latency_regularizer.GammaLatencyRegularizer(
        [logits.op], gamma_threshold=0.5, hardware='V100')
    model_size_reg = model_size_regularizer.GammaModelSizeRegularizer(
        [logits.op], gamma_threshold=0.5)

    with self.cached_session():
      tf.global_variables_initializer().run()

    # Verify costs are expected.
    self.assertAllClose(3.86972e+09, flop_reg.get_cost())
    self.assertAllClose(517536.0, p100_reg.get_cost())
    self.assertAllClose(173330.453125, v100_reg.get_cost())
    self.assertAllClose(1.11684e+07, model_size_reg.get_cost())
  def testLossDecorated(self):
    self.BuildWithBatchNorm(True)
    self.AddRegularizer()
    # Create network regularizer with DummyDecorator op regularization.
    self.gamma_flop_reg = flop_regularizer.GammaFlopsRegularizer(
        [self.conv3.op, self.conv4.op],
        gamma_threshold=0.45,
        regularizer_decorator=dummy_decorator.DummyDecorator,
        decorator_parameters={'scale': 0.5})

    all_convs = [
        o for o in tf.get_default_graph().get_operations() if o.type == 'Conv2D'
    ]
    total_reg_term = 1410376.375
    self.assertAllClose(total_reg_term * 0.5, self.GetLoss(all_convs))
    self.assertAllClose(total_reg_term * 0.5, self.GetLoss([]))
Beispiel #7
0
  def testStructureExporter(self,
                            use_batch_norm,
                            remove_common_prefix=False,
                            export=False):
    """Tests the export of alive counts.

    Args:
      use_batch_norm: A Boolean. Inidcates if batch norm should be used.
      remove_common_prefix: A Boolean, passed to StructureExporter ctor.
      export: A Boolean. Indicates if the result should be exported to test dir.
    """
    sc = self._batch_norm_scope() if use_batch_norm else []
    with tf.contrib.framework.arg_scope(sc):
      with tf.variable_scope(tf.get_variable_scope()):
        final_op = _build_model()
    variables = {v.name: v for v in tf.trainable_variables()}
    update_vars = []
    if use_batch_norm:
      network_regularizer = flop_regularizer.GammaFlopsRegularizer(
          [final_op], gamma_threshold=1e-6)
      for layer in LAYERS:
        force_alive = ALIVE['X/{}/Conv2D'.format(layer)]
        gamma = variables['X/{}/BatchNorm/gamma:0'.format(layer)]
        update_vars.append(gamma.assign(force_alive * gamma))
    else:
      network_regularizer = flop_regularizer.GroupLassoFlopsRegularizer(
          [final_op], threshold=1e-6)
      print(variables)
      for layer in LAYERS:
        force_alive = ALIVE['X/{}/Conv2D'.format(layer)]
        weights = variables['X/{}/weights:0'.format(layer)]
        update_vars.append(weights.assign(force_alive * weights))
    structure_exporter = se.StructureExporter(
        network_regularizer.op_regularizer_manager, remove_common_prefix)
    with self.cached_session() as sess:
      tf.global_variables_initializer().run()
      sess.run(update_vars)
      structure_exporter.populate_tensor_values(
          sess.run(structure_exporter.tensors))
    expected = EXPECTED_COUNTS[remove_common_prefix]
    self.assertEqual(
        expected,
        structure_exporter.get_alive_counts())
    if export:
      f = MockFile()
      structure_exporter.save_alive_counts(f)
      self.assertEqual(expected, json.loads(f.read()))
Beispiel #8
0
    def testCost(self):
        self.buildGraphWithBatchNorm(resnet_v1.resnet_v1,
                                     resnet_v1.resnet_v1_block)
        self.initGamma()
        res_alive = np.logical_or(
            np.logical_or(
                self.getGamma('unit_1/shortcut') > self._threshold,
                self.getGamma('unit_1/conv3') > self._threshold),
            self.getGamma('unit_2/conv3') > self._threshold)

        self.gamma_flop_reg = flop_regularizer.GammaFlopsRegularizer(
            [self.net.op], self._threshold)

        expected = {}
        expected['unit_1/shortcut'] = (self.getCoeff('unit_1/shortcut') *
                                       np.sum(res_alive) * NUM_CHANNELS)
        expected['unit_1/conv1'] = (self.getCoeff('unit_1/conv1') *
                                    self.numAlive('unit_1/conv1') *
                                    NUM_CHANNELS)
        expected['unit_1/conv2'] = (self.getCoeff('unit_1/conv2') *
                                    self.numAlive('unit_1/conv2') *
                                    self.numAlive('unit_1/conv1'))
        expected['unit_1/conv3'] = (self.getCoeff('unit_1/conv3') *
                                    np.sum(res_alive) *
                                    self.numAlive('unit_1/conv2'))
        expected['unit_2/conv1'] = (self.getCoeff('unit_2/conv1') *
                                    self.numAlive('unit_2/conv1') *
                                    np.sum(res_alive))
        expected['unit_2/conv2'] = (self.getCoeff('unit_2/conv2') *
                                    self.numAlive('unit_2/conv2') *
                                    self.numAlive('unit_2/conv1'))
        expected['unit_2/conv3'] = (self.getCoeff('unit_2/conv3') *
                                    np.sum(res_alive) *
                                    self.numAlive('unit_2/conv2'))
        expected['FC'] = 2.0 * np.sum(res_alive) * 23.0

        # TODO: Is there a way to use Parametrized Tests to make this more
        # elegant?
        with self.test_session():
            for short_name in expected:
                cost = self.gamma_flop_reg.get_cost([self.getOp(short_name)
                                                     ]).eval()
                self.assertEqual(expected[short_name], cost)

            self.assertEqual(sum(expected.values()),
                             self.gamma_flop_reg.get_cost().eval())
Beispiel #9
0
    def _build_train_op(self):
        """Builds a training op.

    Returns:
      train_op: An op performing one step of training from replay data.
    """
        replay_action_one_hot = tf.one_hot(self._replay.actions,
                                           self.num_actions,
                                           1.,
                                           0.,
                                           name='action_one_hot')
        replay_chosen_q = tf.reduce_sum(self._replay_net_outputs.q_values *
                                        replay_action_one_hot,
                                        reduction_indices=1,
                                        name='replay_chosen_q')

        target = tf.stop_gradient(self._build_target_q_op())

        self.network_regularizer = flop_regularizer.GammaFlopsRegularizer(
            output_boundary=[self._net_outputs.q_values.op],
            input_boundary=[self.state_ph.op, target.op],
            gamma_threshold=1e-3)
        regularization_strength = 1e-10
        regularizer_loss = (
            self.network_regularizer.get_regularization_term() *
            regularization_strength)

        loss = tf.losses.huber_loss(target,
                                    replay_chosen_q,
                                    reduction=tf.losses.Reduction.NONE)
        if self.summary_writer is not None:
            with tf.variable_scope('Losses'):
                tf.summary.scalar('HuberLoss', tf.reduce_mean(loss))
                tf.summary.scalar('RegularizationLoss', regularizer_loss)
                tf.summary.scalar(self.network_regularizer.cost_name,
                                  self.network_regularizer.get_cost())
        return self.optimizer.minimize(tf.reduce_mean(loss) + regularizer_loss)
Beispiel #10
0
 def AddRegularizer(self, input_boundary=None):
     self.gamma_flop_reg = flop_regularizer.GammaFlopsRegularizer(
         [self.conv3.op, self.conv4.op],
         gamma_threshold=0.45,
         input_boundary=input_boundary)
Beispiel #11
0
    def test_simple_conv3d(self, threshold, expected_alive):
        # TODO(e1) remove when gamma is supported.
        # This test works if reshape not set to be a handled by
        # leaf_op_handler.LeafOpHandler() in op_handlers.py. However this changes
        # brakes other tests for reasons to be investigated.
        if SKIP_GAMMA_CONV3D:
            return

        def fused_batch_norm3d(*args, **kwargs):
            if args:
                inputs = args[0]
                args = args[1:]
            else:
                inputs = kwargs.pop('inputs')
            shape = inputs.shape
            # inputs is assumed to be NHWTC (T is for time).
            batch_size = shape[0]
            # space x time cube is reshaped to be 2D with dims:
            # H, W, T --> H, W * T
            # the idea is that batch norm only needs this to collect spacial stats.
            target_shape = [
                batch_size, shape[1], shape[2] * shape[3], shape[4]
            ]
            inputs = tf.reshape(inputs, target_shape, name='Reshape/to2d')
            normalized = slim.batch_norm(inputs, *args, **kwargs)
            return tf.reshape(normalized, shape, name='Reshape/to3d')

        gamma_val = [0.5, 0.3, 0.2]
        num_inputs = 4
        batch_size = 2
        video = tf.zeros([batch_size, 8, 8, 8, num_inputs])
        kernel = [5, 5, 5]
        num_outputs = 3
        net = slim.conv3d(video,
                          num_outputs,
                          kernel,
                          padding='SAME',
                          normalizer_fn=fused_batch_norm3d,
                          normalizer_params={
                              'scale': True,
                              'fused': True
                          },
                          scope='vconv1')
        self.assertLen(net.shape.as_list(), 5)
        shape = net.shape.as_list()
        # The number of applications is the number of elements in the [HWT] tensor.
        num_applications = shape[1] * shape[2] * shape[3]
        application_cost = num_inputs * kernel[0] * kernel[1] * kernel[2]
        name_to_var = {v.op.name: v for v in tf.global_variables()}
        flop_reg = flop_regularizer.GammaFlopsRegularizer(
            [
                net.op,
                tf.get_default_graph().get_operation_by_name('vconv1/Conv3D')
            ],
            threshold,
            force_group=[
                'vconv1/Reshape/to3d|vconv1/Reshape/to2d|vconv1/Conv3D'
            ])
        gamma = name_to_var['vconv1/BatchNorm/gamma']
        with self.session():
            tf.global_variables_initializer().run()
            gamma.assign(gamma_val).eval()

            self.assertAllClose(
                flop_reg.get_cost(),
                2 * expected_alive * num_applications * application_cost)

            raw_cost = 2 * num_outputs * num_applications * application_cost
            self.assertAllClose(flop_reg.get_regularization_term(),
                                raw_cost * np.mean(gamma_val))
Beispiel #12
0
def main(args):
    # Load MNIST Data
    train_data, test_data = tf.keras.datasets.mnist.load_data()
    X_train, y_train = train_data[0], train_data[1]
    X_test, y_test = test_data[0], test_data[1]
    global_step = tf.train.get_or_create_global_step()

    N, H, W = X_train.shape
    X_ph = tf.placeholder(tf.float32, [None, H, W, 1])
    y_ph = tf.placeholder(tf.int64, [None])

    # Defining Model
    logits, pred = mnist_model(X_ph, scope='base')
    loss_op = tf.losses.sparse_softmax_cross_entropy(labels=y_ph,
                                                     logits=logits)
    acc_op = tf.reduce_mean(tf.cast(tf.equal(pred, y_ph), tf.float32))

    # Setting Regularizer and Loss Ops
    if args.reg_type == "activation":
        network_regularizer = activation_regularizer.GammaActivationRegularizer(
            output_boundary=[logits.op],
            input_boundary=[X_ph.op, y_ph.op],
            gamma_threshold=args.gamma_threshold)
    elif args.reg_type == "flop":
        network_regularizer = flop_regularizer.GammaFlopsRegularizer(
            output_boundary=[logits.op],
            input_boundary=[X_ph.op, y_ph.op],
            gamma_threshold=args.gamma_threshold)
    elif args.reg_type == "latency":
        network_regularizer = latency_regularizer.GammaLatencyRegularizer(
            output_boundary=[logits.op],
            input_boundary=[X_ph.op, y_ph.op],
            hardware=args.hardware,
            gamma_threshold=args.gamma_threshold)

    reg_loss_op = network_regularizer.get_regularization_term(
    ) * args.reg_penalty
    cost_op = network_regularizer.get_cost()
    exporter = structure_exporter.StructureExporter(
        network_regularizer.op_regularizer_manager)

    optimizer = tf.train.AdamOptimizer(learning_rate=args.lr)
    train_op = optimizer.minimize(loss_op + reg_loss_op,
                                  global_step=global_step)

    hooks = [
        tf.train.StopAtStepHook(last_step=args.steps + 1),
        tf.train.LoggingTensorHook(tensors={
            'step': global_step,
            'loss': loss_op
        },
                                   every_n_iter=10)
    ]
    # pin GPU to be used to process local rank (one GPU per process)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    # Training Loop
    with tf.train.MonitoredTrainingSession(checkpoint_dir=args.outdir,
                                           hooks=hooks,
                                           config=config) as mon_sess:
        while not mon_sess.should_stop():
            idx = np.random.choice(N, args.batch_size, replace=False)
            x_t, y_t = np.expand_dims(X_train[idx], axis=-1), y_train[idx]
            train_dict = {X_ph: x_t, y_ph: y_t}

            val_idx = np.random.choice(X_test.shape[0], 5000, replace=False)
            x_v, y_v = np.expand_dims(X_test[val_idx],
                                      axis=-1), y_test[val_idx]
            val_dict = {X_ph: x_v, y_ph: y_v}

            global_step_val = mon_sess.run(global_step, feed_dict=train_dict)
            structure_exporter_tensors, v_loss, v_acc, reg_cost = mon_sess.run(
                [exporter.tensors, loss_op, acc_op, cost_op],
                feed_dict=val_dict)
            mon_sess.run(train_op, feed_dict=train_dict)

            print("Step: ", global_step_val)
            print("Validation Loss: ", v_loss)
            print("Validation Acc: ", v_acc)
            print("Reg Cost: ", reg_cost)

            # exporting model to JSON
            if global_step_val % 1000 == 0:
                exporter.populate_tensor_values(structure_exporter_tensors)
                exporter.create_file_and_save_alive_counts(
                    args.outdir, global_step_val)
Beispiel #13
0
  def model_fn(features, labels, mode, params=None):
    """Constructs the object detection model.

    Args:
      features: Dictionary of feature tensors, returned from `input_fn`.
      labels: Dictionary of groundtruth tensors if mode is TRAIN or EVAL,
        otherwise None.
      mode: Mode key from tf.estimator.ModeKeys.
      params: Parameter dictionary passed from the estimator.

    Returns:
      An `EstimatorSpec` that encapsulates the model and its serving
        configurations.
    """
    params = params or {}
    total_loss, train_op, detections, export_outputs = None, None, None, None
    is_training = mode == tf.estimator.ModeKeys.TRAIN

    # Make sure to set the Keras learning phase. True during training,
    # False for inference.
    tf.keras.backend.set_learning_phase(is_training)
    detection_model = detection_model_fn(is_training=is_training,
                                         add_summaries=(not use_tpu))
    scaffold_fn = None

    if mode == tf.estimator.ModeKeys.TRAIN:
      labels = unstack_batch(
          labels,
          unpad_groundtruth_tensors=train_config.unpad_groundtruth_tensors)
    elif mode == tf.estimator.ModeKeys.EVAL:
      # For evaling on train data, it is necessary to check whether groundtruth
      # must be unpadded.
      boxes_shape = (
          labels[fields.InputDataFields.groundtruth_boxes].get_shape()
          .as_list())
      unpad_groundtruth_tensors = True if boxes_shape[1] is not None else False
      labels = unstack_batch(
          labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors)

    if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
      gt_boxes_list = labels[fields.InputDataFields.groundtruth_boxes]
      gt_classes_list = labels[fields.InputDataFields.groundtruth_classes]
      gt_masks_list = None
      if fields.InputDataFields.groundtruth_instance_masks in labels:
        gt_masks_list = labels[
            fields.InputDataFields.groundtruth_instance_masks]
      gt_keypoints_list = None
      if fields.InputDataFields.groundtruth_keypoints in labels:
        gt_keypoints_list = labels[fields.InputDataFields.groundtruth_keypoints]
      gt_weights_list = None
      if fields.InputDataFields.groundtruth_weights in labels:
        gt_weights_list = labels[fields.InputDataFields.groundtruth_weights]
      if fields.InputDataFields.groundtruth_is_crowd in labels:
        gt_is_crowd_list = labels[fields.InputDataFields.groundtruth_is_crowd]
      detection_model.provide_groundtruth(
          groundtruth_boxes_list=gt_boxes_list,
          groundtruth_classes_list=gt_classes_list,
          groundtruth_masks_list=gt_masks_list,
          groundtruth_keypoints_list=gt_keypoints_list,
          groundtruth_weights_list=gt_weights_list,
          groundtruth_is_crowd_list=gt_is_crowd_list)

    preprocessed_images = features[fields.InputDataFields.image]
    prediction_dict = detection_model.predict(
        preprocessed_images, features[fields.InputDataFields.true_image_shape])
    if mode in (tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT):
      detections = detection_model.postprocess(
          prediction_dict, features[fields.InputDataFields.true_image_shape])

    if mode == tf.estimator.ModeKeys.TRAIN:
      if train_config.fine_tune_checkpoint and hparams.load_pretrained:
        if not train_config.fine_tune_checkpoint_type:
          # train_config.from_detection_checkpoint field is deprecated. For
          # backward compatibility, set train_config.fine_tune_checkpoint_type
          # based on train_config.from_detection_checkpoint.
          if train_config.from_detection_checkpoint:
            train_config.fine_tune_checkpoint_type = 'detection'
          else:
            train_config.fine_tune_checkpoint_type = 'classification'
        asg_map = detection_model.restore_map(
            fine_tune_checkpoint_type=train_config.fine_tune_checkpoint_type,
            load_all_detection_checkpoint_vars=(
                train_config.load_all_detection_checkpoint_vars))
        available_var_map = (
            variables_helper.get_variables_available_in_checkpoint(
                asg_map, train_config.fine_tune_checkpoint,
                include_global_step=False))
        if use_tpu:
          def tpu_scaffold():
            tf.train.init_from_checkpoint(train_config.fine_tune_checkpoint,
                                          available_var_map)
            return tf.train.Scaffold()
          scaffold_fn = tpu_scaffold
        else:
          tf.train.init_from_checkpoint(train_config.fine_tune_checkpoint,
                                        available_var_map)

    if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
      losses_dict = detection_model.loss(
          prediction_dict, features[fields.InputDataFields.true_image_shape])
      losses = [loss_tensor for loss_tensor in losses_dict.values()]
      if train_config.add_regularization_loss:
        regularization_losses = tf.get_collection(
            tf.GraphKeys.REGULARIZATION_LOSSES)
        if regularization_losses:
          regularization_loss = tf.add_n(regularization_losses,
                                         name='regularization_loss')
          losses.append(regularization_loss)
          losses_dict['Loss/regularization_loss'] = regularization_loss
      if train_config.flop_loss:
        gamma_thresh = 1e-3
        reg_strength = 1e-9
        print(prediction_dict)
        network_reg = flop_regularizer.GammaFlopsRegularizer([prediction_dict['class_predictions_with_background'].op], gamma_thresh)
        flop_loss = reg_strength * network_reg.get_regularization_term()
        losses_dict['Loss/flop_loss'] = flop_loss
      total_loss = tf.add_n(losses, name='total_loss')
      losses_dict['Loss/total_loss'] = total_loss

      if 'graph_rewriter_config' in configs:
        graph_rewriter_fn = graph_rewriter_builder.build(
            configs['graph_rewriter_config'], is_training=is_training)
        graph_rewriter_fn()

      # TODO(rathodv): Stop creating optimizer summary vars in EVAL mode once we
      # can write learning rate summaries on TPU without host calls.
      global_step = tf.train.get_or_create_global_step()
      training_optimizer, optimizer_summary_vars = optimizer_builder.build(
          train_config.optimizer)

    if mode == tf.estimator.ModeKeys.TRAIN:
      if use_tpu:
        training_optimizer = tf.contrib.tpu.CrossShardOptimizer(
            training_optimizer)

      # Optionally freeze some layers by setting their gradients to be zero.
      trainable_variables = None
      include_variables = (
          train_config.update_trainable_variables
          if train_config.update_trainable_variables else None)
      exclude_variables = (
          train_config.freeze_variables
          if train_config.freeze_variables else None)
      trainable_variables = tf.contrib.framework.filter_variables(
          tf.trainable_variables(),
          include_patterns=include_variables,
          exclude_patterns=exclude_variables)

      clip_gradients_value = None
      if train_config.gradient_clipping_by_norm > 0:
        clip_gradients_value = train_config.gradient_clipping_by_norm

      if not use_tpu:
        for var in optimizer_summary_vars:
          tf.summary.scalar(var.op.name, var)
        if network_reg:
          tf.summary.scalar('FLOPs', network_reg.get_cost())
        else:
          print('not adding FLOPS to summaries')
      summaries = [] if use_tpu else None
      train_op = tf.contrib.layers.optimize_loss(
          loss=total_loss,
          global_step=global_step,
          learning_rate=None,
          clip_gradients=clip_gradients_value,
          optimizer=training_optimizer,
          variables=trainable_variables,
          summaries=summaries,
          name='')  # Preventing scope prefix on all variables.

    if mode == tf.estimator.ModeKeys.PREDICT:
      export_outputs = {
          tf.saved_model.signature_constants.PREDICT_METHOD_NAME:
              tf.estimator.export.PredictOutput(detections)
      }

    eval_metric_ops = None
    scaffold = None
    if mode == tf.estimator.ModeKeys.EVAL:
      class_agnostic = (fields.DetectionResultFields.detection_classes
                        not in detections)
      groundtruth = _prepare_groundtruth_for_eval(
          detection_model, class_agnostic)
      use_original_images = fields.InputDataFields.original_image in features
      eval_images = (
          features[fields.InputDataFields.original_image] if use_original_images
          else features[fields.InputDataFields.image])
      eval_dict = eval_util.result_dict_for_single_example(
          eval_images[0:1],
          features[inputs.HASH_KEY][0],
          detections,
          groundtruth,
          class_agnostic=class_agnostic,
          scale_to_absolute=True)

      if class_agnostic:
        category_index = label_map_util.create_class_agnostic_category_index()
      else:
        category_index = label_map_util.create_category_index_from_labelmap(
            eval_input_config.label_map_path)
      img_summary = None
      if not use_tpu and use_original_images:
        detection_and_groundtruth = (
            vis_utils.draw_side_by_side_evaluation_image(
                eval_dict, category_index, max_boxes_to_draw=eval_config.max_num_boxes_to_visualize,
                min_score_thresh=eval_config.min_score_threshold,
                use_normalized_coordinates=False))
        img_summary = tf.summary.image('Detections_Left_Groundtruth_Right',
                                       detection_and_groundtruth)

      # Eval metrics on a single example.
      eval_metric_ops = eval_util.get_eval_metric_ops_for_evaluators(
          eval_config,
          list(category_index.values()),
          eval_dict)
      for loss_key, loss_tensor in iter(losses_dict.items()):
        eval_metric_ops[loss_key] = tf.metrics.mean(loss_tensor)
      for var in optimizer_summary_vars:
        eval_metric_ops[var.op.name] = (var, tf.no_op())
      if img_summary is not None:
        eval_metric_ops['Detections_Left_Groundtruth_Right'] = (
            img_summary, tf.no_op())
      eval_metric_ops = {str(k): v for k, v in eval_metric_ops.items()}

      if eval_config.use_moving_averages:
        variable_averages = tf.train.ExponentialMovingAverage(0.0)
        variables_to_restore = variable_averages.variables_to_restore()
        keep_checkpoint_every_n_hours = (
            train_config.keep_checkpoint_every_n_hours)
        saver = tf.train.Saver(
            variables_to_restore,
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)
        scaffold = tf.train.Scaffold(saver=saver)

    # EVAL executes on CPU, so use regular non-TPU EstimatorSpec.
    if use_tpu and mode != tf.estimator.ModeKeys.EVAL:
      return tf.contrib.tpu.TPUEstimatorSpec(
          mode=mode,
          scaffold_fn=scaffold_fn,
          predictions=detections,
          loss=total_loss,
          train_op=train_op,
          eval_metrics=eval_metric_ops,
          export_outputs=export_outputs)
    else:
      return tf.estimator.EstimatorSpec(
          mode=mode,
          predictions=detections,
          loss=total_loss,
          train_op=train_op,
          eval_metric_ops=eval_metric_ops,
          export_outputs=export_outputs,
          scaffold=scaffold)
def resnet_model_fn(features, labels, mode, model_class,
                    resnet_size, weight_decay, learning_rate_fn, momentum,
                    data_format, resnet_version, loss_scale,
                    loss_filter_fn=None, dtype=resnet_model.DEFAULT_DTYPE,
                    fine_tune=False, label_smoothing=0.0):
  """Shared functionality for different resnet model_fns.

  Initializes the ResnetModel representing the model layers
  and uses that model to build the necessary EstimatorSpecs for
  the `mode` in question. For training, this means building losses,
  the optimizer, and the train op that get passed into the EstimatorSpec.
  For evaluation and prediction, the EstimatorSpec is returned without
  a train op, but with the necessary parameters for the given mode.

  Args:
    features: tensor representing input images
    labels: tensor representing class labels for all input images
    mode: current estimator mode; should be one of
      `tf.estimator.ModeKeys.TRAIN`, `EVALUATE`, `PREDICT`
    model_class: a class representing a TensorFlow model that has a __call__
      function. We assume here that this is a subclass of ResnetModel.
    resnet_size: A single integer for the size of the ResNet model.
    weight_decay: weight decay loss rate used to regularize learned variables.
    learning_rate_fn: function that returns the current learning rate given
      the current global_step
    momentum: momentum term used for optimization
    data_format: Input format ('channels_last', 'channels_first', or None).
      If set to None, the format is dependent on whether a GPU is available.
    resnet_version: Integer representing which version of the ResNet network to
      use. See README for details. Valid values: [1, 2]
    loss_scale: The factor to scale the loss for numerical stability. A detailed
      summary is present in the arg parser help text.
    loss_filter_fn: function that takes a string variable name and returns
      True if the var should be included in loss calculation, and False
      otherwise. If None, batch_normalization variables will be excluded
      from the loss.
    dtype: the TensorFlow dtype to use for calculations.
    fine_tune: If True only train the dense layers(final layers).
    label_smoothing: If greater than 0 then smooth the labels.

  Returns:
    EstimatorSpec parameterized according to the input params and the
    current mode.
  """

  # Generate a summary node for the images
  tf.compat.v1.summary.image('images', features, max_outputs=6)
  # Checks that features/images have same data type being used for calculations.
  assert features.dtype == dtype

  model = model_class(resnet_size, data_format, resnet_version=resnet_version,
                      dtype=dtype)

  logits = model(features, mode == tf.estimator.ModeKeys.TRAIN)

  # Use morphnet flop regularizer
  network_regularizer = flop_regularizer.GammaFlopsRegularizer(
    output_boundary=[logits.op],
    input_boundary=[features.op, labels.op],
    gamma_threshold=1e-3
  )
  regularization_strength = 1e-2
  regularizer_loss = (network_regularizer.get_regularization_term() * regularization_strength)

  # This acts as a no-op if the logits are already in fp32 (provided logits are
  # not a SparseTensor). If dtype is is low precision, logits must be cast to
  # fp32 for numerical stability.
  logits = tf.cast(logits, tf.float32)

  predictions = {
      'classes': tf.argmax(input=logits, axis=1),
      'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
  }

  if mode == tf.estimator.ModeKeys.PREDICT:
    # Return the predictions and the specification for serving a SavedModel
    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions=predictions,
        export_outputs={
            'predict': tf.estimator.export.PredictOutput(predictions)
        })

  # Calculate loss, which includes softmax cross entropy and L2 regularization.
  if label_smoothing != 0.0:
    one_hot_labels = tf.one_hot(labels, 1001)
    cross_entropy = tf.losses.softmax_cross_entropy(
        logits=logits, onehot_labels=one_hot_labels,
        label_smoothing=label_smoothing)
  else:
    cross_entropy = tf.compat.v1.losses.sparse_softmax_cross_entropy(
        logits=logits, labels=labels)

  # Create a tensor named cross_entropy for logging purposes.
  tf.identity(cross_entropy, name='cross_entropy')
  tf.compat.v1.summary.scalar('cross_entropy', cross_entropy)

  # If no loss_filter_fn is passed, assume we want the default behavior,
  # which is that batch_normalization variables are excluded from loss.
  def exclude_batch_norm(name):
    return 'batch_normalization' not in name
  loss_filter_fn = loss_filter_fn or exclude_batch_norm

  # Add weight decay to the loss.
  l2_loss = weight_decay * tf.add_n(
      # loss is computed using fp32 for numerical stability.
      [
          tf.nn.l2_loss(tf.cast(v, tf.float32))
          for v in tf.compat.v1.trainable_variables()
          if loss_filter_fn(v.name)
      ])
  tf.compat.v1.summary.scalar('l2_loss', l2_loss)

  loss = cross_entropy + l2_loss + regularizer_loss

  if mode == tf.estimator.ModeKeys.TRAIN:
    global_step = tf.compat.v1.train.get_or_create_global_step()

    learning_rate = learning_rate_fn(global_step)

    # Create a tensor named learning_rate for logging purposes
    tf.identity(learning_rate, name='learning_rate')
    tf.compat.v1.summary.scalar('learning_rate', learning_rate)

    if flags.FLAGS.enable_lars:
      optimizer = tf.contrib.opt.LARSOptimizer(
          learning_rate,
          momentum=momentum,
          weight_decay=weight_decay,
          skip_list=['batch_normalization', 'bias'])
    else:
      optimizer = tf.compat.v1.train.MomentumOptimizer(
          learning_rate=learning_rate,
          momentum=momentum
      )

    fp16_implementation = getattr(flags.FLAGS, 'fp16_implementation', None)
    if fp16_implementation == 'graph_rewrite':
      optimizer = (
          tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite(
              optimizer, loss_scale=loss_scale))

    def _dense_grad_filter(gvs):
      """Only apply gradient updates to the final layer.

      This function is used for fine tuning.

      Args:
        gvs: list of tuples with gradients and variable info
      Returns:
        filtered gradients so that only the dense layer remains
      """
      return [(g, v) for g, v in gvs if 'dense' in v.name]

    if loss_scale != 1 and fp16_implementation != 'graph_rewrite':
      # When computing fp16 gradients, often intermediate tensor values are
      # so small, they underflow to 0. To avoid this, we multiply the loss by
      # loss_scale to make these tensor values loss_scale times bigger.
      scaled_grad_vars = optimizer.compute_gradients(loss * loss_scale)

      if fine_tune:
        scaled_grad_vars = _dense_grad_filter(scaled_grad_vars)

      # Once the gradient computation is complete we can scale the gradients
      # back to the correct scale before passing them to the optimizer.
      unscaled_grad_vars = [(grad / loss_scale, var)
                            for grad, var in scaled_grad_vars]
      minimize_op = optimizer.apply_gradients(unscaled_grad_vars, global_step)
    else:
      grad_vars = optimizer.compute_gradients(loss)
      if fine_tune:
        grad_vars = _dense_grad_filter(grad_vars)
      minimize_op = optimizer.apply_gradients(grad_vars, global_step)

    update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
    train_op = tf.group(minimize_op, update_ops)
  else:
    train_op = None

  accuracy = tf.compat.v1.metrics.accuracy(labels, predictions['classes'])
  accuracy_top_5 = tf.compat.v1.metrics.mean(
      tf.nn.in_top_k(predictions=logits, targets=labels, k=5, name='top_5_op'))
  metrics = {'accuracy': accuracy,
             'accuracy_top_5': accuracy_top_5}

  # Create a tensor named train_accuracy for logging purposes
  tf.identity(accuracy[1], name='train_accuracy')
  tf.identity(accuracy_top_5[1], name='train_accuracy_top_5')
  tf.compat.v1.summary.scalar('train_accuracy', accuracy[1])
  tf.compat.v1.summary.scalar('train_accuracy_top_5', accuracy_top_5[1])
  tf.compat.v1.summary.scalar('RegularizationLoss', regularizer_loss)
  tf.compat.v1.summary.scalar(network_regularizer.cost_name, network_regularizer.get_cost())

  return tf.estimator.EstimatorSpec(
      mode=mode,
      predictions=predictions,
      loss=loss,
      train_op=train_op,
      eval_metric_ops=metrics)