Ejemplo n.º 1
0
    def testLossCostDecorated(self):
        image = tf.constant(0.0, shape=[1, 3, 3, 3])
        kernel = tf.ones([1, 1, 3, 2])

        pred = tf.nn.conv2d(image,
                            kernel,
                            strides=[1, 1, 1, 1],
                            padding='SAME')
        conv = pred.op

        self.group_lasso_reg = flop_regularizer.GroupLassoFlopsRegularizer(
            [conv],
            0.1,
            l1_fraction=0,
            regularizer_decorator=dummy_decorator.DummyDecorator,
            decorator_parameters={'scale': 0.5})
        # we compare the computed cost and regularization calculated as follows:
        # reg_term = op_coeff * (number_of_inputs * (regularization=2 * 0.5) +
        # number_of_outputs * (input_regularization=0))
        # number_of_flops = coeff * number_of_inputs * number_of_outputs.
        with self.cached_session():
            pred_reg = self.group_lasso_reg.get_regularization_term([conv
                                                                     ]).eval()
            self.assertEqual(_coeff(conv) * 3 * 1, pred_reg)
            pred_cost = self.group_lasso_reg.get_cost([conv]).eval()
            self.assertEqual(_coeff(conv) * 2 * NUM_CHANNELS, pred_cost)
  def testFlopRegularizerWithMatMul(self):
    """Test the MatMul op regularizer with FLOP network regularizer.

    Set up a two layer fully connected network.
    """
    tf.reset_default_graph()
    tf.set_random_seed(1234)
    # Create the variables, and corresponding values.
    x = tf.constant(1.0, shape=[2, 6], name='x', dtype=tf.float32)
    w = tf.get_variable('w', shape=(6, 4), dtype=tf.float32)
    b = tf.get_variable('b', shape=(4), dtype=tf.float32)
    w2 = tf.get_variable('w2', shape=(4, 1), dtype=tf.float32)
    b2 = tf.get_variable('b2', shape=(1), dtype=tf.float32)
    w_value = np.arange(24).reshape((6, 4)).astype('float32')
    b_value = np.arange(4).reshape(4).astype('float32')
    w2_value = np.arange(21, 25).reshape((4, 1)).astype('float32')
    b2_value = np.arange(1).astype('float32')
    # Build the test network model.
    net = tf.nn.relu(tf.matmul(x, w, name='matmul1') + b)
    output = tf.nn.relu(tf.matmul(net, w2, name='matmul2') + b2)
    # Assign values to network parameters.
    with self.cached_session() as session:
      session.run([
          w.assign(w_value),
          b.assign(b_value),
          w2.assign(w2_value),
          b2.assign(b2_value)
      ])
    # Create FLOPs network regularizer.
    threshold = 32.0
    flop_reg = flop_regularizer.GroupLassoFlopsRegularizer([output.op],
                                                           threshold, 0)

    # Compute expected regularization vector and alive vector.
    def group_norm(weights, axis=(0, 1, 2)):  # pylint: disable=invalid-name
      return np.sqrt(np.mean(weights**2, axis=axis))
    expected_reg_vector1 = group_norm(w_value, axis=(0,))
    expected_reg_vector2 = group_norm(w2_value, axis=(0,))
    # Since the threshold is 32, and the L2 norm of columns in matrix w is
    # (29.66479301, 31.71750259, 33.82307053, 35.97220993). Thus, the alive
    # vector for w should be (0, 0, 1, 1). The alive vector is [1] since the L2
    # norm for w2_value is 45.055521 > 32.
    # Compute the expected FLOPs cost and expected regularization term.
    matmul1_live_input = 6
    matmul1_live_output = sum(expected_reg_vector1 > threshold)
    matmul2_live_output = sum(expected_reg_vector2 > threshold)
    expected_flop_cost = (
        _coeff(_get_op('matmul1')) * matmul1_live_input * matmul1_live_output +
        _coeff(_get_op('matmul2')) * matmul1_live_output * matmul2_live_output)
    regularizer1 = np.sum(expected_reg_vector1)
    regularizer2 = np.sum(expected_reg_vector2)
    expected_reg_term = (
        _coeff(_get_op('matmul1')) * matmul1_live_input * regularizer1 +
        _coeff(_get_op('matmul2')) * (matmul1_live_output * regularizer2 +
                                      matmul2_live_output * regularizer1))
    with self.cached_session() as session:
      self.assertEqual(
          round(flop_reg.get_cost().eval()), round(expected_flop_cost))
      self.assertNearRelatively(flop_reg.get_regularization_term().eval(),
                                expected_reg_term)
 def __init__(self, ops, regularizer_strength=1e-3, **kwargs):
     self._network_regularizer = flop_regularizer.GroupLassoFlopsRegularizer(
         ops, **kwargs)
     self._regularization_strength = regularizer_strength
     self._regularizer_loss = (
         self._network_regularizer.get_regularization_term() *
         self._regularization_strength)
     self._sess = K.get_session()
  def testFlopRegularizerDontConvertToVariable(self):
    tf.reset_default_graph()
    tf.set_random_seed(1234)

    x = tf.constant(1.0, shape=[2, 6], name='x', dtype=tf.float32)
    w = tf.Variable(tf.truncated_normal([6, 4], stddev=1.0), use_resource=True)
    net = tf.matmul(x, w)

    # Create FLOPs network regularizer.
    threshold = 0.9
    flop_reg = flop_regularizer.GroupLassoFlopsRegularizer(
        [net.op], threshold, 0, convert_to_variable=False)

    with self.cached_session():
      tf.global_variables_initializer().run()
      flop_reg.get_regularization_term().eval()
