Beispiel #1
0
 def MyNasNet(self,images,training): #定义基本模型
     arg_scope = nasnet.nasnet_mobile_arg_scope() #获取模型命名空间
     with slim.arg_scope(arg_scope): #构建模型
         logits,end_points = nasnet.build_nasnet_mobile(
                 images,num_classes = self.num_classes+1,is_training=training)
     global_step = tf.train.get_or_create_global_step() #定义记录步长的张量
     return logits,end_points,global_step
Beispiel #2
0
 def testAllEndPointsShapesMobileModel(self):
   batch_size = 5
   height, width = 224, 224
   num_classes = 1000
   inputs = tf.random_uniform((batch_size, height, width, 3))
   tf.train.create_global_step()
   with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
     _, end_points = nasnet.build_nasnet_mobile(inputs, num_classes)
   endpoints_shapes = {'Stem': [batch_size, 28, 28, 88],
                       'Cell_0': [batch_size, 28, 28, 264],
                       'Cell_1': [batch_size, 28, 28, 264],
                       'Cell_2': [batch_size, 28, 28, 264],
                       'Cell_3': [batch_size, 28, 28, 264],
                       'Cell_4': [batch_size, 14, 14, 528],
                       'Cell_5': [batch_size, 14, 14, 528],
                       'Cell_6': [batch_size, 14, 14, 528],
                       'Cell_7': [batch_size, 14, 14, 528],
                       'Cell_8': [batch_size, 7, 7, 1056],
                       'Cell_9': [batch_size, 7, 7, 1056],
                       'Cell_10': [batch_size, 7, 7, 1056],
                       'Cell_11': [batch_size, 7, 7, 1056],
                       'Reduction_Cell_0': [batch_size, 14, 14, 352],
                       'Reduction_Cell_1': [batch_size, 7, 7, 704],
                       'global_pool': [batch_size, 1056],
                       # Logits and predictions
                       'AuxLogits': [batch_size, num_classes],
                       'Logits': [batch_size, num_classes],
                       'Predictions': [batch_size, num_classes]}
   self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
   for endpoint_name in endpoints_shapes:
     tf.logging.info('Endpoint name: {}'.format(endpoint_name))
     expected_shape = endpoints_shapes[endpoint_name]
     self.assertTrue(endpoint_name in end_points)
     self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
                          expected_shape)
    def construct(self, args):
        with self.session.graph.as_default():
            # Inputs
            self.images = tf.placeholder(tf.uint8,
                                         [None, self.HEIGHT, self.WIDTH, 3],
                                         name="images")

            # Computation
            images = 2 * (
                tf.image.convert_image_dtype(self.images, tf.float32) - 0.5)

            with tf.contrib.slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
                self.output_layer, _ = nasnet.build_nasnet_mobile(
                    images, num_classes=self.CLASSES + 1, is_training=False)
            self.nasnet_saver = tf.train.Saver()

            self.predictions = tf.argmax(self.output_layer, axis=1) - 1

            # Image loading
            self.image_file = tf.placeholder(tf.string, [])
            self.image_data = tf.image.decode_image(tf.read_file(
                self.image_file),
                                                    channels=3)
            self.image_data_resized = tf.image.resize_image_with_crop_or_pad(
                self.image_data, self.HEIGHT, self.WIDTH)

            # Initialize variables
            self.session.run(tf.global_variables_initializer())
            self.nasnet_saver.restore(self.session, args.checkpoint)
Beispiel #4
0
 def testAllEndPointsShapesMobileModel(self):
   batch_size = 5
   height, width = 224, 224
   num_classes = 1000
   inputs = tf.random_uniform((batch_size, height, width, 3))
   tf.train.create_global_step()
   with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
     _, end_points = nasnet.build_nasnet_mobile(inputs, num_classes)
   endpoints_shapes = {'Stem': [batch_size, 28, 28, 88],
                       'Cell_0': [batch_size, 28, 28, 264],
                       'Cell_1': [batch_size, 28, 28, 264],
                       'Cell_2': [batch_size, 28, 28, 264],
                       'Cell_3': [batch_size, 28, 28, 264],
                       'Cell_4': [batch_size, 14, 14, 528],
                       'Cell_5': [batch_size, 14, 14, 528],
                       'Cell_6': [batch_size, 14, 14, 528],
                       'Cell_7': [batch_size, 14, 14, 528],
                       'Cell_8': [batch_size, 7, 7, 1056],
                       'Cell_9': [batch_size, 7, 7, 1056],
                       'Cell_10': [batch_size, 7, 7, 1056],
                       'Cell_11': [batch_size, 7, 7, 1056],
                       'Reduction_Cell_0': [batch_size, 14, 14, 352],
                       'Reduction_Cell_1': [batch_size, 7, 7, 704],
                       'global_pool': [batch_size, 1056],
                       # Logits and predictions
                       'AuxLogits': [batch_size, num_classes],
                       'Logits': [batch_size, num_classes],
                       'Predictions': [batch_size, num_classes]}
   self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
   for endpoint_name in endpoints_shapes:
     tf.logging.info('Endpoint name: {}'.format(endpoint_name))
     expected_shape = endpoints_shapes[endpoint_name]
     self.assertTrue(endpoint_name in end_points)
     self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
                          expected_shape)
Beispiel #5
0
 def testVariablesSetDeviceMobileModel(self):
   batch_size = 5
   height, width = 224, 224
   num_classes = 1000
   inputs = tf.random_uniform((batch_size, height, width, 3))
   tf.train.create_global_step()
   # Force all Variables to reside on the device.
   with tf.variable_scope('on_cpu'), tf.device('/cpu:0'):
     with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
       nasnet.build_nasnet_mobile(inputs, num_classes)
   with tf.variable_scope('on_gpu'), tf.device('/gpu:0'):
     with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
       nasnet.build_nasnet_mobile(inputs, num_classes)
   for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='on_cpu'):
     self.assertDeviceEqual(v.device, '/cpu:0')
   for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='on_gpu'):
     self.assertDeviceEqual(v.device, '/gpu:0')
Beispiel #6
0
 def testVariablesSetDeviceMobileModel(self):
   batch_size = 5
   height, width = 224, 224
   num_classes = 1000
   inputs = tf.random_uniform((batch_size, height, width, 3))
   tf.train.create_global_step()
   # Force all Variables to reside on the device.
   with tf.variable_scope('on_cpu'), tf.device('/cpu:0'):
     with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
       nasnet.build_nasnet_mobile(inputs, num_classes)
   with tf.variable_scope('on_gpu'), tf.device('/gpu:0'):
     with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
       nasnet.build_nasnet_mobile(inputs, num_classes)
   for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='on_cpu'):
     self.assertDeviceEqual(v.device, '/cpu:0')
   for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='on_gpu'):
     self.assertDeviceEqual(v.device, '/gpu:0')
 def testNoAuxHeadMobileModel(self):
   batch_size = 5
   height, width = 224, 224
   num_classes = 1000
   for use_aux_head in (True, False):
     tf.reset_default_graph()
     inputs = tf.random_uniform((batch_size, height, width, 3))
     tf.train.create_global_step()
     with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
       _, end_points = nasnet.build_nasnet_mobile(inputs, num_classes,
                                                  use_aux_head=use_aux_head)
     self.assertEqual('AuxLogits' in end_points, use_aux_head)
Beispiel #8
0
 def testBuildPreLogitsMobileModel(self):
   batch_size = 5
   height, width = 224, 224
   num_classes = None
   inputs = tf.random_uniform((batch_size, height, width, 3))
   tf.train.create_global_step()
   with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
     net, end_points = nasnet.build_nasnet_mobile(inputs, num_classes)
   self.assertFalse('AuxLogits' in end_points)
   self.assertFalse('Predictions' in end_points)
   self.assertTrue(net.op.name.startswith('final_layer/Mean'))
   self.assertListEqual(net.get_shape().as_list(), [batch_size, 1056])
Beispiel #9
0
    def MyNASNet(self, images, is_training):
        arg_scope = nasnet.nasnet_mobile_arg_scope()  #获得模型命名空间
        with slim.arg_scope(arg_scope):
            #构建NASNet Mobile模型
            logits, end_points = nasnet.build_nasnet_mobile(
                images,
                num_classes=self.num_classes + 1,
                is_training=is_training)

        global_step = tf.train.get_or_create_global_step()  #定义记录步数的张量

        return logits, end_points, global_step  #返回有用的张量
Beispiel #10
0
 def testBuildPreLogitsMobileModel(self):
   batch_size = 5
   height, width = 224, 224
   num_classes = None
   inputs = tf.random_uniform((batch_size, height, width, 3))
   tf.train.create_global_step()
   with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
     net, end_points = nasnet.build_nasnet_mobile(inputs, num_classes)
   self.assertFalse('AuxLogits' in end_points)
   self.assertFalse('Predictions' in end_points)
   self.assertTrue(net.op.name.startswith('final_layer/Mean'))
   self.assertListEqual(net.get_shape().as_list(), [batch_size, 1056])
 def testOverrideHParamsMobileModel(self):
   batch_size = 5
   height, width = 224, 224
   num_classes = 1000
   inputs = tf.random_uniform((batch_size, height, width, 3))
   tf.train.create_global_step()
   config = nasnet.mobile_imagenet_config()
   config.set_hparam('data_format', 'NCHW')
   with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
     _, end_points = nasnet.build_nasnet_mobile(
         inputs, num_classes, config=config)
   self.assertListEqual(
       end_points['Stem'].shape.as_list(), [batch_size, 88, 28, 28])
Beispiel #12
0
 def testNoAuxHeadMobileModel(self):
   batch_size = 5
   height, width = 224, 224
   num_classes = 1000
   for use_aux_head in (True, False):
     tf.compat.v1.reset_default_graph()
     inputs = tf.random.uniform((batch_size, height, width, 3))
     tf.compat.v1.train.create_global_step()
     config = nasnet.mobile_imagenet_config()
     config.set_hparam('use_aux_head', int(use_aux_head))
     with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
       _, end_points = nasnet.build_nasnet_mobile(inputs, num_classes,
                                                  config=config)
     self.assertEqual('AuxLogits' in end_points, use_aux_head)
Beispiel #13
0
 def testNoAuxHeadMobileModel(self):
   batch_size = 5
   height, width = 224, 224
   num_classes = 1000
   for use_aux_head in (True, False):
     tf.reset_default_graph()
     inputs = tf.random_uniform((batch_size, height, width, 3))
     tf.train.create_global_step()
     config = nasnet.mobile_imagenet_config()
     config.set_hparam('use_aux_head', int(use_aux_head))
     with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
       _, end_points = nasnet.build_nasnet_mobile(inputs, num_classes,
                                                  config=config)
     self.assertEqual('AuxLogits' in end_points, use_aux_head)
Beispiel #14
0
 def testUnknownBatchSizeMobileModel(self):
   batch_size = 1
   height, width = 224, 224
   num_classes = 1000
   with self.test_session() as sess:
     inputs = tf.placeholder(tf.float32, (None, height, width, 3))
     with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
       logits, _ = nasnet.build_nasnet_mobile(inputs, num_classes)
     self.assertListEqual(logits.get_shape().as_list(),
                          [None, num_classes])
     images = tf.random_uniform((batch_size, height, width, 3))
     sess.run(tf.global_variables_initializer())
     output = sess.run(logits, {inputs: images.eval()})
     self.assertEquals(output.shape, (batch_size, num_classes))
Beispiel #15
0
 def testEvaluationMobileModel(self):
   batch_size = 2
   height, width = 224, 224
   num_classes = 1000
   with self.test_session() as sess:
     eval_inputs = tf.random_uniform((batch_size, height, width, 3))
     with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
       logits, _ = nasnet.build_nasnet_mobile(eval_inputs,
                                              num_classes,
                                              is_training=False)
     predictions = tf.argmax(logits, 1)
     sess.run(tf.global_variables_initializer())
     output = sess.run(predictions)
     self.assertEquals(output.shape, (batch_size,))
Beispiel #16
0
 def testEvaluationMobileModel(self):
   batch_size = 2
   height, width = 224, 224
   num_classes = 1000
   with self.test_session() as sess:
     eval_inputs = tf.random_uniform((batch_size, height, width, 3))
     with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
       logits, _ = nasnet.build_nasnet_mobile(eval_inputs,
                                              num_classes,
                                              is_training=False)
     predictions = tf.argmax(logits, 1)
     sess.run(tf.global_variables_initializer())
     output = sess.run(predictions)
     self.assertEquals(output.shape, (batch_size,))
Beispiel #17
0
 def testUnknownBatchSizeMobileModel(self):
   batch_size = 1
   height, width = 224, 224
   num_classes = 1000
   with self.test_session() as sess:
     inputs = tf.placeholder(tf.float32, (None, height, width, 3))
     with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
       logits, _ = nasnet.build_nasnet_mobile(inputs, num_classes)
     self.assertListEqual(logits.get_shape().as_list(),
                          [None, num_classes])
     images = tf.random_uniform((batch_size, height, width, 3))
     sess.run(tf.global_variables_initializer())
     output = sess.run(logits, {inputs: images.eval()})
     self.assertEquals(output.shape, (batch_size, num_classes))
Beispiel #18
0
def nasnet_mobile(inputs, is_training, opts):
    with slim.arg_scope(nasnet.nasnet_mobile_arg_scope(
            weight_decay=opts.weight_decay,
            batch_norm_decay=opts.batch_norm_decay,
            batch_norm_epsilon=opts.batch_norm_epsilon)):

        config = nasnet.mobile_imagenet_config()
        config.set_hparam('dense_dropout_keep_prob', opts.dropout_keep_prob)
        config.set_hparam('use_aux_head', int(opts.create_aux_logits))

        return nasnet.build_nasnet_mobile(
            inputs,
            num_classes=opts.num_classes,
            is_training=is_training,
            config=config)
Beispiel #19
0
 def testBuildLogitsMobileModel(self):
   batch_size = 5
   height, width = 224, 224
   num_classes = 1000
   inputs = tf.random_uniform((batch_size, height, width, 3))
   tf.train.create_global_step()
   with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
     logits, end_points = nasnet.build_nasnet_mobile(inputs, num_classes)
   auxlogits = end_points['AuxLogits']
   predictions = end_points['Predictions']
   self.assertListEqual(auxlogits.get_shape().as_list(),
                        [batch_size, num_classes])
   self.assertListEqual(logits.get_shape().as_list(),
                        [batch_size, num_classes])
   self.assertListEqual(predictions.get_shape().as_list(),
                        [batch_size, num_classes])