Ejemplo n.º 5
0
  def testStructureExporter(self,
                            use_batch_norm,
                            remove_common_prefix=False,
                            export=False):
    """Tests the export of alive counts.

    Args:
      use_batch_norm: A Boolean. Inidcates if batch norm should be used.
      remove_common_prefix: A Boolean, passed to StructureExporter ctor.
      export: A Boolean. Indicates if the result should be exported to test dir.
    """
    sc = self._batch_norm_scope() if use_batch_norm else []
    with tf.contrib.framework.arg_scope(sc):
      with tf.variable_scope(tf.get_variable_scope()):
        final_op = _build_model()
    variables = {v.name: v for v in tf.trainable_variables()}
    update_vars = []
    if use_batch_norm:
      network_regularizer = flop_regularizer.GammaFlopsRegularizer(
          [final_op], gamma_threshold=1e-6)
      for layer in LAYERS:
        force_alive = ALIVE['X/{}/Conv2D'.format(layer)]
        gamma = variables['X/{}/BatchNorm/gamma:0'.format(layer)]
        update_vars.append(gamma.assign(force_alive * gamma))
    else:
      network_regularizer = flop_regularizer.GroupLassoFlopsRegularizer(
          [final_op], threshold=1e-6)
      print(variables)
      for layer in LAYERS:
        force_alive = ALIVE['X/{}/Conv2D'.format(layer)]
        weights = variables['X/{}/weights:0'.format(layer)]
        update_vars.append(weights.assign(force_alive * weights))
    structure_exporter = se.StructureExporter(
        network_regularizer.op_regularizer_manager, remove_common_prefix)
    with self.cached_session() as sess:
      tf.global_variables_initializer().run()
      sess.run(update_vars)
      structure_exporter.populate_tensor_values(
          sess.run(structure_exporter.tensors))
    expected = EXPECTED_COUNTS[remove_common_prefix]
    self.assertEqual(
        expected,
        structure_exporter.get_alive_counts())
    if export:
      f = MockFile()
      structure_exporter.save_alive_counts(f)
      self.assertEqual(expected, json.loads(f.read()))
  def testFlopRegularizerExpectErrorWithConvertToVariable(self):
    """Tests the functionality of the convert_to_variable parameter.

    This test explicitly tests the failure condition. The failure condition is
    the use of a resource variable not made with tf.get_variable.
    """
    tf.reset_default_graph()
    tf.set_random_seed(1234)

    x = tf.constant(1.0, shape=[2, 6], name='x', dtype=tf.float32)
    w = tf.Variable(tf.truncated_normal([6, 4], stddev=1.0), use_resource=True)
    net = tf.matmul(x, w)

    # Create FLOPs network regularizer.
    threshold = 0.9
    with self.assertRaises(ValueError):
      _ = flop_regularizer.GroupLassoFlopsRegularizer(
          [net.op], threshold, 0, convert_to_variable=True)