Beispiel #20
0
 def testBuildLogitsMobileModel(self):
   batch_size = 5
   height, width = 224, 224
   num_classes = 1000
   inputs = tf.random_uniform((batch_size, height, width, 3))
   tf.train.create_global_step()
   with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
     logits, end_points = nasnet.build_nasnet_mobile(inputs, num_classes)
   auxlogits = end_points['AuxLogits']
   predictions = end_points['Predictions']
   self.assertListEqual(auxlogits.get_shape().as_list(),
                        [batch_size, num_classes])
   self.assertListEqual(logits.get_shape().as_list(),
                        [batch_size, num_classes])
   self.assertListEqual(predictions.get_shape().as_list(),
                        [batch_size, num_classes])
    def extract_features(self, preprocessed_inputs):
        """Extract features from preprocessed inputs.

    Args:
      preprocessed_inputs: a [batch, height, width, channels] float tensor
        representing a batch of images.

    Returns:
      feature_maps: a list of tensors where the ith tensor has shape
        [batch, height_i, width_i, depth_i]
    """
        preprocessed_inputs.get_shape().assert_has_rank(4)
        shape_assert = tf.Assert(
            tf.logical_and(
                tf.greater_equal(tf.shape(preprocessed_inputs)[1], 33),
                tf.greater_equal(tf.shape(preprocessed_inputs)[2], 33)),
            ['image size must at least be 33 in both height and width.'])

        feature_map_layout = {
            'from_layer': ['Cell_10', 'Cell_11', '', '', '', ''],
            'layer_depth': [-1, -1, 512, 256, 256, 128],
        }

        with tf.control_dependencies([shape_assert]):
            with slim.arg_scope(self._conv_hyperparams):
                # TODO scope is removed from call to build_basnet_mobile so it does not do anything
                with slim.arg_scope([slim.batch_norm], fused=False):
                    with tf.variable_scope('NasNetMobile',
                                           reuse=self._reuse_weights) as scope:
                        preprocessed_and_padded_inputs = ops.pad_to_multiple(
                            preprocessed_inputs, self._pad_to_multiple)
                        _, image_features = nasnet.build_nasnet_mobile(
                            preprocessed_and_padded_inputs,
                            num_classes=None,
                            final_endpoint='Cell_11')

                        feature_maps = feature_map_generators.multi_resolution_feature_maps(
                            feature_map_layout=feature_map_layout,
                            depth_multiplier=self._depth_multiplier,
                            min_depth=self._min_depth,
                            insert_1x1_conv=True,
                            image_features=image_features)

        print("Image features: ", image_features)
        print("Feature maps:", feature_maps)
        raw_input()
        return feature_maps.values()
  def _extract_proposal_features(self, preprocessed_inputs, scope):
    """Extracts first stage RPN features.

    Extracts features using the first half of the NASNet network.
    We construct the network in `align_feature_maps=True` mode, which means
    that all VALID paddings in the network are changed to SAME padding so that
    the feature maps are aligned.

    Args:
      preprocessed_inputs: A [batch, height, width, channels] float32 tensor
        representing a batch of images.
      scope: A scope name.

    Returns:
      rpn_feature_map: A tensor with shape [batch, height, width, depth]
    Raises:
      ValueError: If the created network is missing the required activation.
    """
    del scope

    if len(preprocessed_inputs.get_shape().as_list()) != 4:
      raise ValueError('`preprocessed_inputs` must be 4 dimensional, got a '
                       'tensor of shape %s' % preprocessed_inputs.get_shape())

    with slim.arg_scope(nasnet_mobile_arg_scope_for_detection(
        is_batch_norm_training=self._train_batch_norm)):
      _, end_points = nasnet.build_nasnet_mobile(
          preprocessed_inputs, num_classes=None,
          is_training=self._is_training,
          final_endpoint='Cell_11')

    # Note that both 'Cell_10' and 'Cell_11' have equal depth = 2016.
    rpn_feature_map = tf.concat([end_points['Cell_10'],
                                 end_points['Cell_11']], 3)

    # nasnet.py does not maintain the batch size in the first dimension.
    # This work around permits us retaining the batch for below.
    batch = preprocessed_inputs.get_shape().as_list()[0]
    shape_without_batch = rpn_feature_map.get_shape().as_list()[1:]
    rpn_feature_map_shape = [batch] + shape_without_batch
    rpn_feature_map.set_shape(rpn_feature_map_shape)

    return rpn_feature_map
    def __init__(self, args, seed=42):
        # Create an empty graph and a session
        graph = tf.Graph()
        graph.seed = seed
        self.session = tf.Session(
            graph=graph,
            config=tf.ConfigProto(inter_op_parallelism_threads=args.threads,
                                  intra_op_parallelism_threads=args.threads))

        with self.session.graph.as_default():
            # Inputs
            self.images = tf.placeholder(tf.uint8,
                                         [None, self.HEIGHT, self.WIDTH, 1],
                                         name="images")
            self.labels = tf.placeholder(tf.int64, [None], name="labels")
            self.is_training = tf.placeholder(tf.bool, [], name="is_training")
            self.learning_rate = tf.placeholder_with_default(0.01, None)

            images = 2 * (
                tf.tile(tf.image.convert_image_dtype(self.images, tf.float32),
                        [1, 1, 1, 3]) - 0.5)

            if args.pretrained == 'inception_v3':
                with tf.contrib.slim.arg_scope(
                        inception_v3.inception_v3_arg_scope()):
                    features, _ = inception_v3.inception_v3(images,
                                                            num_classes=None,
                                                            is_training=True)
                    features = tf.squeeze(features, [1, 2])
            else:
                with tf.contrib.slim.arg_scope(
                        nasnet.nasnet_mobile_arg_scope()):
                    features, _ = nasnet.build_nasnet_mobile(images,
                                                             num_classes=None,
                                                             is_training=True)
            self.nasnet_saver = tf.train.Saver()

            nasnet_features = features

            # Computation and training.
            #
            # The code below assumes that:
            # - loss is stored in `self.loss`
            # - training is stored in `self.training`
            # - label predictions are stored in `self.predictions`
            with tf.variable_scope('our_beloved_vars'):
                for layer_def in args.model.strip().split(';'):
                    features = get_layer(layer_def, features, self.is_training)
                output = tf.layers.dense(features,
                                         self.LABELS,
                                         activation=None)

                self.predictions = tf.argmax(output, axis=1)
                tf.losses.sparse_softmax_cross_entropy(self.labels, output)

            self.loss = tf.losses.get_total_loss()
            optimizer = tf.train.AdamOptimizer(
                learning_rate=self.learning_rate)
            self.training = tf.contrib.slim.learning.create_train_op(
                self.loss,
                optimizer,
                clip_gradient_norm=args.clip_gradient,
                variables_to_train=tf.get_collection(
                    tf.GraphKeys.TRAINABLE_VARIABLES, 'our_beloved_vars'))
            self.training_all = tf.contrib.slim.learning.create_train_op(
                self.loss,
                optimizer,
                clip_gradient_norm=args.clip_gradient,
                variables_to_train=None)

            # Summaries
            self.accuracy = tf.reduce_mean(
                tf.cast(tf.equal(self.labels, self.predictions), tf.float32))
            summary_writer = tf.contrib.summary.create_file_writer(
                args.logdir, flush_millis=10 * 1000)
            self.summaries = {}
            with summary_writer.as_default(
            ), tf.contrib.summary.record_summaries_every_n_global_steps(10):
                self.summaries["train"] = [tf.contrib.summary.scalar("train/loss", self.loss),
                                           tf.contrib.summary.scalar("train/lr", self.learning_rate),
                                           tf.contrib.summary.scalar("train/accuracy", self.accuracy)]\
                                          + variable_summaries(nasnet_features, 'pretrained')\
                                          + variable_summaries(features, 'near_output')
            with summary_writer.as_default(
            ), tf.contrib.summary.always_record_summaries():
                self.given_loss = tf.placeholder(tf.float32, [],
                                                 name="given_loss")
                self.given_accuracy = tf.placeholder(tf.float32, [],
                                                     name="given_accuracy")
                for dataset in ["dev", "test"]:
                    self.summaries[dataset] = [
                        tf.contrib.summary.scalar(dataset + "/loss",
                                                  self.given_loss),
                        tf.contrib.summary.scalar(dataset + "/accuracy",
                                                  self.given_accuracy)
                    ]

            # Initialize variables
            self.session.run(tf.global_variables_initializer())
            with summary_writer.as_default():
                tf.contrib.summary.initialize(session=self.session,
                                              graph=self.session.graph)

            self.nasnet_saver.restore(self.session,
                                      self.CHECKPOINTS[args.pretrained])
Beispiel #24
0
def batch_prediction(frame_id_to_path, frame_id_to_image_ids, image_id_to_coordinates, model, image_size, sess, \
                    debug=_prediction_debug):
    print "batch processing: " + str(len(image_id_to_coordinates))
    if model == 'inception_v1' or model == 'inception_v2' or model == 'inception_v3' or model == 'inception_v4' or \
            model == 'mobilenet_v1_0.25_128' or model == 'mobilenet_v1_0.50_160' or model == 'mobilenet_v1_1.0_224' or \
            model == 'inception_resnet_v2' or model == 'nasnet_mobile' or model == 'nasnet_large':
        preprocessing_type = 'inception'
    elif model == 'vgg_16' or model == 'resnet_v1_50' or model == 'resnet_v1_101' or model == 'resnet_v1_152':
        preprocessing_type = 'vgg'
    image_id_to_predictions = {}
    image_ids = []
    count = 0
    start_time_1 = time.time()
    for frame_id, path in frame_id_to_path.iteritems():
        frame_string = open(path, 'rb').read()
        frame = tf.image.decode_jpeg(frame_string, channels=3)
        #plt.imshow(PIL.Image.open(StringIO.StringIO(sess.run(tf.image.encode_jpeg(frame)))))
        #plt.show()
        frame_np = cv2.imread(path, cv2.IMREAD_COLOR)
        frame_height, frame_width = frame_np.shape[:2]
        #print frame_np.shape
        if preprocessing_type == 'inception':
            processed_frame = preprocess_for_inception(frame,
                                                       frame_height,
                                                       frame_width,
                                                       sess,
                                                       central_fraction=1.0,
                                                       debug=_prediction_debug)
        elif preprocessing_type == 'vgg':
            processed_frame = preprocess_for_vgg(frame,
                                                 frame_height,
                                                 frame_width,
                                                 frame_height,
                                                 sess,
                                                 debug=_prediction_debug)
        start_time = time.time()
        height, width = processed_frame.shape[:2].as_list()
        #print "Size: "+str(width)+", "+str(height)
        #plt.imshow(PIL.Image.open(StringIO.StringIO(sess.run(tf.image.encode_jpeg(tf.cast(processed_frame, tf.uint8))))))
        #plt.show()
        for image_id in frame_id_to_image_ids[frame_id]:
            fields = image_id_to_coordinates[image_id].split('\t')
            x = int(width * float(fields[0]))
            y = int(height * float(fields[1]))
            w = int(width * float(fields[2]))
            h = int(height * float(fields[3]))
            processed_image = tf.image.crop_to_bounding_box(
                processed_frame, y, x, h, w)
            if debug:
                print "object at " + str(fields)
                print str(x) + ", " + str(y) + ", " + str(w) + ", " + str(
                    h) + ", " + str(frame_height - y - h)
                if preprocessing_type == 'vgg':
                    plt.imshow(
                        PIL.Image.open(
                            StringIO.StringIO(
                                sess.run(
                                    tf.image.encode_jpeg(
                                        tf.cast(processed_image, tf.uint8))))))
                elif preprocessing_type == 'inception':
                    plt.imshow(
                        PIL.Image.open(
                            StringIO.StringIO(
                                sess.run(
                                    tf.image.encode_jpeg(
                                        tf.cast(
                                            tf.multiply(processed_image, 255),
                                            tf.uint8))))))
                plt.show()
            processed_image = tf.image.resize_images(processed_image,
                                                     (image_size, image_size))
            if debug:
                print "resized"
                if preprocessing_type == 'vgg':
                    plt.imshow(
                        PIL.Image.open(
                            StringIO.StringIO(
                                sess.run(
                                    tf.image.encode_jpeg(
                                        tf.cast(processed_image, tf.uint8))))))
                elif preprocessing_type == 'inception':
                    plt.imshow(
                        PIL.Image.open(
                            StringIO.StringIO(
                                sess.run(
                                    tf.image.encode_jpeg(
                                        tf.cast(
                                            tf.multiply(processed_image, 255),
                                            tf.uint8))))))
                plt.show()
            if count == 0:
                processed_images = tf.expand_dims(processed_image, 0)
            else:
                local_matrix = tf.expand_dims(processed_image, 0)
                processed_images = tf.concat([processed_images, local_matrix],
                                             0)
            image_ids.append(image_id)
            count = count + 1
    print "Preparation: " + str(time.time() - start_time_1) + " seconds"
    start_time = time.time()
    if model == 'inception_v1':
        logits, _ = inception.inception_v1(processed_images,
                                           num_classes=1001,
                                           is_training=False)
        init_fn = slim.assign_from_checkpoint_fn(
            os.path.join(checkpoints_dir, 'inception_v1.ckpt'),
            slim.get_model_variables('InceptionV1'))
    elif model == 'inception_v2':
        logits, _ = inception.inception_v2(processed_images,
                                           num_classes=1001,
                                           is_training=False)
        init_fn = slim.assign_from_checkpoint_fn(
            os.path.join(checkpoints_dir, 'inception_v2.ckpt'),
            slim.get_model_variables('InceptionV2'))
    elif model == 'inception_v3':
        logits, _ = inception.inception_v3(processed_images,
                                           num_classes=1001,
                                           is_training=False)
        init_fn = slim.assign_from_checkpoint_fn(
            os.path.join(checkpoints_dir, 'inception_v3.ckpt'),
            slim.get_model_variables('InceptionV3'))
    elif model == 'inception_v4':
        logits, _ = inception.inception_v4(processed_images,
                                           num_classes=1001,
                                           is_training=False)
        init_fn = slim.assign_from_checkpoint_fn(
            os.path.join(checkpoints_dir, 'inception_v4.ckpt'),
            slim.get_model_variables('InceptionV4'))
    elif model == 'resnet_v1_50':
        logits, _ = resnet_v1.resnet_v1_50(processed_images,
                                           num_classes=1000,
                                           is_training=False)
        init_fn = slim.assign_from_checkpoint_fn(
            os.path.join(checkpoints_dir, 'resnet_v1_50.ckpt'),
            slim.get_model_variables('resnet_v1_50'))
    elif model == 'resnet_v1_101':
        logits, _ = resnet_v1.resnet_v1_101(processed_images,
                                            num_classes=1000,
                                            is_training=False)
        init_fn = slim.assign_from_checkpoint_fn(
            os.path.join(checkpoints_dir, 'resnet_v1_101.ckpt'),
            slim.get_model_variables('resnet_v1_101'))
    elif model == 'resnet_v1_152':
        logits, _ = resnet_v1.resnet_v1_152(processed_images,
                                            num_classes=1000,
                                            is_training=False)
        init_fn = slim.assign_from_checkpoint_fn(
            os.path.join(checkpoints_dir, 'resnet_v1_152.ckpt'),
            slim.get_model_variables('resnet_v1_152'))
    elif model == 'mobilenet_v1_0.25_128':
        logits, _ = mobilenet_v1.mobilenet_v1(processed_images, num_classes=1001, is_training=False, \
                                              depth_multiplier=0.25)
        init_fn = slim.assign_from_checkpoint_fn(
            os.path.join(checkpoints_dir, 'mobilenet_v1_0.25_128.ckpt'),
            slim.get_model_variables('MobilenetV1'))
    elif model == 'mobilenet_v1_0.50_160':
        logits, _ = mobilenet_v1.mobilenet_v1(processed_images, num_classes=1001, is_training=False, \
                                              depth_multiplier=0.50)
        init_fn = slim.assign_from_checkpoint_fn(
            os.path.join(checkpoints_dir, 'mobilenet_v1_0.50_160.ckpt'),
            slim.get_model_variables('MobilenetV1'))
    elif model == 'mobilenet_v1_1.0_224':
        logits, _ = mobilenet_v1.mobilenet_v1(processed_images, num_classes=1001, is_training=False, \
                                              depth_multiplier=1.0)
        init_fn = slim.assign_from_checkpoint_fn(
            os.path.join(checkpoints_dir, 'mobilenet_v1_1.0_224.ckpt'),
            slim.get_model_variables('MobilenetV1'))
    elif model == 'inception_resnet_v2':
        logits, _ = inception_resnet_v2.inception_resnet_v2(processed_images,
                                                            num_classes=1001,
                                                            is_training=False)
        init_fn = slim.assign_from_checkpoint_fn(
            os.path.join(checkpoints_dir,
                         'inception_resnet_v2_2016_08_30.ckpt'),
            slim.get_model_variables('InceptionResnetV2'))
    elif model == 'nasnet_mobile':
        logits, _ = nasnet.build_nasnet_mobile(processed_images,
                                               num_classes=1001,
                                               is_training=False)
        init_fn = slim.assign_from_checkpoint_fn(
            os.path.join(checkpoints_dir, 'model.ckpt'),
            slim.get_model_variables())
    elif model == 'nasnet_large':
        logits, _ = nasnet.build_nasnet_large(processed_images,
                                              num_classes=1001,
                                              is_training=False)
        init_fn = slim.assign_from_checkpoint_fn(
            os.path.join(checkpoints_dir, 'model.ckpt'),
            slim.get_model_variables())
    elif model == 'vgg_16':
        logits, _ = vgg.vgg_16(processed_images,
                               num_classes=1000,
                               is_training=False)
        init_fn = slim.assign_from_checkpoint_fn(
            os.path.join(checkpoints_dir, 'vgg_16.ckpt'),
            slim.get_model_variables('vgg_16'))
    print "Prediction2.1: " + str(time.time() - start_time) + " seconds"
    start_time = time.time()
    init_fn(sess)
    print "Prediction2.2: " + str(time.time() - start_time) + " seconds"
    probabilities = tf.nn.softmax(logits)

    start_time = time.time()
    np_image, probabilities = sess.run([frame, probabilities])
    runtime = time.time() - start_time
    print "Prediction: " + str(runtime) + " seconds"
    for k in range(len(image_ids)):
        image_id = image_ids[k]
        predictions = []
        prob = probabilities[k, 0:]
        sorted_inds = [
            i[0] for i in sorted(enumerate(-prob), key=lambda x: x[1])
        ]
        for i in range(5):
            index = sorted_inds[i]
            if model == 'inception_v1' or model == 'inception_v2' or \
                    model == 'inception_v3' or model == 'inception_v4' or \
                    model == 'mobilenet_v1_0.25_128' or model == 'mobilenet_v1_0.50_160' or model == 'mobilenet_v1_1.0_224' or \
                    model == 'inception_resnet_v2' or model == 'nasnet_mobile' or model == 'nasnet_large':
                name = names[index]
            elif model == 'vgg_16' or model == 'resnet_v1_50' or model == 'resnet_v1_101' or model == 'resnet_v1_152':
                name = names[index + 1]
            pr = prob[index]
            pair = (name, pr)
            predictions.append(pair)
        image_id_to_predictions[image_id] = predictions
    return image_id_to_predictions, runtime, sess