Ejemplo n.º 7
0
    def test_group_lasso_conv3d(self):
        shape = [3, 3, 3]
        video = tf.zeros([2, 3, 3, 3, 1])
        net = slim.conv3d(video,
                          5,
                          shape,
                          padding='VALID',
                          weights_initializer=tf.glorot_normal_initializer(),
                          scope='vconv1')
        conv3d_op = tf.get_default_graph().get_operation_by_name(
            'vconv1/Conv3D')
        conv3d_weights = conv3d_op.inputs[1]

        threshold = 0.09
        flop_reg = flop_regularizer.GroupLassoFlopsRegularizer(
            [net.op], threshold=threshold)
        norm = tf.sqrt(tf.reduce_mean(tf.square(conv3d_weights), [0, 1, 2, 3]))
        alive = tf.reduce_sum(tf.cast(norm > threshold, tf.float32))
        with self.session():
            flop_coeff = 2 * shape[0] * shape[1] * shape[2]
            tf.compat.v1.global_variables_initializer().run()
            self.assertAllClose(flop_reg.get_cost(), flop_coeff * alive)
            self.assertAllClose(flop_reg.get_regularization_term(),
                                flop_coeff * tf.reduce_sum(norm))
Ejemplo n.º 8
0
    def testFlopRegularizer(self):
        tf.reset_default_graph()
        tf.set_random_seed(7907)
        with arg_scope([layers.conv2d, layers.conv2d_transpose],
                       weights_initializer=tf.random_normal_initializer):
            # Our test model is:
            #
            #         -> conv1 --+
            #        /           |--[concat]
            #  image --> conv2 --+
            #        \
            #         -> convt
            #
            # (the model has two "outputs", convt and concat).
            #
            image = tf.constant(0.0, shape=[1, 17, 19, NUM_CHANNELS])
            conv1 = layers.conv2d(image,
                                  13, [7, 5],
                                  padding='SAME',
                                  scope='conv1')
            conv2 = layers.conv2d(image,
                                  23, [1, 1],
                                  padding='SAME',
                                  scope='conv2')
            self.concat = tf.concat([conv1, conv2], 3)
            self.convt = layers.conv2d_transpose(image,
                                                 29, [7, 5],
                                                 stride=3,
                                                 padding='SAME',
                                                 scope='convt')
            self.name_to_var = {v.op.name: v for v in tf.global_variables()}
        with self.test_session():
            tf.global_variables_initializer().run()

        threshold = 1.0
        flop_reg = flop_regularizer.GroupLassoFlopsRegularizer(
            [self.concat.op, self.convt.op], threshold=threshold)

        with self.test_session() as s:
            evaluated_vars = s.run(self.name_to_var)

        def group_norm(weights, axis=(0, 1, 2)):  # pylint: disable=invalid-name
            return np.sqrt(np.mean(weights**2, axis=axis))

        reg_vectors = {
            'conv1': group_norm(evaluated_vars['conv1/weights'], (0, 1, 2)),
            'conv2': group_norm(evaluated_vars['conv2/weights'], (0, 1, 2)),
            'convt': group_norm(evaluated_vars['convt/weights'], (0, 1, 3))
        }

        num_alive = {
            k: np.sum(r > threshold)
            for k, r in reg_vectors.iteritems()
        }
        total_outputs = (reg_vectors['conv1'].shape[0] +
                         reg_vectors['conv2'].shape[0])
        total_alive_outputs = sum(num_alive.values())
        assert total_alive_outputs > 0, (
            'All outputs are dead - test is trivial. Decrease the threshold.')
        assert total_alive_outputs < total_outputs, (
            'All outputs are alive - test is trivial. Increase the threshold.')

        coeff1 = _coeff(_get_op('conv1/Conv2D'))
        coeff2 = _coeff(_get_op('conv2/Conv2D'))
        coefft = _coeff(_get_op('convt/conv2d_transpose'))

        expected_flop_cost = NUM_CHANNELS * (coeff1 * num_alive['conv1'] +
                                             coeff2 * num_alive['conv2'] +
                                             coefft * num_alive['convt'])
        expected_reg_term = NUM_CHANNELS * (
            coeff1 * np.sum(reg_vectors['conv1']) +
            coeff2 * np.sum(reg_vectors['conv2']) +
            coefft * np.sum(reg_vectors['convt']))
        with self.test_session():
            self.assertEqual(round(expected_flop_cost),
                             round(flop_reg.get_cost().eval()))
            self.assertNearRelatively(
                expected_reg_term,
                flop_reg.get_regularization_term().eval())