def run_training(path_db,
                 pid,
                 category,
                 task_id,
                 path_unknown,
                 pretrained_dir,
                 tensorflow_dir,
                 path_save,
                 num_epochs=1000,
                 batch_size=32,
                 finetune_last_layer=False,
                 data_augmentation=True,
                 mix_up=False,
                 network_model='inception-v3',
                 restore_all_parameters=False,
                 initial_learning_rate=0.0002,
                 learning_rate_decay_factor=0.7,
                 num_epochs_before_decay=2):
    ##### start parameters for creating TFRecord files #####
    #validation_size = 0.1
    validation_size = 0.0
    num_shards = 2
    random_seed = 0
    ##### end parameters for creating TFRecord files #####

    dataset_dir = os.path.join(path_db, pid, category)
    log_dir = path_save
    tfrecord_filename = pid + '_' + category

    if _dataset_exists(dataset_dir=dataset_dir,
                       _NUM_SHARDS=num_shards,
                       output_filename=tfrecord_filename):
        print('Dataset files already exist. Overwrite them.')

    photo_filenames, class_names = _get_filenames_and_classes(
        dataset_dir, path_unknown)

    # dictionary for class name and class ID
    class_names_to_ids = dict(zip(class_names, range(len(class_names))))

    # number of validation examples
    num_validation = int(validation_size * len(photo_filenames))

    # divide to training and validation data
    random.seed(random_seed)
    random.shuffle(photo_filenames)
    training_filenames = photo_filenames[num_validation:]
    validation_filenames = photo_filenames[:num_validation]

    # find available GPU ID
    gpu_id = gpu_utils.pick_gpu_lowest_memory()

    # if log directory does not exist, create log directory and dataset
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

        print('found lowest memory gpu id : ' + str(gpu_id))
        _convert_dataset(gpu_id,
                         'train',
                         training_filenames,
                         class_names_to_ids,
                         dataset_dir=dataset_dir,
                         tfrecord_filename=tfrecord_filename,
                         _NUM_SHARDS=num_shards)
        _convert_dataset(gpu_id,
                         'validation',
                         validation_filenames,
                         class_names_to_ids,
                         dataset_dir=dataset_dir,
                         tfrecord_filename=tfrecord_filename,
                         _NUM_SHARDS=num_shards)

        labels_to_class_names = dict(zip(range(len(class_names)), class_names))
        write_label_file(labels_to_class_names, dataset_dir)

    print('finished creating dataset ' + tfrecord_filename)

    # start training
    output_label_filepath = os.path.join(dataset_dir, 'labels.txt')

    if network_model!='inception-v4' and network_model!='inception-v3' and network_model!='resnet-v2-50' and network_model!='resnet-v2-152' and \
       network_model!='vgg-16' and network_model!='mobilenet-v1' and network_model!='nasnet-large' and network_model!='nasnet-mobile':
        print("invalid network model : " + network_model)
        sys.exit()

    # find pretrained model
    if os.path.exists(os.path.join(log_dir, 'model.ckpt')):
        checkpoint_file = os.path.join(log_dir, 'model.ckpt')
    else:
        if network_model == 'inception-v4':
            checkpoint_file = os.path.join(
                pretrained_dir, 'inception_resnet_v2_2016_08_30.ckpt')
        elif network_model == 'inception-v3':
            checkpoint_file = os.path.join(pretrained_dir, 'inception_v3.ckpt')
        elif network_model == 'resnet-v2-50':
            checkpoint_file = os.path.join(pretrained_dir, 'resnet_v2_50.ckpt')
        elif network_model == 'resnet-v2-152':
            checkpoint_file = os.path.join(pretrained_dir,
                                           'resnet_v2_152.ckpt')
        elif network_model == 'vgg-16':
            checkpoint_file = os.path.join(pretrained_dir, 'vgg_16.ckpt')
        elif network_model == 'mobilenet-v1':
            checkpoint_file = os.path.join(pretrained_dir,
                                           'mobilenet_v1_1.0_224.ckpt')
        elif network_model == 'nasnet-large':
            checkpoint_file = os.path.join(pretrained_dir,
                                           'nasnet-a_large_04_10_2017',
                                           'model.ckpt')
        elif network_model == 'nasnet-mobile':
            checkpoint_file = os.path.join(pretrained_dir,
                                           'nasnet-a_mobile_04_10_2017',
                                           'model.ckpt')
        else:
            print("invalid network model : " + network_model)
            sys.exit()

    # set image size
    if network_model == 'inception-v4' or network_model == 'inception-v3' or network_model == 'resnet-v2-50' or network_model == 'resnet-v2-152':
        image_size = 299
    elif network_model == 'vgg-16' or network_model == 'mobilenet-v1' or network_model == 'nasnet-mobile':
        image_size = 224
    elif network_model == 'nasnet-large':
        image_size = 331
    else:
        print("invalid network model : " + network_model)
        sys.exit()

    # create the file pattern of TFRecord files
    file_pattern = tfrecord_filename + '_%s_*.tfrecord'
    file_pattern_for_counting = tfrecord_filename

    labels_to_name, label_list = load_labels(output_label_filepath)
    num_classes = len(label_list)

    # create a dataset discription
    items_to_descriptions = {
        'image':
        'A 3-channel RGB coloured image that is either ' +
        ','.join(label_list),
        'label':
        'A label that is as such -- ' + ','.join([
            str(key) + ':' + labels_to_name[key]
            for key in labels_to_name.keys()
        ])
    }

    # start training
    with tf.Graph().as_default() as graph:
        tf.logging.set_verbosity(tf.logging.INFO)

        # create dataset and load one batch
        dataset = get_split('train', dataset_dir, file_pattern,
                            file_pattern_for_counting, labels_to_name,
                            num_classes, items_to_descriptions)
        images, _, labels = load_batch(dataset,
                                       batch_size=batch_size,
                                       data_augmentation=data_augmentation,
                                       mix_up=mix_up,
                                       height=image_size,
                                       width=image_size)

        # number of steps to take before decaying the learning rate and batches per epoch
        num_batches_per_epoch = int(dataset.num_samples / batch_size)
        num_steps_per_epoch = num_batches_per_epoch  # because one step is one batch processed
        decay_steps = int(num_epochs_before_decay * num_steps_per_epoch)

        # create model for inference
        finetune_vars = []
        if network_model == 'inception-v4':
            with slim.arg_scope(inception_resnet_v2_arg_scope()):
                logits, end_points = inception_resnet_v2(
                    images, num_classes=dataset.num_classes, is_training=True)

            finetune_vars = [
                'InceptionResnetV2/Logits', 'InceptionResnetV2/AuxLogits'
            ]
        elif network_model == 'inception-v3':
            with slim.arg_scope(inception_v3_arg_scope()):
                logits, end_points = inception_v3(
                    images, num_classes=dataset.num_classes, is_training=True)

            finetune_vars = ['InceptionV3/Logits', 'InceptionV3/AuxLogits']
        elif network_model == 'resnet-v2-50':
            with slim.arg_scope(resnet_arg_scope()):
                logits, end_points = resnet_v2_50(
                    images, num_classes=dataset.num_classes, is_training=True)

            finetune_vars = ['resnet_v2_50/logits']
        elif network_model == 'resnet-v2-152':
            with slim.arg_scope(resnet_arg_scope()):
                logits, end_points = resnet_v2_152(
                    images, num_classes=dataset.num_classes, is_training=True)

            finetune_vars = ['resnet_v2_152/logits']
        elif network_model == 'vgg-16':
            with slim.arg_scope(vgg_arg_scope()):
                logits, _ = vgg_16(images,
                                   num_classes=dataset.num_classes,
                                   is_training=True)

            finetune_vars = ['vgg_16/fc8']
        elif network_model == 'mobilenet-v1':
            with slim.arg_scope(mobilenet_v1_arg_scope()):
                logits, end_points = mobilenet_v1(
                    images, num_classes=dataset.num_classes, is_training=True)

            finetune_vars = ['MobilenetV1/Logits']
        elif network_model == 'nasnet-large':
            with slim.arg_scope(nasnet.nasnet_large_arg_scope()):
                logits, end_points = nasnet.build_nasnet_large(
                    images, dataset.num_classes)

            finetune_vars = [
                'final_layer', 'aux_11',
                'cell_stem_0/comb_iter_0/left/global_step'
            ]
        elif network_model == 'nasnet-mobile':
            with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
                logits, end_points = nasnet.build_nasnet_mobile(
                    images, dataset.num_classes)

            finetune_vars = ['final_layer', 'aux_7']
        else:
            print("Invalid network model : " + network_model)
            sys.exit()

        # define the scopes that you want to exclude for restoration
        exclude = []
        if not restore_all_parameters:
            exclude = finetune_vars
        variables_to_restore = slim.get_variables_to_restore(exclude=exclude)

        if mix_up:
            labels.set_shape([batch_size, dataset.num_classes])
            logits.set_shape([batch_size, dataset.num_classes])
            loss = tf.losses.sigmoid_cross_entropy(labels, logits)
        else:
            # perform one-hot-encoding of the labels (Try one-hot-encoding within the load_batch function!)
            one_hot_labels = slim.one_hot_encoding(labels, dataset.num_classes)

            # performs the equivalent to tf.nn.sparse_softmax_cross_entropy_with_logits but enhanced with checks
            loss = tf.losses.softmax_cross_entropy(
                onehot_labels=one_hot_labels, logits=logits)
        total_loss = tf.losses.get_total_loss(
        )  #obtain the regularization losses as well

        # create the global step for monitoring the learning_rate and training.
        global_step = tf.train.get_or_create_global_step()

        # define your exponentially decaying learning rate
        lr = tf.train.exponential_decay(learning_rate=initial_learning_rate,
                                        global_step=global_step,
                                        decay_steps=decay_steps,
                                        decay_rate=learning_rate_decay_factor,
                                        staircase=True)

        # define optimizer
        optimizer = tf.train.AdamOptimizer(learning_rate=lr)

        # create train_op
        if finetune_last_layer:
            variables_to_train = get_variables_to_train_by_scopes(
                finetune_vars)
            print("finetune variables : " + str(variables_to_train))
            train_op = slim.learning.create_train_op(
                total_loss, optimizer, variables_to_train=variables_to_train)
        else:
            train_op = slim.learning.create_train_op(total_loss, optimizer)

        # define prediction matrix
        if network_model=='inception-v4' or network_model=='inception-v3' or network_model=='mobilenet-v1' or \
           network_model=='nasnet-large' or network_model=='nasnet-mobile':
            predictions = tf.argmax(end_points['Predictions'], 1)
            probabilities = end_points['Predictions']
        elif network_model == 'resnet-v2-50' or network_model == 'resnet-v2-152':
            predictions = tf.argmax(end_points['predictions'], 1)
            probabilities = end_points['predictions']
        elif network_model == 'vgg-16':
            predictions = tf.argmax(logits, 1)
            probabilities = tf.nn.softmax(logits)
        else:
            print("Invalid network model : " + network_model)
            sys.exit()
        if mix_up:
            argmax_labels = tf.argmax(labels, 1)
            accuracy, accuracy_update = tf.contrib.metrics.streaming_accuracy(
                predictions, argmax_labels)
        else:
            accuracy, accuracy_update = tf.contrib.metrics.streaming_accuracy(
                predictions, labels)
        metrics_op = tf.group(accuracy_update, probabilities)

        # create summaries
        tf.summary.scalar('losses/Total_Loss', total_loss)
        tf.summary.scalar('accuracy', accuracy)
        tf.summary.scalar('learning_rate', lr)
        my_summary_op = tf.summary.merge_all()

        # defube training step function that runs both the train_op, metrics_op and updates the global_step concurrently
        def train_step(sess, train_op, global_step):
            # check the time for each sess run
            start_time = time.time()
            total_loss, global_step_count, _ = sess.run(
                [train_op, global_step, metrics_op])
            time_elapsed = time.time() - start_time

            # run the logging to print some results
            logging.info('global step %s: loss: %.4f (%.2f sec/step)',
                         global_step_count, total_loss, time_elapsed)

            return total_loss, int(global_step_count)

        # create a saver function that actually restores the variables from a checkpoint file
        saver = tf.train.Saver(variables_to_restore)

        def restore_fn(sess):
            return saver.restore(sess, checkpoint_file)

        # define your supervisor for running a managed session
        sv = tf.train.Supervisor(logdir=log_dir,
                                 summary_op=None,
                                 init_fn=restore_fn)

        # run the managed session
        start_train_time = time.time()

        gpu_options = tf.ConfigProto(
            gpu_options=tf.GPUOptions(visible_device_list=str(gpu_id),
                                      per_process_gpu_memory_fraction=0.4))

        with sv.prepare_or_wait_for_session(config=gpu_options) as sess:
            for step in range(num_steps_per_epoch * num_epochs):
                # check if training task is not canceled
                if not controller.check_train_task_alive(
                        pid, category, task_id):
                    print('Training task is canceled.')
                    sv.stop()
                    return False, "", "", output_label_filepath, global_step_count

                # at the start of every epoch, show the vital information:
                if step % num_batches_per_epoch == 0:
                    logging.info('Epoch %s/%s',
                                 step / num_batches_per_epoch + 1, num_epochs)
                    learning_rate_value, accuracy_value = sess.run(
                        [lr, accuracy])
                    logging.info('Current Learning Rate: %s',
                                 learning_rate_value)
                    logging.info('Current Streaming Accuracy: %s',
                                 accuracy_value)

                    # optionally, print your logits and predictions for a sanity check that things are going fine.
                    logits_value, probabilities_value, predictions_value, labels_value = sess.run(
                        [logits, probabilities, predictions, labels])
                    print('logits: \n', logits_value)
                    print('Probabilities: \n', probabilities_value)
                    print('predictions: \n', predictions_value)
                    print('Labels:\n:', labels_value)

                # log the summaries every 10 step.
                if step % 10 == 0:
                    loss, global_step_count = train_step(
                        sess, train_op, sv.global_step)
                    summaries = sess.run(my_summary_op)
                    sv.summary_computed(sess, summaries)

                # if not, simply run the training step
                else:
                    loss, global_step_count = train_step(
                        sess, train_op, sv.global_step)

                # if specific time passes, save model for evaluation
                time_elapsed_train = time.time() - start_train_time
                print('training time : ' + str(time_elapsed_train))

            # log the final training loss and accuracy
            logging.info(
                'Training Progress : %.2f %% ',
                100.0 * step / float(num_steps_per_epoch * num_epochs))
            logging.info('Final Loss: %s', loss)
            logging.info('Global Step: %s', global_step_count)
            logging.info('Final Accuracy: %s', sess.run(accuracy))

            # after all the training has been done, save the log files and checkpoint model
            logging.info('Finished training! Saving model to disk now.')
            sv.saver.save(sess, sv.save_path, global_step=sv.global_step)

            # save graph definition file
            output_graph_filepath = os.path.join(log_dir, 'graph.pb')
            export_graph_command_exec = "./network/export_slim_graph.py"
            if not os.path.exists(export_graph_command_exec):
                print("fatal error, cannot find command : " +
                      export_graph_command_exec)
                sys.exit()
            export_graph_command_env = os.environ.copy()
            export_graph_command_env["CUDA_VISIBLE_DEVICES"] = ''
            export_graph_command = []
            export_graph_command.append(sys.executable)
            export_graph_command.append(export_graph_command_exec)
            export_graph_command.append(network_model)
            export_graph_command.append(str(dataset.num_classes))
            export_graph_command.append(output_graph_filepath)
            print("start exec:" + " ".join(export_graph_command))
            proc = subprocess.Popen(export_graph_command,
                                    env=export_graph_command_env)
            print("export graph process ID=" + str(proc.pid))
            controller.upsert_train_child_process(task_id, proc.pid)
            proc.communicate()
            controller.delete_train_child_process(task_id, proc.pid)
            print("finish exec:" + " ".join(export_graph_command))
            if not controller.check_train_task_alive(pid, category, task_id):
                print('Training task is canceled.')
                sv.stop()
                return False, "", "", output_label_filepath, global_step_count

            # save frozon graph, optimized graph, and quantized graph from graph definition and checkpoint
            latest_checkpoint_filepath = tf.train.latest_checkpoint(log_dir)

            # you can check output node name by tensorflow/tools/graph_transforms::summarize_graph
            # https://github.com/tensorflow/models/tree/master/research/slim#Export
            output_node_names = ""
            if network_model == 'inception-v4':
                output_node_names = "InceptionResnetV2/Logits/Predictions"
            elif network_model == 'inception-v3':
                output_node_names = "InceptionV3/AuxLogits/SpatialSqueeze,InceptionV3/Predictions/Reshape_1"
            elif network_model == 'resnet-v2-50':
                output_node_names = "resnet_v2_50/predictions/Reshape_1"
            elif network_model == 'resnet-v2-152':
                output_node_names = "resnet_v2_152/predictions/Reshape_1"
            elif network_model == 'vgg-16':
                output_node_names = "vgg_16/fc8/squeezed"
            elif network_model == 'mobilenet-v1':
                output_node_names = "MobilenetV1/Predictions/Reshape_1"
            elif network_model == 'nasnet-large' or network_model == 'nasnet-mobile':
                output_node_names = "final_layer/predictions"
            else:
                print("Invalid network model : " + network_model)
                sys.exit()

            output_frozen_graph_filepath = os.path.join(
                log_dir, 'frozen_graph.pb')
            freeze_graph_command_exec = os.path.join(
                tensorflow_dir,
                "bazel-bin/tensorflow/python/tools/freeze_graph")
            if not os.path.exists(freeze_graph_command_exec):
                print("fatal error, cannot find command : " +
                      freeze_graph_command_exec)
                sys.exit()
            freeze_graph_command_env = os.environ.copy()
            freeze_graph_command_env["CUDA_VISIBLE_DEVICES"] = ''
            freeze_graph_command = []
            freeze_graph_command.append(freeze_graph_command_exec)
            freeze_graph_command.append("--input_graph=" +
                                        output_graph_filepath)
            freeze_graph_command.append("--input_checkpoint=" +
                                        latest_checkpoint_filepath)
            freeze_graph_command.append("--input_binary=true")
            freeze_graph_command.append("--output_graph=" +
                                        output_frozen_graph_filepath)
            freeze_graph_command.append("--output_node_names=" +
                                        output_node_names)
            print("start exec:" + " ".join(freeze_graph_command))
            proc = subprocess.Popen(freeze_graph_command,
                                    env=freeze_graph_command_env)
            print("freeze graph process ID=" + str(proc.pid))
            controller.upsert_train_child_process(task_id, proc.pid)
            proc.communicate()
            controller.delete_train_child_process(task_id, proc.pid)
            print("finish exec:" + " ".join(freeze_graph_command))
            if not controller.check_train_task_alive(pid, category, task_id):
                print('Training task is canceled.')
                sv.stop()
                return False, "", "", output_label_filepath, global_step_count

            output_optimized_graph_filepath = os.path.join(
                log_dir, 'optimized_graph.pb')
            optimize_graph_command_exec = os.path.join(
                tensorflow_dir,
                "bazel-bin/tensorflow/python/tools/optimize_for_inference")
            if not os.path.exists(optimize_graph_command_exec):
                print("fatal error, cannot find command : " +
                      optimize_graph_command_exec)
                sys.exit()
            optimize_graph_command_env = os.environ.copy()
            optimize_graph_command_env["CUDA_VISIBLE_DEVICES"] = ''
            optimize_graph_command = []
            optimize_graph_command.append(optimize_graph_command_exec)
            optimize_graph_command.append("--input=" +
                                          output_frozen_graph_filepath)
            optimize_graph_command.append("--output=" +
                                          output_optimized_graph_filepath)
            optimize_graph_command.append("--input_names=input")
            optimize_graph_command.append("--output_names=" +
                                          output_node_names)
            optimize_graph_command.append("--frozen_graph=true")
            print("start exec:" + " ".join(optimize_graph_command))
            proc = subprocess.Popen(optimize_graph_command,
                                    env=optimize_graph_command_env)
            print("optimize graph process ID=" + str(proc.pid))
            controller.upsert_train_child_process(task_id, proc.pid)
            proc.communicate()
            controller.delete_train_child_process(task_id, proc.pid)
            print("finish exec:" + " ".join(optimize_graph_command))
            if not controller.check_train_task_alive(pid, category, task_id):
                print('Training task is canceled.')
                sv.stop()
                return False, "", "", output_label_filepath, global_step_count

            output_quantized_graph_filepath = os.path.join(
                log_dir, 'quantized_graph.pb')
            quantize_graph_command_exec = os.path.join(
                tensorflow_dir,
                "bazel-bin/tensorflow/tools/quantization/quantize_graph")
            if not os.path.exists(quantize_graph_command_exec):
                print("fatal error, cannot find command : " +
                      quantize_graph_command_exec)
                sys.exit()
            quantize_graph_command_env = os.environ.copy()
            quantize_graph_command_env["CUDA_VISIBLE_DEVICES"] = ''
            quantize_graph_command = []
            quantize_graph_command.append(quantize_graph_command_exec)
            quantize_graph_command.append("--input=" +
                                          output_optimized_graph_filepath)
            quantize_graph_command.append("--output=" +
                                          output_quantized_graph_filepath)
            quantize_graph_command.append("--input_node_names=input")
            quantize_graph_command.append("--output_node_names=" +
                                          output_node_names)
            quantize_graph_command.append("--mode=eightbit")
            print("start exec:" + " ".join(quantize_graph_command))
            proc = subprocess.Popen(quantize_graph_command,
                                    env=quantize_graph_command_env)
            print("quantize graph process ID=" + str(proc.pid))
            controller.upsert_train_child_process(task_id, proc.pid)
            proc.communicate()
            controller.delete_train_child_process(task_id, proc.pid)
            print("finish exec:" + " ".join(quantize_graph_command))
            if not controller.check_train_task_alive(pid, category, task_id):
                print('Training task is canceled.')
                sv.stop()
                return False, "", "", output_label_filepath, global_step_count

    return True, output_optimized_graph_filepath, output_quantized_graph_filepath, output_label_filepath, global_step_count
Beispiel #26
0
with open('中文标签.csv', 'r+') as f:  #打开文件
    labels = list(map(getone, list(f)))
    print(len(labels), type(labels), labels[:5])

checkpoint_file = r'./nasnet-a_mobile_04_10_2017/model.ckpt'  #定义模型路径
sample_images = ['hy.jpg', 'ps.jpg', '72.jpg']  #定义待测试图片路径
input_imgs = tf.placeholder(tf.float32,
                            [None, image_size, image_size, 3])  #定义占位符

x1 = 2 * (input_imgs / 255.0) - 1.0  #归一化图片

arg_scope = nasnet.nasnet_mobile_arg_scope()  #获得模型命名空间
with slim.arg_scope(arg_scope):
    logits, end_points = nasnet.build_nasnet_mobile(x1,
                                                    num_classes=1001,
                                                    is_training=False)
    prob = end_points['Predictions']
    y = tf.argmax(prob, axis=1)  #获得结果的输出节点

saver = tf.train.Saver()  #定义saver,用于加载模型
with tf.Session() as sess:  #建立会话
    saver.restore(sess, checkpoint_file)  #载入模型

    def preimg(img):  #定义图片预处理函数
        ch = 3
        if img.mode == 'RGBA':  #兼容RGBA图片
            ch = 4

        imgnp = np.asarray(img.resize((image_size, image_size)),
                           dtype=np.float32).reshape(image_size, image_size,