Ejemplo n.º 9
0
    def testFlopRegularizerWithContribFC(self):
        """Test MatMul Flop regularizer with tf.contrib.fully_connected layer.

    The structure of the fully connected network used in this test is the same
    with that used in testFlopRegularizerWithMatMul.
    """
        tf.reset_default_graph()
        tf.set_random_seed(1234)
        # Create test networks with tf.contrib.layers.fully_connected and initialize
        # the variables.
        with slim.arg_scope([contrib_layers.fully_connected],
                            weights_initializer=tf.random_normal_initializer,
                            biases_initializer=tf.random_normal_initializer):
            x = tf.constant(1.0, shape=[2, 6], name='x', dtype=tf.float32)
            net = contrib_layers.fully_connected(x, 4, scope='matmul1')
            net = contrib_layers.fully_connected(net, 1, scope='matmul2')
            name_to_variable = {v.op.name: v for v in tf.global_variables()}
        with self.cached_session():
            tf.global_variables_initializer().run()

        # Create FLOPs network regularizer.
        threshold = 0.9
        flop_reg = flop_regularizer.GroupLassoFlopsRegularizer([net.op],
                                                               threshold, 0)
        with self.cached_session() as session:
            evaluated_vars = session.run(name_to_variable)

        # Compute the regularizer vector for each layer.
        def group_norm(weights, axis=(0, 1, 2)):  # pylint: disable=invalid-name
            return np.sqrt(np.mean(weights**2, axis=axis))

        regularizer_vec = {
            'matmul1': group_norm(evaluated_vars['matmul1/weights'],
                                  axis=(0, )),
            'matmul2': group_norm(evaluated_vars['matmul2/weights'],
                                  axis=(0, ))
        }

        # Sanity check to make sure that not all outputs are alive or dead.
        total_outputs = (regularizer_vec['matmul1'].shape[0] +
                         regularizer_vec['matmul2'].shape[0])
        total_alive = sum(
            [np.sum(val > threshold) for val in regularizer_vec.values()])
        assert total_alive > 0, (
            'All outputs are dead. Decrease the threshold.')
        assert total_alive < total_outputs, (
            'All outputs are alive. Increase the threshold.')

        # Compute the expected flop cost and regularization term. The L2 norm of
        # columns in weight matrix of layer matmul1 is [2.15381098, 2.57671237,
        # 2.12560201, 2.2081387] and that of layer matmul2 is [1.72404861]. With
        # threshold = 2.2, there are two outputs in matmul1 layer are alive.
        matmul1_live_input = 6
        matmul1_live_output = sum(regularizer_vec['matmul1'] > threshold)
        expected_flop_cost = (_coeff(_get_op('matmul1/MatMul')) *
                              matmul1_live_input * matmul1_live_output)
        regularizer1 = np.sum(regularizer_vec['matmul1'])
        regularizer2 = np.sum(regularizer_vec['matmul2'])
        expected_reg_term = (_coeff(_get_op('matmul1/MatMul')) *
                             matmul1_live_input * regularizer1 +
                             _coeff(_get_op('matmul2/MatMul')) *
                             matmul1_live_output * regularizer2)
        with self.cached_session() as session:
            self.assertEqual(round(flop_reg.get_cost().eval()),
                             round(expected_flop_cost))
            self.assertNearRelatively(
                flop_reg.get_regularization_term().eval(), expected_reg_term)
Ejemplo n.º 10
0
def model_fn(features, labels, mode, params):
    """The model_fn argument for creating an Estimator."""
    model = create_model(params['data_format'])
    image = features
    if isinstance(image, dict):
        image = features['image']

    if mode == tf.estimator.ModeKeys.PREDICT:
        logits = model(image, training=False)
        predictions = {
            'classes': tf.argmax(logits, axis=1),
            'probabilities': tf.nn.softmax(logits),
        }
        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.PREDICT,
            predictions=predictions,
            export_outputs={
                'classify': tf.estimator.export.PredictOutput(predictions)
            })
    if mode == tf.estimator.ModeKeys.TRAIN:
        optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)

        logits = model(image, training=True)

        network_regularizer = flop_regularizer.GroupLassoFlopsRegularizer(
            output_boundary=[logits.op],
            input_boundary=[image.op, labels.op],
            threshold=1e-2)
        regularization_strength = 1e-5
        regularizer_loss = (network_regularizer.get_regularization_term() *
                            regularization_strength)

        cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=labels,
                                                               logits=logits)
        loss = cross_entropy + regularizer_loss
        # loss = cross_entropy

        accuracy = tf.metrics.accuracy(labels=labels,
                                       predictions=tf.argmax(logits, axis=1))

        # Name tensors to be logged with LoggingTensorHook.
        tf.identity(LEARNING_RATE, 'learning_rate')
        tf.identity(loss, 'cross_entropy')
        tf.identity(accuracy[1], name='train_accuracy')

        # Save accuracy scalar to Tensorboard output.
        tf.summary.scalar('train_accuracy', accuracy[1])
        tf.summary.scalar('RegularizationLoss', regularizer_loss)
        tf.summary.scalar(network_regularizer.cost_name,
                          network_regularizer.get_cost())

        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.TRAIN,
            loss=loss,
            train_op=optimizer.minimize(loss,
                                        tf.train.get_or_create_global_step()))
    if mode == tf.estimator.ModeKeys.EVAL:
        logits = model(image, training=False)
        loss = tf.losses.sparse_softmax_cross_entropy(labels=labels,
                                                      logits=logits)
        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.EVAL,
            loss=loss,
            eval_metric_ops={
                'accuracy':
                tf.metrics.accuracy(labels=labels,
                                    predictions=tf.argmax(logits, axis=1)),
            })
Ejemplo n.º 11
0
def main(_):
    if FLAGS.self_test:
        print('Running self-test.')
        train_data, train_labels = fake_data(256)
        validation_data, validation_labels = fake_data(EVAL_BATCH_SIZE)
        test_data, test_labels = fake_data(EVAL_BATCH_SIZE)
        num_epochs = 1
    else:
        # Get the data.
        train_data_filename = maybe_download('train-images-idx3-ubyte.gz')
        train_labels_filename = maybe_download('train-labels-idx1-ubyte.gz')
        test_data_filename = maybe_download('t10k-images-idx3-ubyte.gz')
        test_labels_filename = maybe_download('t10k-labels-idx1-ubyte.gz')

        # Extract it into numpy arrays.
        train_data = extract_data(train_data_filename, 60000)
        train_labels = extract_labels(train_labels_filename, 60000)
        test_data = extract_data(test_data_filename, 10000)
        test_labels = extract_labels(test_labels_filename, 10000)

        # Generate a validation set.
        validation_data = train_data[:VALIDATION_SIZE, ...]
        validation_labels = train_labels[:VALIDATION_SIZE]
        train_data = train_data[VALIDATION_SIZE:, ...]
        train_labels = train_labels[VALIDATION_SIZE:]
        num_epochs = NUM_EPOCHS
    train_size = train_labels.shape[0]

    # This is where training samples and labels are fed to the graph.
    # These placeholder nodes will be fed a batch of training data at each
    # training step using the {feed_dict} argument to the Run() call below.
    train_data_node = tf.placeholder(data_type(),
                                     shape=(BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE,
                                            NUM_CHANNELS))
    train_labels_node = tf.placeholder(tf.int64, shape=(BATCH_SIZE, ))
    eval_data = tf.placeholder(data_type(),
                               shape=(EVAL_BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE,
                                      NUM_CHANNELS))

    # The variables below hold all the trainable weights. They are passed an
    # initial value which will be assigned when we call:
    # {tf.global_variables_initializer().run()}
    conv1_weights = tf.Variable(
        tf.truncated_normal(
            [5, 5, NUM_CHANNELS, 32],  # 5x5 filter, depth 32.
            stddev=0.1,
            seed=SEED,
            dtype=data_type()))
    conv1_biases = tf.Variable(tf.zeros([32], dtype=data_type()))
    conv2_weights = tf.Variable(
        tf.truncated_normal([5, 5, 32, 64],
                            stddev=0.1,
                            seed=SEED,
                            dtype=data_type()))
    conv2_biases = tf.Variable(tf.constant(0.1, shape=[64], dtype=data_type()))
    fc1_weights = tf.Variable(  # fully connected, depth 512.
        tf.truncated_normal([IMAGE_SIZE // 4 * IMAGE_SIZE // 4 * 64, 512],
                            stddev=0.1,
                            seed=SEED,
                            dtype=data_type()))
    fc1_biases = tf.Variable(tf.constant(0.1, shape=[512], dtype=data_type()))
    fc2_weights = tf.Variable(
        tf.truncated_normal([512, NUM_LABELS],
                            stddev=0.1,
                            seed=SEED,
                            dtype=data_type()))
    fc2_biases = tf.Variable(
        tf.constant(0.1, shape=[NUM_LABELS], dtype=data_type()))

    # We will replicate the model structure for the training subgraph, as well
    # as the evaluation subgraphs, while sharing the trainable parameters.
    def model(data, train=False):
        """The Model definition."""
        # 2D convolution, with 'SAME' padding (i.e. the output feature map has
        # the same size as the input). Note that {strides} is a 4D array whose
        # shape matches the data layout: [image index, y, x, depth].
        conv = tf.nn.conv2d(data,
                            conv1_weights,
                            strides=[1, 1, 1, 1],
                            padding='SAME')
        # Bias and rectified linear non-linearity.
        relu = tf.nn.relu(tf.nn.bias_add(conv, conv1_biases))
        # Max pooling. The kernel size spec {ksize} also follows the layout of
        # the data. Here we have a pooling window of 2, and a stride of 2.
        pool = tf.nn.max_pool(relu,
                              ksize=[1, 2, 2, 1],
                              strides=[1, 2, 2, 1],
                              padding='SAME')
        conv = tf.nn.conv2d(pool,
                            conv2_weights,
                            strides=[1, 1, 1, 1],
                            padding='SAME')
        relu = tf.nn.relu(tf.nn.bias_add(conv, conv2_biases))
        pool = tf.nn.max_pool(relu,
                              ksize=[1, 2, 2, 1],
                              strides=[1, 2, 2, 1],
                              padding='SAME')
        # Reshape the feature map cuboid into a 2D matrix to feed it to the
        # fully connected layers.
        pool_shape = pool.get_shape().as_list()
        reshape = tf.reshape(
            pool,
            [pool_shape[0], pool_shape[1] * pool_shape[2] * pool_shape[3]])
        # Fully connected layer. Note that the '+' operation automatically
        # broadcasts the biases.
        hidden = tf.nn.relu(tf.matmul(reshape, fc1_weights) + fc1_biases)
        # Add a 50% dropout during training only. Dropout also scales
        # activations such that no rescaling is needed at evaluation time.
        if train:
            hidden = tf.nn.dropout(hidden, 0.5, seed=SEED)
        return tf.matmul(hidden, fc2_weights) + fc2_biases

    # Training computation: logits + cross-entropy loss.
    logits = model(train_data_node, True)

    # Use morphnet flop regularizer
    network_regularizer = flop_regularizer.GroupLassoFlopsRegularizer(
        output_boundary=[logits.op],
        input_boundary=[train_data_node.op, train_labels_node.op],
        threshold=1e-2)
    regularization_strength = 1e-5
    regularizer_loss = (network_regularizer.get_regularization_term() *
                        regularization_strength)

    loss = tf.reduce_mean(
        tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=train_labels_node, logits=logits))

    # L2 regularization for the fully connected parameters.
    regularizers = (tf.nn.l2_loss(fc1_weights) + tf.nn.l2_loss(fc1_biases) +
                    tf.nn.l2_loss(fc2_weights) + tf.nn.l2_loss(fc2_biases))
    # Add the regularization term to the loss.
    loss += 5e-4 * regularizers + regularizer_loss
    flop_cost = tf.identity(network_regularizer.get_cost())

    # Optimizer: set up a variable that's incremented once per batch and
    # controls the learning rate decay.
    batch = tf.Variable(0, dtype=data_type())
    # Decay once per epoch, using an exponential schedule starting at 0.01.
    learning_rate = tf.train.exponential_decay(
        0.01,  # Base learning rate.
        batch * BATCH_SIZE,  # Current index into the dataset.
        train_size,  # Decay step.
        0.95,  # Decay rate.
        staircase=True)
    # Use simple momentum for the optimization.
    optimizer = tf.train.MomentumOptimizer(learning_rate,
                                           0.9).minimize(loss,
                                                         global_step=batch)

    # Predictions for the current training minibatch.
    train_prediction = tf.nn.softmax(logits)

    # Predictions for the test and validation, which we'll compute less often.
    eval_prediction = tf.nn.softmax(model(eval_data))

    # Small utility function to evaluate a dataset by feeding batches of data to
    # {eval_data} and pulling the results from {eval_predictions}.
    # Saves memory and enables this to run on smaller GPUs.
    def eval_in_batches(data, sess):
        """Get all predictions for a dataset by running it in small batches."""
        size = data.shape[0]
        if size < EVAL_BATCH_SIZE:
            raise ValueError("batch size for evals larger than dataset: %d" %
                             size)
        predictions = numpy.ndarray(shape=(size, NUM_LABELS),
                                    dtype=numpy.float32)
        for begin in xrange(0, size, EVAL_BATCH_SIZE):
            end = begin + EVAL_BATCH_SIZE
            if end <= size:
                predictions[begin:end, :] = sess.run(
                    eval_prediction,
                    feed_dict={eval_data: data[begin:end, ...]})
            else:
                batch_predictions = sess.run(
                    eval_prediction,
                    feed_dict={eval_data: data[-EVAL_BATCH_SIZE:, ...]})
                predictions[begin:, :] = batch_predictions[begin - size:, :]
        return predictions

    # Create a local session to run the training.
    start_time = time.time()

    # global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.Session() as sess:
        # Run all the initializers to prepare the trainable parameters.
        tf.global_variables_initializer().run()
        print('Initialized!')
        # Loop through training steps.
        for step in xrange(int(num_epochs * train_size) // BATCH_SIZE):
            # Compute the offset of the current minibatch in the data.
            # Note that we could use better randomization across epochs.
            offset = (step * BATCH_SIZE) % (train_size - BATCH_SIZE)
            batch_data = train_data[offset:(offset + BATCH_SIZE), ...]
            batch_labels = train_labels[offset:(offset + BATCH_SIZE)]
            # This dictionary maps the batch data (as a numpy array) to the
            # node in the graph it should be fed to.
            feed_dict = {
                train_data_node: batch_data,
                train_labels_node: batch_labels
            }
            # Run the optimizer to update weights.
            sess.run(optimizer, feed_dict=feed_dict)
            # print some extra information once reach the evaluation frequency
            if step % EVAL_FREQUENCY == 0:
                # fetch some extra nodes' data
                l, lr, predictions, flops = sess.run(
                    [loss, learning_rate, train_prediction, flop_cost],
                    feed_dict=feed_dict)
                elapsed_time = time.time() - start_time
                start_time = time.time()
                print('Step %d (epoch %.2f), %.1f ms' %
                      (step, float(step) * BATCH_SIZE / train_size,
                       1000 * elapsed_time / EVAL_FREQUENCY))
                print('Minibatch loss: %.3f, learning rate: %.6f' % (l, lr))
                print('Minibatch error: %.1f%%' %
                      error_rate(predictions, batch_labels))
                print('Validation error: %.1f%%' % error_rate(
                    eval_in_batches(validation_data, sess), validation_labels))
                print('Minibatch FLOPs: %.10f' % (flops))
                sys.stdout.flush()
        # Finally print the result!
        test_error = error_rate(eval_in_batches(test_data, sess), test_labels)
        print('Test error: %.1f%%' % test_error)
        if FLAGS.self_test:
            print('test_error', test_error)
            assert test_error == 0.0, 'expected 0.0 test_error, got %.2f' % (
                test_error, )