def get_style_features(FLAGS):
    """
    For the "style_image", the preprocessing step is:
    1. Resize the shorter side to FLAGS.image_size
    2. Apply central crop
    """
    with tf.Graph().as_default():
        network_fn = nets_factory.get_network_fn(
            FLAGS.loss_model,
            num_classes=1,
            is_training=False)
        image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing(
            FLAGS.loss_model,
            is_training=False)

        # Get the style image data
        size = FLAGS.image_size
        img_bytes = tf.read_file(FLAGS.style_image)
        if FLAGS.style_image.lower().endswith('png'):
            image = tf.image.decode_png(img_bytes)
        else:
            image = tf.image.decode_jpeg(img_bytes)
        # image = _aspect_preserving_resize(image, size)

        # Add the batch dimension
        images = tf.expand_dims(image_preprocessing_fn(image, size, size), 0)
        # images = tf.stack([image_preprocessing_fn(image, size, size)])

        _, endpoints_dict = network_fn(images, spatial_squeeze=False)
        features = []
        for layer in FLAGS.style_layers:
            feature = endpoints_dict[layer]
            feature = tf.squeeze(gram(feature), [0])  # remove the batch dimension
            features.append(feature)

        with tf.Session() as sess:
            # Restore variables for loss network.
            init_func = utils._get_init_fn(FLAGS)
            init_func(sess)

            # Make sure the 'generated' directory is exists.
            if os.path.exists('generated') is False:
                os.makedirs('generated')
            # Indicate cropped style image path
            save_file = 'generated/target_style_' + FLAGS.naming + '.jpg'
            # Write preprocessed style image to indicated path
            with open(save_file, 'wb') as f:
                target_image = image_unprocessing_fn(images[0, :])
                value = tf.image.encode_jpeg(tf.cast(target_image, tf.uint8))
                f.write(sess.run(value))
                tf.logging.info('Target style pattern is saved to: %s.' % save_file)

            # Return the features those layers are use for measuring style loss.
            return sess.run(features)
 def testGetNetworkFnArgScope(self):
   batch_size = 5
   num_classes = 10
   net = 'cifarnet'
   with self.test_session(use_gpu=True):
     net_fn = nets_factory.get_network_fn(net, num_classes)
     image_size = getattr(net_fn, 'default_image_size', 224)
     with slim.arg_scope([slim.model_variable, slim.variable],
                         device='/CPU:0'):
       inputs = tf.random_uniform((batch_size, image_size, image_size, 3))
       net_fn(inputs)
     weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'CifarNet/conv1')[0]
     self.assertDeviceEqual('/CPU:0', weights.device)
 def testGetNetworkFn(self):
   batch_size = 5
   num_classes = 1000
   for net in nets_factory.networks_map:
     with self.test_session():
       net_fn = nets_factory.get_network_fn(net, num_classes)
       # Most networks use 224 as their default_image_size
       image_size = getattr(net_fn, 'default_image_size', 224)
       inputs = tf.random_uniform((batch_size, image_size, image_size, 3))
       logits, end_points = net_fn(inputs)
       self.assertTrue(isinstance(logits, tf.Tensor))
       self.assertTrue(isinstance(end_points, dict))
       self.assertEqual(logits.get_shape().as_list()[0], batch_size)
       self.assertEqual(logits.get_shape().as_list()[-1], num_classes)
示例#4
0
 def testGetNetworkFnVideoModels(self):
   batch_size = 5
   num_classes = 400
   for net in ['i3d', 's3dg']:
     with tf.Graph().as_default() as g, self.test_session(g):
       net_fn = nets_factory.get_network_fn(net, num_classes=num_classes)
       # Most networks use 224 as their default_image_size
       image_size = getattr(net_fn, 'default_image_size', 224) // 2
       inputs = tf.random_uniform(
           (batch_size, 10, image_size, image_size, 3))
       logits, end_points = net_fn(inputs)
       self.assertTrue(isinstance(logits, tf.Tensor))
       self.assertTrue(isinstance(end_points, dict))
       self.assertEqual(logits.get_shape().as_list()[0], batch_size)
       self.assertEqual(logits.get_shape().as_list()[-1], num_classes)
示例#5
0
def main(_):
  if not FLAGS.output_file:
    raise ValueError('You must supply the path to save to with --output_file')
  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default() as graph:
    dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train',
                                          FLAGS.dataset_dir)
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        is_training=FLAGS.is_training)
    image_size = FLAGS.image_size or network_fn.default_image_size
    placeholder = tf.placeholder(name='input', dtype=tf.float32,
                                 shape=[1, image_size, image_size, 3])
    network_fn(placeholder)
    graph_def = graph.as_graph_def()
    with gfile.GFile(FLAGS.output_file, 'wb') as f:
      f.write(graph_def.SerializeToString())
示例#6
0
def get_style_features(FLAGS):
    """
    对于风格图片,预处理步骤:
    1. Resize the shorter side to FLAGS.image_size
    2. Apply central crop
    """
    with tf.Graph().as_default():
        network_fn = nets_factory.get_network_fn(
            FLAGS.loss_model,
            num_classes=1,
            is_training=False)
        image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing(
            FLAGS.loss_model,
            is_training=False)

        size = FLAGS.image_size
        img_bytes = tf.read_file(FLAGS.style_image)
        if FLAGS.style_image.lower().endswith('png'):
            image = tf.image.decode_png(img_bytes)
        else:
            image = tf.image.decode_jpeg(img_bytes)
        # image = _aspect_preserving_resize(image, size)
        images = tf.stack([image_preprocessing_fn(image, size, size)])
        _, endpoints_dict = network_fn(images, spatial_squeeze=False)
        features = []
        for layer in FLAGS.style_layers:
            feature = endpoints_dict[layer]
            feature = tf.squeeze(gram(feature), [0])  # remove the batch dimension
            features.append(feature)

        with tf.Session() as sess:
            init_func = utils._get_init_fn(FLAGS)
            init_func(sess)
            if os.path.exists('static/img/generated') is False:
                os.makedirs('static/img/generated')
            save_file = 'static/img/generated/target_style_' + FLAGS.naming + '.jpg'
            with open(save_file, 'wb') as f:
                target_image = image_unprocessing_fn(images[0, :])
                value = tf.image.encode_jpeg(tf.cast(target_image, tf.uint8))
                f.write(sess.run(value))
                tf.logging.info('Target style pattern is saved to: %s.' % save_file)
            return sess.run(features)
def main(_):
  if not FLAGS.output_file:
    raise ValueError('You must supply the path to save to with --output_file')
  if FLAGS.is_video_model and not FLAGS.num_frames:
    raise ValueError(
        'Number of frames must be specified for video models with --num_frames')
  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default() as graph:
    dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train',
                                          FLAGS.dataset_dir)
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        is_training=FLAGS.is_training)
    image_size = FLAGS.image_size or network_fn.default_image_size
    if FLAGS.is_video_model:
      input_shape = [FLAGS.batch_size, FLAGS.num_frames,
                     image_size, image_size, 3]
    else:
      input_shape = [FLAGS.batch_size, image_size, image_size, 3]
    placeholder = tf.placeholder(name='input', dtype=tf.float32,
                                 shape=input_shape)
    network_fn(placeholder)

    if FLAGS.quantize:
      tf.contrib.quantize.create_eval_graph()

    graph_def = graph.as_graph_def()
    if FLAGS.write_text_graphdef:
      tf.io.write_graph(
          graph_def,
          os.path.dirname(FLAGS.output_file),
          os.path.basename(FLAGS.output_file),
          as_text=True)
    else:
      with gfile.GFile(FLAGS.output_file, 'wb') as f:
        f.write(graph_def.SerializeToString())
def train():
    with tf.Graph().as_default():
        ######################
        # Config model_deploy#
        ######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ####################
        # Select the network #
        ####################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=FLAGS.NUM_CLASSES,
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        if hasattr(network_fn, 'rnn_part'):
            load_batch_size = FLAGS.batch_size * deploy_config.num_clones
        else:
            load_batch_size = FLAGS.batch_size
        with tf.device(deploy_config.inputs_device()):
            dataset_size, images, labels, video_name = async_loader.video_inputs(
                FLAGS.dataset_list,
                FLAGS.dataset_dir,
                FLAGS.resize_image_size,
                FLAGS.train_image_size,
                load_batch_size,
                FLAGS.n_steps,
                FLAGS.modality,
                FLAGS.read_stride,
                image_preprocessing_fn,
                shuffle=True,
                label_from_one=(FLAGS.labels_offset > 0),
                length1=FLAGS.length,
                crop=2,
                merge_label=FLAGS.merge_label)
            labels = slim.one_hot_encoding(labels, FLAGS.NUM_CLASSES)
            if hasattr(network_fn, 'rnn_part'):
                assert load_batch_size % FLAGS.n_steps == 0
                total_video_num = int(load_batch_size / FLAGS.n_steps)
                # Split images and labels for cnn
                split_images = tf.split(images, deploy_config.num_clones, 0)
                cnn_labels = labels
                if FLAGS.merge_label:
                    cnn_labels = tf.reshape(cnn_labels,
                                            [1, -1, FLAGS.NUM_CLASSES])
                    cnn_labels = tf.tile(cnn_labels, [FLAGS.n_steps, 1, 1])
                    cnn_labels = tf.reshape(cnn_labels,
                                            [-1, FLAGS.NUM_CLASSES])
                split_cnn_labels = tf.split(cnn_labels,
                                            deploy_config.num_clones, 0)
                # Split labels for rnn
                if not FLAGS.merge_label:
                    split_rnn_labels = tf.reshape(
                        labels,
                        [FLAGS.n_steps, total_video_num, FLAGS.NUM_CLASSES])
                    assert total_video_num % deploy_config.num_clones == 0
                    split_rnn_labels = tf.split(split_rnn_labels,
                                                deploy_config.num_clones, 1)
                    each_video_num = int(total_video_num /
                                         deploy_config.num_clones)
                    split_rnn_labels = [
                        tf.reshape(label, [
                            FLAGS.n_steps * each_video_num, FLAGS.NUM_CLASSES
                        ]) for label in split_rnn_labels
                    ]
                else:
                    split_rnn_labels = tf.split(labels,
                                                deploy_config.num_clones, 0)
            else:
                batch_queue = slim.prefetch_queue.prefetch_queue(
                    [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        if hasattr(network_fn, 'rnn_part'):
            cnn_outputs = []
            end_point_outputs = []

            def clone_bn_part(split_batchs, split_cnn_labels, cnn_outputs,
                              end_point_outputs):
                batch = split_batchs[0]
                split_batchs.remove(batch)
                logits, end_points = network_fn(batch)
                cnn_outputs.append(logits)
                end_point_outputs.append(end_points)
                labels = split_cnn_labels[0]
                split_cnn_labels.remove(labels)
                #############################
                # Specify the loss function #
                #############################
                if 'AuxLogits' in end_points:
                    tf.losses.softmax_cross_entropy(
                        logits=end_points['AuxLogits'],
                        onehot_labels=labels,
                        label_smoothing=FLAGS.label_smoothing,
                        weights=0.4,
                        scope='aux_loss')
                return end_points

            def clone_rnn(cnn_outputs, split_rnn_labels, end_point_outputs):
                cnn_output = cnn_outputs[0]
                cnn_outputs.remove(cnn_output)
                end_point_output = end_point_outputs[0]
                end_point_outputs.remove(end_point_output)
                labels = split_rnn_labels[0]
                split_rnn_labels.remove(labels)
                logits, end_points = network_fn.rnn_part(cnn_output)
                end_points.update(end_point_output)
                #############################
                # Specify the loss function #
                #############################
                tf.losses.softmax_cross_entropy(
                    logits=logits,
                    onehot_labels=labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=1.0)
                return end_points

            # Run BN part, CNN and RNN should have different labels because of the different sample order
            model_deploy.create_clones(deploy_config,
                                       clone_bn_part, [
                                           split_images, split_cnn_labels,
                                           cnn_outputs, end_point_outputs
                                       ],
                                       gpu_offset=1)
            # Merge on another GPU to avoid transport data back to original GPUs
            assert len(
                model_deploy.get_available_gpus()) > deploy_config.num_clones
            with tf.device(deploy_config.clone_device(0)):
                # Concat all clones to one tensor
                cnn_outputs = tf.concat(values=cnn_outputs, axis=0)
                output_shape = cnn_outputs.get_shape().as_list()
                # Reshape to expose the video number dimension
                cnn_outputs = tf.reshape(cnn_outputs,
                                         [FLAGS.n_steps, total_video_num] +
                                         output_shape[1:])
                # Split in the video number dimension, so that each clone has an input for lstm
                cnn_outputs = tf.split(cnn_outputs, deploy_config.num_clones,
                                       1)
                # Merge n_steps and video number dimension
                cnn_outputs = [
                    tf.reshape(output, [-1] + output_shape[1:])
                    for output in cnn_outputs
                ]
            # Run RNN part on another GPU #deploy_config.num_clones
            #  clones = model_deploy.create_extra_clones_on_another_gpu(deploy_config, clone_rnn,
            #  [cnn_outputs, split_rnn_labels, end_point_outputs])
            clones = model_deploy.create_clones(
                deploy_config,
                clone_rnn, [cnn_outputs, split_rnn_labels, end_point_outputs],
                gpu_offset=1)
        else:

            def clone_fn(batch_queue):
                """Allows data parallelism by creating multiple clones of network_fn."""
                images, labels = batch_queue.dequeue()
                logits, end_points = network_fn(images)
                #############################
                # Specify the loss function #
                #############################
                if 'AuxLogits' in end_points:
                    tf.losses.softmax_cross_entropy(
                        logits=end_points['AuxLogits'],
                        onehot_labels=labels,
                        label_smoothing=FLAGS.label_smoothing,
                        weights=0.4,
                        scope='aux_loss')
                tf.losses.softmax_cross_entropy(
                    logits=logits,
                    onehot_labels=labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=1.0)
                return end_points

            clones = model_deploy.create_clones(deploy_config, clone_fn,
                                                [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = train_util._configure_learning_rate(global_step)
            optimizer = train_util._configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables,
                total_num_replicas=FLAGS.worker_replicas)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = train_util._get_variables_to_train()
        # Variables to restore and decay
        variables_to_restore = train_util._get_variables_to_restore()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones,
            optimizer,
            var_list=variables_to_train,
            gate_gradients=optimizer.GATE_OP,
            colocate_gradients_with_ops=True)

        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Gradient decay and clipping
        if not FLAGS.no_decay:
            # Set up learning rate decay
            lr_mul = {var: 0.1 for var in variables_to_restore}
            clones_gradients = tf.contrib.slim.learning.multiply_gradients(
                clones_gradients, lr_mul)
        if FLAGS.grad_clipping is not None:
            clones_gradients = tf.contrib.slim.learning.clip_gradient_norms(
                clones_gradients, FLAGS.grad_clipping)

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op],
                                                          total_loss,
                                                          name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        ###########################
        # Kicks off the training. #
        ###########################
        sess_config = tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement)
        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            init_fn=train_util._get_init_fn(variables_to_restore),
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            sync_optimizer=optimizer if FLAGS.sync_replicas else None,
            trace_every_n_steps=FLAGS.trace_every_n_steps,
            session_config=sess_config)
示例#9
0
    TFRECORD_FILE)

#打乱顺序
image_batch, image_batch_raw, label0_batch, label1_batch, label2_batch, label3_batch = tf.train.shuffle_batch(
    [image, image_raw, label0, label1, label2, label3],
    batch_size=BATCH_SIZE,  #批次大小
    capacity=10000,  #队列大小
    min_after_dequeue=2000,  #最小队列个数
    num_threads=1,  #线程数
)

# 'alexnet_v2_captcha_multi'
# 定义网络结构
train_network_fn = nets_factory.get_network_fn(
    'alexnet_v2',
    num_classes=CHAR_SET_LEN,  #要数据结果个数 10个数字 默认10000
    weight_decay=0.0005,
    is_training=False,  # 是否需要训练
)

with tf.Session() as sess:

    # end_points
    X = tf.reshape(x, [BATCH_SIZE, 224, 224, 1])
    logits0, logits1, logits2, logits3, end_points = train_network_fn(X)

    # 预测
    predict0 = tf.reshape(logits0, [-1, CHAR_SET_LEN])
    predict0 = tf.argmax(predict0, 1)

    predict1 = tf.reshape(logits1, [-1, CHAR_SET_LEN])
    predict1 = tf.argmax(predict1, 1)
def mobilenet_v1(inputs,
                 alpha,
                 sigma,
                 bayer_mask,
                 num_layers=17,
                 num_iters=1,
                 use_anscombe=True,
                 noise_channel=True,
                 isp_model_name=None,
                 num_classes=1001,
                 dropout_keep_prob=0.999,
                 is_training=True,
                 min_depth=8,
                 depth_multiplier=1.0,
                 conv_defs=None,
                 prediction_fn=tf.contrib.layers.softmax,
                 spatial_squeeze=True,
                 reuse=None,
                 scope='MobilenetV1',
                 is_real_data=False):
    """Joint ISP + Mobilenet v1 model for classification.

  Args:
    inputs: a tensor of shape [batch_size, height, width, channels].
    num_classes: number of predicted classes.
    dropout_keep_prob: the percentage of activation values that are retained.
    is_training: whether is training or not.
    min_depth: Minimum depth value (number of channels) for all convolution ops.
      Enforced when depth_multiplier < 1, and not an active constraint when
      depth_multiplier >= 1.
    depth_multiplier: Float multiplier for the depth (number of channels)
      for all convolution ops. The value must be greater than zero. Typical
      usage will be to set this value in (0, 1) to reduce the number of
      parameters or computation cost of the model.
    conv_defs: A list of ConvDef namedtuples specifying the net architecture.
    prediction_fn: a function to get predictions out of logits.
    spatial_squeeze: if True, logits is of shape is [B, C], if false logits is
        of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
    reuse: whether or not the network and its variables should be reused. To be
      able to reuse 'scope' must be given.
    scope: Optional variable_scope.

  Returns:
    logits: the pre-softmax activations, a tensor of size
      [batch_size, num_classes]
    end_points: a dictionary from components of the network to the corresponding
      activation.

  Raises:
    ValueError: Input rank is invalid.
  """
    input_shape = inputs.get_shape().as_list()
    if len(input_shape) != 4:
        raise ValueError('Invalid input tensor rank, expected 4, was: %d' %
                         len(input_shape))

    end_points = {}
    # IMPORT HACK.
    from nets import nets_factory
    network_fn = nets_factory.get_network_fn(isp_model_name,
                                             num_classes=num_classes,
                                             weight_decay=0.0,
                                             batch_norm_decay=0.95,
                                             is_training=is_training)

    clean_image, isp_end_points = network_fn(inputs,
                                             alpha=alpha,
                                             sigma=sigma,
                                             bayer_mask=bayer_mask,
                                             use_anscombe=use_anscombe,
                                             noise_channel=noise_channel,
                                             num_layers=num_layers,
                                             num_iters=num_iters,
                                             is_real_data=is_real_data)
    end_points.update(isp_end_points)
    clean_image = tf.maximum(clean_image, 0.0)
    clean_image = tf.minimum(clean_image, 1.0)
    end_points['mobilenet_input'] = clean_image
    clean_image -= 0.5
    clean_image *= 2.0

    with tf.variable_scope(scope,
                           'MobilenetV1', [inputs, num_classes],
                           reuse=reuse) as scope:
        with slim.arg_scope([slim.batch_norm, slim.dropout],
                            is_training=is_training):
            net, mobilenet_end_points = mobilenet_v1_base(
                clean_image,
                scope=scope,
                min_depth=min_depth,
                depth_multiplier=depth_multiplier,
                conv_defs=conv_defs)
            end_points.update(mobilenet_end_points)
            with tf.variable_scope('Logits'):
                kernel_size = _reduced_kernel_size_for_small_input(net, [7, 7])
                net = slim.avg_pool2d(net,
                                      kernel_size,
                                      padding='VALID',
                                      scope='AvgPool_1a')
                end_points['AvgPool_1a'] = net
                # 1 x 1 x 1024
                net = slim.dropout(net,
                                   keep_prob=dropout_keep_prob,
                                   scope='Dropout_1b')
                logits = slim.conv2d(net,
                                     num_classes, [1, 1],
                                     activation_fn=None,
                                     normalizer_fn=None,
                                     scope='Conv2d_1c_1x1')
                if spatial_squeeze:
                    logits = tf.squeeze(logits, [1, 2], name='SpatialSqueeze')
            end_points['Logits'] = logits
            if prediction_fn:
                end_points['Predictions'] = prediction_fn(logits,
                                                          scope='Predictions')
    return logits, end_points, clean_image
示例#11
0
文件: train.py 项目: pschang-phy/LIYS
def main(FLAGS):
    style_features_t = losses.get_style_features(FLAGS)

    # Make sure the training path exists.
    training_path = os.path.join(FLAGS.model_path, FLAGS.naming)
    if not (os.path.exists(training_path)):
        os.makedirs(training_path)

    with tf.Graph().as_default():
        with tf.Session() as sess:
            """Build Network"""
            network_fn = nets_factory.get_network_fn(FLAGS.loss_model,
                                                     num_classes=1,
                                                     is_training=False)

            image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing(
                FLAGS.loss_model, is_training=False)
            processed_images = reader.image(FLAGS.batch_size,
                                            FLAGS.image_size,
                                            FLAGS.image_size,
                                            'train2014/',
                                            image_preprocessing_fn,
                                            epochs=FLAGS.epoch)
            generated = model.net(processed_images, training=True)
            processed_generated = [
                image_preprocessing_fn(image, FLAGS.image_size,
                                       FLAGS.image_size) for image in
                tf.unstack(generated, axis=0, num=FLAGS.batch_size)
            ]
            processed_generated = tf.stack(processed_generated)
            _, endpoints_dict = network_fn(tf.concat(
                [processed_generated, processed_images], 0),
                                           spatial_squeeze=False)

            # Log the structure of loss network
            tf.logging.info(
                'Loss network layers(You can define them in "content_layers" and "style_layers"):'
            )
            for key in endpoints_dict:
                tf.logging.info(key)
            """Build Losses"""
            content_loss = losses.content_loss(endpoints_dict,
                                               FLAGS.content_layers)
            style_loss, style_loss_summary = losses.style_loss(
                endpoints_dict, style_features_t, FLAGS.style_layers)
            tv_loss = losses.total_variation_loss(
                generated)  # use the unprocessed image

            loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * tv_loss

            # Add Summary for visualization in tensorboard.
            """Add Summary"""
            tf.summary.scalar('losses/content_loss', content_loss)
            tf.summary.scalar('losses/style_loss', style_loss)
            tf.summary.scalar('losses/regularizer_loss', tv_loss)

            tf.summary.scalar('weighted_losses/weighted_content_loss',
                              content_loss * FLAGS.content_weight)
            tf.summary.scalar('weighted_losses/weighted_style_loss',
                              style_loss * FLAGS.style_weight)
            tf.summary.scalar('weighted_losses/weighted_regularizer_loss',
                              tv_loss * FLAGS.tv_weight)
            tf.summary.scalar('total_loss', loss)

            for layer in FLAGS.style_layers:
                tf.summary.scalar('style_losses/' + layer,
                                  style_loss_summary[layer])
            tf.summary.image('generated', generated)
            # tf.image_summary('processed_generated', processed_generated)  # May be better?
            tf.summary.image(
                'origin',
                tf.stack([
                    image_unprocessing_fn(image) for image in tf.unstack(
                        processed_images, axis=0, num=FLAGS.batch_size)
                ]))
            summary = tf.summary.merge_all()
            writer = tf.summary.FileWriter(training_path)
            """Prepare to Train"""
            global_step = tf.Variable(0, name="global_step", trainable=False)

            variable_to_train = []
            for variable in tf.trainable_variables():
                if not (variable.name.startswith(FLAGS.loss_model)):
                    variable_to_train.append(variable)
            train_op = tf.train.AdamOptimizer(1e-3).minimize(
                loss, global_step=global_step, var_list=variable_to_train)

            variables_to_restore = []
            for v in tf.global_variables():
                if not (v.name.startswith(FLAGS.loss_model)):
                    variables_to_restore.append(v)
            saver = tf.train.Saver(variables_to_restore)

            sess.run([
                tf.global_variables_initializer(),
                tf.local_variables_initializer()
            ])

            # Restore variables for loss network.
            init_func = utils._get_init_fn(FLAGS)
            init_func(sess)

            # Restore variables for training model if the checkpoint file exists.
            last_file = tf.train.latest_checkpoint(training_path)
            if last_file:
                tf.logging.info('Restoring model from {}'.format(last_file))
                saver.restore(sess, last_file)
            """Start Training"""
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)
            start_time = time.time()
            try:
                while not coord.should_stop():
                    _, loss_t, step = sess.run([train_op, loss, global_step])
                    elapsed_time = time.time() - start_time
                    start_time = time.time()
                    """logging"""
                    # print(step)
                    if step % 10 == 0:
                        tf.logging.info(
                            'step: %d,  total Loss %f, secs/step: %f' %
                            (step, loss_t, elapsed_time))
                    """summary"""
                    if step % 200 == 0:
                        tf.logging.info('adding summary...')
                        summary_str = sess.run(summary)
                        writer.add_summary(summary_str, step)
                        writer.flush()
                    """checkpoint"""
                    if step % 1000 == 0:
                        saver.save(sess,
                                   os.path.join(training_path,
                                                'fast-style-model.ckpt'),
                                   global_step=step)
            except tf.errors.OutOfRangeError:
                saver.save(
                    sess,
                    os.path.join(training_path, 'fast-style-model.ckpt-done'))
                tf.logging.info('Done training -- epoch limit reached')
            finally:
                coord.request_stop()
            coord.join(threads)
    def __init__(self, network_name, checkpoint_path, batch_size, num_classes,
                 image_size=None, preproc_func_name=None, preproc_threads=2):

        '''
        TensorFlow feature extractor using tf.slim and models/slim.
        Core functionalities are loading network architecture, pretrained weights,
        setting up an image pre-processing function, queues for fast input reading.
        The main workflow after initialization is first loading a list of image
        files using the `enqueue_image_files` function and then pushing them
        through the network with `feed_forward_batch`.

        For pre-trained networks and some more explanation, checkout:
          https://github.com/tensorflow/models/tree/master/slim

        :param network_name: str, network name (e.g. resnet_v1_101)
        :param checkpoint_path: str, full path to checkpoint file to load
        :param batch_size: int, batch size
        :param num_classes: int, number of output classes
        :param image_size: int, width and height to overrule default_image_size (default=None)
        :param preproc_func_name: func, optional to overwrite default processing (default=None)
        :param preproc_threads: int, number of input threads (default=1)

        '''

        self._network_name = network_name
        self._checkpoint_path = checkpoint_path
        self._batch_size = batch_size
        self._num_classes = num_classes
        self._image_size = image_size
        self._preproc_func_name = preproc_func_name
        self._num_preproc_threads = preproc_threads

        self._global_step = tf.train.get_or_create_global_step()

        # Retrieve the function that returns logits and endpoints
        self._network_fn = nets_factory.get_network_fn(
            self._network_name, num_classes=num_classes, is_training=False)

        # Retrieve the model scope from network factory
        self._model_scope = nets_factory.arg_scopes_map[self._network_name]

        # Fetch the default image size
        self._image_size = self._network_fn.default_image_size

        # Setup the input pipeline with a queue of filenames
        self._filename_queue = tf.FIFOQueue(100000, [tf.string], shapes=[[]], name="filename_queue")
        self._pl_image_files = tf.placeholder(tf.string, shape=[None], name="image_file_list")
        self._enqueue_op = self._filename_queue.enqueue_many([self._pl_image_files])
        self._num_in_queue = self._filename_queue.size()

        # Image reader and preprocessing
        self._batch_from_queue, self._batch_filenames = \
            self._preproc_image_batch(self._batch_size, num_threads=preproc_threads)

        # Either use the placeholder as inputs or feed from queue
        self._image_batch = tf.placeholder_with_default(
            self._batch_from_queue, shape=[None, self._image_size, self._image_size, 3])

        # Retrieve the logits and network endpoints (for extracting activations)
        # Note: endpoints is a dictionary with endpoints[name] = tf.Tensor
        self._logits, self._endpoints = self._network_fn(self._image_batch)

        # Find the checkpoint file
        checkpoint_path = self._checkpoint_path
        if tf.gfile.IsDirectory(self._checkpoint_path):
          checkpoint_path = tf.train.latest_checkpoint(self._checkpoint_path)

        # Load pre-trained weights into the model
        variables_to_restore = slim.get_variables_to_restore()
        restore_fn = slim.assign_from_checkpoint_fn(
            self._checkpoint_path, variables_to_restore)

        # Start the session and load the pre-trained weights
        self._sess = tf.Session()
        restore_fn(self._sess)

        # Local variables initializer, needed for queues etc.
        self._sess.run(tf.local_variables_initializer())

        # Managing the queues and threads
        self._coord = tf.train.Coordinator()
        self._threads = tf.train.start_queue_runners(coord=self._coord, sess=self._sess)
示例#13
0
def main(FLAGS):
    style_features_t = losses.get_style_features(FLAGS)

    # Make sure the training path exists.
    training_path = os.path.join(FLAGS.model_path, FLAGS.naming)
    if not(os.path.exists(training_path)):
        os.makedirs(training_path)

    with tf.Graph().as_default():
        with tf.Session() as sess:
            """Build Network"""
            network_fn = nets_factory.get_network_fn(
                FLAGS.loss_model,
                num_classes=1,
                is_training=False)

            image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing(
                FLAGS.loss_model,
                is_training=False)
            processed_images = reader.image(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size,
                                            'train2014/', image_preprocessing_fn, epochs=FLAGS.epoch)
            generated = model.net(processed_images, training=True)
            processed_generated = [image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size)
                                   for image in tf.unstack(generated, axis=0, num=FLAGS.batch_size)
                                   ]
            processed_generated = tf.stack(processed_generated)
            _, endpoints_dict = network_fn(tf.concat([processed_generated, processed_images], 0), spatial_squeeze=False)

            # Log the structure of loss network
            tf.logging.info('Loss network layers(You can define them in "content_layers" and "style_layers"):')
            for key in endpoints_dict:
                tf.logging.info(key)

            """Build Losses"""
            content_loss = losses.content_loss(endpoints_dict, FLAGS.content_layers)
            style_loss, style_loss_summary = losses.style_loss(endpoints_dict, style_features_t, FLAGS.style_layers)
            tv_loss = losses.total_variation_loss(generated)  # use the unprocessed image

            loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * tv_loss

            # Add Summary for visualization in tensorboard.
            """Add Summary"""
            tf.summary.scalar('losses/content_loss', content_loss)
            tf.summary.scalar('losses/style_loss', style_loss)
            tf.summary.scalar('losses/regularizer_loss', tv_loss)

            tf.summary.scalar('weighted_losses/weighted_content_loss', content_loss * FLAGS.content_weight)
            tf.summary.scalar('weighted_losses/weighted_style_loss', style_loss * FLAGS.style_weight)
            tf.summary.scalar('weighted_losses/weighted_regularizer_loss', tv_loss * FLAGS.tv_weight)
            tf.summary.scalar('total_loss', loss)

            for layer in FLAGS.style_layers:
                tf.summary.scalar('style_losses/' + layer, style_loss_summary[layer])
            tf.summary.image('generated', generated)
            # tf.image_summary('processed_generated', processed_generated)  # May be better?
            tf.summary.image('origin', tf.stack([
                image_unprocessing_fn(image) for image in tf.unstack(processed_images, axis=0, num=FLAGS.batch_size)
            ]))
            summary = tf.summary.merge_all()
            writer = tf.summary.FileWriter(training_path)

            """Prepare to Train"""
            global_step = tf.Variable(0, name="global_step", trainable=False)

            variable_to_train = []
            for variable in tf.trainable_variables():
                if not(variable.name.startswith(FLAGS.loss_model)):
                    variable_to_train.append(variable)
            train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step, var_list=variable_to_train)

            variables_to_restore = []
            for v in tf.global_variables():
                if not(v.name.startswith(FLAGS.loss_model)):
                    variables_to_restore.append(v)
            saver = tf.train.Saver(variables_to_restore, write_version=tf.train.SaverDef.V1)

            sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])

            # Restore variables for loss network.
            init_func = utils._get_init_fn(FLAGS)
            init_func(sess)

            # Restore variables for training model if the checkpoint file exists.
            last_file = tf.train.latest_checkpoint(training_path)
            if last_file:
                tf.logging.info('Restoring model from {}'.format(last_file))
                saver.restore(sess, last_file)

            """Start Training"""
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)
            start_time = time.time()
            try:
                while not coord.should_stop():
                    _, loss_t, step = sess.run([train_op, loss, global_step])
                    elapsed_time = time.time() - start_time
                    start_time = time.time()
                    """logging"""
                    # print(step)
                    if step % 10 == 0:
                        tf.logging.info('step: %d,  total Loss %f, secs/step: %f' % (step, loss_t, elapsed_time))
                    """summary"""
                    if step % 25 == 0:
                        tf.logging.info('adding summary...')
                        summary_str = sess.run(summary)
                        writer.add_summary(summary_str, step)
                        writer.flush()
                    """checkpoint"""
                    if step % 1000 == 0:
                        saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt'), global_step=step)
            except tf.errors.OutOfRangeError:
                saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt-done'))
                tf.logging.info('Done training -- epoch limit reached')
            finally:
                coord.request_stop()
            coord.join(threads)
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(num_clones=1,
                                                      clone_on_cpu=False,
                                                      replica_id=0,
                                                      num_replicas=1,
                                                      num_ps_tasks=0)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        # TODO: integrate data
        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            'resnet_v1_50',
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        # TODO: should write own preprocessing
        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = 'resnet_v1_50'
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        # TODO: data provider needed
        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                tf.losses.softmax_cross_entropy(logits=end_points['AuxLogits'],
                                                onehot_labels=labels,
                                                label_smoothing=0,
                                                weights=0.4,
                                                scope='aux_loss')
            tf.losses.softmax_cross_entropy(logits=logits,
                                            onehot_labels=labels,
                                            label_smoothing=0,
                                            weights=1.0)
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            # TODO: may need to add flexibility in optimizer
            optimizer = tf.train.AdamOptimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op],
                                                          total_loss,
                                                          name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(train_tensor,
                            logdir=FLAGS.train_dir,
                            master=FLAGS.master,
                            is_chief=(FLAGS.task == 0),
                            init_fn=_get_init_fn(),
                            summary_op=summary_op,
                            number_of_steps=FLAGS.max_number_of_steps,
                            log_every_n_steps=FLAGS.log_every_n_steps,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            save_interval_secs=FLAGS.save_interval_secs,
                            sync_optimizer=None)
示例#15
0
def main(_):
    global PIXEL_MEANS

    logging.basicConfig(
        filename='train-%s-%s.log' %
        (FLAGS.net, datetime.datetime.now().strftime('%Y%m%d-%H%M%S')),
        level=logging.DEBUG,
        format='%(asctime)s %(message)s')

    if FLAGS.model:
        try:
            os.makedirs(FLAGS.model)
        except:
            pass

    if FLAGS.finetune:
        print_red("finetune, using RGB with vgg pixel means")
        COLORSPACE = 'RGB'
        PIXEL_MEANS = VGG_PIXEL_MEANS

    X = tf.placeholder(tf.float32, shape=(None, None, None, 3), name="images")
    # ground truth labels
    Y = tf.placeholder(tf.int32, shape=(None, ), name="labels")
    is_training = tf.placeholder(tf.bool, name="is_training")

    if not FLAGS.finetune:
        patch_arg_scopes()
    #with \
    #     slim.arg_scope([slim.batch_norm], decay=0.9, epsilon=5e-4):
    network_fn = nets_factory.get_network_fn(FLAGS.net,
                                             num_classes=FLAGS.classes,
                                             weight_decay=FLAGS.weight_decay,
                                             is_training=is_training)

    logits, _ = network_fn(X - PIXEL_MEANS)
    logits = tf.identity(logits, name='logits')
    # probability of class 1 -- not very useful if FLAGS.classes > 2
    probs = tf.squeeze(tf.slice(tf.nn.softmax(logits), [0, 1], [-1, 1]), 1)

    loss, metrics = cls_loss(logits, Y)
    metric_names = [x.name[:-2] for x in metrics]

    def format_metrics(avg):
        return ' '.join(
            ['%s=%.3f' % (a, b) for a, b in zip(metric_names, list(avg))])

    init_finetune, variables_to_train = None, None
    if FLAGS.finetune:
        print_red("finetune, using RGB with vgg pixel means")
        init_finetune, variables_to_train = setup_finetune(
            FLAGS.finetune, [FLAGS.net + '/logits'])

    global_step = tf.train.create_global_step()
    LR = tf.train.exponential_decay(FLAGS.lr,
                                    global_step,
                                    FLAGS.decay_steps,
                                    FLAGS.decay_rate,
                                    staircase=True)
    if FLAGS.adam:
        print("Using Adam optimizer, reducing LR by 100x")
        optimizer = tf.train.AdamOptimizer(LR / 100)
    else:
        optimizer = tf.train.MomentumOptimizer(learning_rate=LR, momentum=0.9)

    train_op = slim.learning.create_train_op(
        loss,
        optimizer,
        global_step=global_step,
        variables_to_train=variables_to_train)
    saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep)

    if FLAGS.size is None:
        FLAGS.size = network_fn.default_image_size
    stream = create_picpac_stream(FLAGS.db, True, FLAGS.size)
    # load validation db
    val_stream = None
    if FLAGS.val_db:
        val_stream = create_picpac_stream(FLAGS.val_db, False, FLAGS.size)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    epoch_steps = FLAGS.epoch_steps
    if epoch_steps is None:
        epoch_steps = (stream.size() + FLAGS.batch - 1) // FLAGS.batch
    best = 0
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        if init_finetune:
            init_finetune(sess)
        if FLAGS.resume:
            saver.restore(sess, FLAGS.resume)

        global_start_time = time.time()
        epoch = 0
        step = 0
        while epoch < FLAGS.max_epochs:
            start_time = time.time()
            cnt, metrics_sum = 0, np.array([0] * len(metrics),
                                           dtype=np.float32)
            progress = tqdm(range(epoch_steps), leave=False)
            for _ in progress:
                meta, images = stream.next()
                feed_dict = {X: images, Y: meta.labels, is_training: True}
                mm, _ = sess.run([metrics, train_op], feed_dict=feed_dict)
                metrics_sum += np.array(mm) * images.shape[0]
                cnt += images.shape[0]
                metrics_txt = format_metrics(metrics_sum / cnt)
                progress.set_description(metrics_txt)
                step += 1
                pass
            stop = time.time()
            msg = 'train epoch=%d step=%d ' % (epoch, step)
            msg += metrics_txt
            msg += ' elapsed=%.3f time=%.3f ' % (stop - global_start_time,
                                                 stop - start_time)
            print_green(msg)
            logging.info(msg)

            epoch += 1

            if (epoch % FLAGS.val_epochs == 0) and val_stream:
                lr = sess.run(LR)
                # evaluation
                Ys, Ps = [], []
                cnt, metrics_sum = 0, np.array([0] * len(metrics),
                                               dtype=np.float32)
                val_stream.reset()
                progress = tqdm(val_stream, leave=False)
                for meta, images in progress:
                    feed_dict = {X: images, Y: meta.labels, is_training: False}
                    p, mm = sess.run([probs, metrics], feed_dict=feed_dict)
                    metrics_sum += np.array(mm) * images.shape[0]
                    cnt += images.shape[0]
                    Ys.extend(list(meta.labels))
                    Ps.extend(list(p))
                    metrics_txt = format_metrics(metrics_sum / cnt)
                    progress.set_description(metrics_txt)
                    pass
                assert cnt == val_stream.size()
                avg = metrics_sum / cnt
                if avg[0] > best:
                    best = avg[0]
                msg = 'valid epoch=%d step=%d ' % (epoch - 1, step)
                msg += metrics_txt
                if FLAGS.classes == 2:
                    # display scikit-learn metrics
                    Ys = np.array(Ys, dtype=np.int32)
                    Ps = np.array(Ps, dtype=np.float32)
                    msg += ' sk_acc=%.3f auc=%.3f' % (accuracy_score(
                        Ys, Ps > 0.5), roc_auc_score(Ys, Ps))
                    pass
                msg += ' lr=%.4f best=%.3f' % (lr, best)
                print_red(msg)
                logging.info(msg)
                #log.write('%d\t%s\t%.4f\n' % (epoch, '\t'.join(['%.4f' % x for x in avg]), best))
            # model saving
            if (epoch % FLAGS.ckpt_epochs == 0) and FLAGS.model:
                ckpt_path = '%s/%d' % (FLAGS.model, epoch)
                saver.save(sess, ckpt_path)
                print('saved to %s.' % (step, ckpt_path))
            pass
        pass
    pass
示例#16
0
def main(FLAGS):
    style_features_t = losses.get_style_features(FLAGS)
    training_path = os.path.join(FLAGS.model_path, FLAGS.naming)
    if not (os.path.exists(training_path)):
        os.makedirs(training_path)

    with tf.Graph().as_default():
        with tf.Session() as sess:
            """创建Network"""
            network_fn = nets_factory.get_network_fn(
                FLAGS.loss_model,
                num_classes=1,
                is_training=False)

            image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing(
                FLAGS.loss_model,
                is_training=False)

            """训练图片预处理"""
            processed_images = reader.batch_image(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size,
                                                  'train2014/', image_preprocessing_fn, epochs=FLAGS.epoch)
            generated = model.transform_network(processed_images, training=True)
            processed_generated = [image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size)
                                   for image in tf.unstack(generated, axis=0, num=FLAGS.batch_size)
                                   ]
            processed_generated = tf.stack(processed_generated)
            _, endpoints_dict = network_fn(tf.concat([processed_generated, processed_images], 0), spatial_squeeze=False)
            tf.logging.info('Loss network layers(You can define them in "content_layers" and "style_layers"):')
            for key in endpoints_dict:
                tf.logging.info(key)

            """创建 Losses"""
            content_loss = losses.content_loss(endpoints_dict, FLAGS.content_layers)
            style_loss, style_loss_summary = losses.style_loss(endpoints_dict, style_features_t, FLAGS.style_layers)
            tv_loss = losses.total_variation_loss(generated)  # use the unprocessed image

            loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * tv_loss

            """准备训练"""
            global_step = tf.Variable(0, name="global_step", trainable=False)
            variable_to_train = []
            for variable in tf.trainable_variables():
                # 只训练和保存生成网络中的变量
                if not (variable.name.startswith(FLAGS.loss_model)):
                    variable_to_train.append(variable)

            """优化"""
            train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step, var_list=variable_to_train)

            variables_to_restore = []
            for v in tf.global_variables():
                if not (v.name.startswith(FLAGS.loss_model)):
                    variables_to_restore.append(v)
            saver = tf.train.Saver(variables_to_restore, write_version=tf.train.SaverDef.V1)
            sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
            init_func = utils._get_init_fn(FLAGS)
            init_func(sess)
            last_file = tf.train.latest_checkpoint(training_path)
            if last_file:
                tf.logging.info('Restoring model from {}'.format(last_file))
                saver.restore(sess, last_file)

            """开始训练"""
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)
            start_time = time.time()
            try:
                while not coord.should_stop():
                    _, loss_t, step = sess.run([train_op, loss, global_step])
                    elapsed_time = time.time() - start_time
                    start_time = time.time()
                    if step % 10 == 0:
                        tf.logging.info(
                            'step: %d,  total Loss %f, secs/step: %f,%s' % (step, loss_t, elapsed_time, time.asctime()))
                    """checkpoint"""
                    if step % 50 == 0:
                        tf.logging.info('saving check point...')
                        saver.save(sess, os.path.join(training_path, FLAGS.naming + '.ckpt'), global_step=step)
            except tf.errors.OutOfRangeError:
                saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt-done'))
                tf.logging.info('Done training -- epoch limit reached')
            finally:
                coord.request_stop()
                tf.logging.info('coordinator stop')
            coord.join(threads)
def main_fun(argv, ctx):
  import math
  import six
  import tensorflow as tf

  from datasets import dataset_factory
  from nets import nets_factory
  from preprocessing import preprocessing_factory

  sys.argv = argv

  slim = tf.contrib.slim

  tf.app.flags.DEFINE_integer(
      'batch_size', 100, 'The number of samples in each batch.')

  tf.app.flags.DEFINE_integer(
      'max_num_batches', None,
      'Max number of batches to evaluate by default use all.')

  tf.app.flags.DEFINE_string(
      'master', '', 'The address of the TensorFlow master to use.')

  tf.app.flags.DEFINE_string(
      'checkpoint_path', '/tmp/tfmodel/',
      'The directory where the model was written to or an absolute path to a '
      'checkpoint file.')

  tf.app.flags.DEFINE_string(
      'eval_dir', '/tmp/tfmodel/', 'Directory where the results are saved to.')

  tf.app.flags.DEFINE_integer(
      'num_preprocessing_threads', 4,
      'The number of threads used to create the batches.')

  tf.app.flags.DEFINE_string(
      'dataset_name', 'imagenet', 'The name of the dataset to load.')

  tf.app.flags.DEFINE_string(
      'dataset_split_name', 'test', 'The name of the train/test split.')

  tf.app.flags.DEFINE_string(
      'dataset_dir', None, 'The directory where the dataset files are stored.')

  tf.app.flags.DEFINE_integer(
      'labels_offset', 0,
      'An offset for the labels in the dataset. This flag is primarily used to '
      'evaluate the VGG and ResNet architectures which do not use a background '
      'class for the ImageNet dataset.')

  tf.app.flags.DEFINE_string(
      'model_name', 'inception_v3', 'The name of the architecture to evaluate.')

  tf.app.flags.DEFINE_string(
      'preprocessing_name', None, 'The name of the preprocessing to use. If left '
      'as `None`, then the model_name flag is used.')

  tf.app.flags.DEFINE_float(
      'moving_average_decay', None,
      'The decay to use for the moving average.'
      'If left as None, then moving averages are not used.')

  tf.app.flags.DEFINE_integer(
      'eval_image_size', None, 'Eval image size')

  FLAGS = tf.app.flags.FLAGS

  if not FLAGS.dataset_dir:
    raise ValueError('You must supply the dataset directory with --dataset_dir')

  cluster_spec, server = TFNode.start_cluster_server(ctx)

  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    #tf_global_step = slim.get_or_create_global_step()
    tf_global_step = tf.Variable(0, name="global_step")

    ######################
    # Select the dataset #
    ######################
    dataset = dataset_factory.get_dataset(
        FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

    ####################
    # Select the model #
    ####################
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        is_training=False)

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    provider = slim.dataset_data_provider.DatasetDataProvider(
        dataset,
        shuffle=False,
        common_queue_capacity=2 * FLAGS.batch_size,
        common_queue_min=FLAGS.batch_size)
    [image, label] = provider.get(['image', 'label'])
    label -= FLAGS.labels_offset

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
        preprocessing_name,
        is_training=False)

    eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size

    image = image_preprocessing_fn(image, eval_image_size, eval_image_size)

    images, labels = tf.train.batch(
        [image, label],
        batch_size=FLAGS.batch_size,
        num_threads=FLAGS.num_preprocessing_threads,
        capacity=5 * FLAGS.batch_size)

    ####################
    # Define the model #
    ####################
    logits, _ = network_fn(images)

    if FLAGS.moving_average_decay:
      variable_averages = tf.train.ExponentialMovingAverage(
          FLAGS.moving_average_decay, tf_global_step)
      variables_to_restore = variable_averages.variables_to_restore(
          slim.get_model_variables())
      variables_to_restore[tf_global_step.op.name] = tf_global_step
    else:
      variables_to_restore = slim.get_variables_to_restore()

    predictions = tf.argmax(logits, 1)
    labels = tf.squeeze(labels)

    # Define the metrics:
    names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
        'Accuracy': slim.metrics.streaming_accuracy(predictions, labels),
        'Recall_5': slim.metrics.streaming_recall_at_k(
            logits, labels, 5),
    })

    # Print the summaries to screen.
    for name, value in six.iteritems(names_to_values):
      summary_name = 'eval/%s' % name
      op = tf.summary.scalar(summary_name, value, collections=[])
      op = tf.Print(op, [value], summary_name)
      tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)

    # TODO(sguada) use num_epochs=1
    if FLAGS.max_num_batches:
      num_batches = FLAGS.max_num_batches
    else:
      # This ensures that we make a single pass over all of the data.
      num_batches = math.ceil(dataset.num_samples / float(FLAGS.batch_size))

    if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
      checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
    else:
      checkpoint_path = FLAGS.checkpoint_path

    tf.logging.info('Evaluating %s' % checkpoint_path)

    slim.evaluation.evaluate_once(
        master=FLAGS.master,
        checkpoint_path=checkpoint_path,
        logdir=FLAGS.eval_dir,
        num_evals=num_batches,
        eval_op=list(names_to_updates.values()),
        variables_to_restore=variables_to_restore)
def main(_):
    # if not FLAGS.dataset_dir:
    #     raise ValueError('You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)

    graph = tf.Graph()
    with graph.as_default(), tf.device('/cpu:0'):


        # required from data
        num_samples_per_epoch = fileh.root.label_enrollment.shape[0]
        num_batches_per_epoch = int(num_samples_per_epoch / FLAGS.batch_size)

        num_samples_per_epoch_test = fileh.root.label_evaluation.shape[0]
        num_batches_per_epoch_test = int(num_samples_per_epoch_test / FLAGS.batch_size)

        # Create global_step
        global_step = tf.Variable(0, name='global_step', trainable=False)

        ######################
        # Select the network #
        ######################

        is_training = tf.placeholder(tf.bool)

        # num_subjects is the number of subjects in development phase and not the enrollment.
        model_speech_fn = nets_factory.get_network_fn(
            FLAGS.model_speech,
            num_classes=num_subjects_development,
            is_training=is_training)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        # with tf.device(deploy_config.inputs_device()):
        """
        Define the place holders and creating the batch tensor.
        """
        speech = tf.placeholder(tf.float32, (20, 80, 40, 1))
        label = tf.placeholder(tf.int32, (1))
        batch_dynamic = tf.placeholder(tf.int32, ())
        margin_imp_tensor = tf.placeholder(tf.float32, ())

        # Create the batch tensors
        batch_speech, batch_labels = tf.train.batch(
            [speech, label],
            batch_size=batch_dynamic,
            num_threads=FLAGS.num_preprocessing_threads,
            capacity=5 * FLAGS.batch_size)

        #############################
        # Specify the loss function #
        #############################
        tower_grads = []
        with tf.variable_scope(tf.get_variable_scope()):
            for i in xrange(FLAGS.num_clones):
                with tf.device('/gpu:%d' % i):
                    with tf.name_scope('%s_%d' % ('tower', i)) as scope:
                        """
                        Two distance metric are defined:
                           1 - distance_weighted: which is a weighted average of the distance between two structures.
                           2 - distance_l2: which is the regular l2-norm of the two networks outputs.
                        Place holders

                        """
                        ########################################
                        ######## Outputs of two networks #######
                        ########################################
                        features, logits, end_points_speech = model_speech_fn(batch_speech)


                        # one_hot labeling
                        # num_subjects is the number of subjects in development phase and not the enrollment.
                        # Because we are using the pretrained network in the development phase and use the features of the
                        # layer prior to Softmax!
                        label_onehot = tf.one_hot(tf.squeeze(batch_labels, [1]), depth=num_subjects_development, axis=-1)

                        # Define loss
                        with tf.name_scope('loss'):
                            loss = tf.reduce_mean(
                                tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=label_onehot))

                        # Accuracy
                        with tf.name_scope('accuracy'):
                            # Evaluate the model
                            correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(label_onehot, 1))

                            # Accuracy calculation
                            accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

                            # # ##### call the optimizer ######
                            # # # TODO: call optimizer object outside of this gpu environment
                            # #
                            # # Reuse variables for the next tower.
                            # tf.get_variable_scope().reuse_variables()

        #################################################
        ########### Summary Section #####################
        #################################################

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Add summaries for all end_points.
        for end_point in end_points_speech:
            x = end_points_speech[end_point]
            summaries.add(tf.summary.scalar('sparsity_speech/' + end_point,
                                            tf.nn.zero_fraction(x)))

    ###########################
    ######## ######## #########
    ###########################

    with tf.Session(graph=graph, config=tf.ConfigProto(allow_soft_placement=True)) as sess:

        # Initialization of the network.
        variables_to_restore = slim.get_variables_to_restore()
        saver = tf.train.Saver(variables_to_restore, max_to_keep=20)
        coord = tf.train.Coordinator()
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        ################################################
        ############## ENROLLMENT Model ################
        ################################################

        latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir=FLAGS.checkpoint_dir)
        saver.restore(sess, latest_checkpoint)

        # The model predefinition.
        NumClasses = 2
        NumLogits = 128
        MODEL = np.zeros((NumClasses, NumLogits), dtype=np.float32)

        # Go through the speakers.
        for speaker_id, speaker_class in enumerate(range(1, 3)):

            # The contributung number of utterances
            NumUtterance = 20
            # Get the indexes for each speaker in the enrollment data
            speaker_index = np.where(fileh.root.label_enrollment[:] == speaker_class)[0]

            # Check the minumum required utterances per speaker.
            assert len(speaker_index) >= NumUtterance, "At least %d utterances is needed for each speaker" % NumUtterance

            # Get the indexes.
            start_idx = speaker_index[0]
            end_idx = min(speaker_index[0] + NumUtterance, speaker_index[-1])

            # print(end_idx-start_idx)

            # Enrollment of the speaker with specific number of utterances.
            speaker_enrollment, label_enrollment = fileh.root.utterance_enrollment[start_idx:end_idx, :, :,
                                                     0:1], fileh.root.label_enrollment[start_idx:end_idx]

            # Just adding a dimention for 3D convolutional operations.
            speaker_enrollment = speaker_enrollment[None, :, :, :, :]

            # Evaluation
            feature = sess.run(
                [features, is_training],
                feed_dict={is_training: True, batch_dynamic: speaker_enrollment.shape[0],
                           batch_speech: speaker_enrollment,
                           batch_labels: label_enrollment.reshape([label_enrollment.shape[0], 1])})

            # Extracting the associated numpy array.
            feature_speaker = feature[0]

            # # # L2-norm along each utterance vector
            # feature_speaker = sklearn.preprocessing.normalize(feature_speaker,norm='l2', axis=1, copy=True, return_norm=False)

            # Averaging for creation of the spekear model
            speaker_model = feature_speaker

            # Creating the speaker model
            MODEL[speaker_id,:] = speaker_model

        if not os.path.exists(FLAGS.enrollment_dir):
            os.makedirs(FLAGS.enrollment_dir)
        # Save the created model.
        np.save(os.path.join(FLAGS.enrollment_dir , 'MODEL.npy'), MODEL)
def main_fun(argv, ctx):
  import tensorflow as tf
  from tensorflow.python.ops import control_flow_ops
  from datasets import dataset_factory
  from deployment import model_deploy
  from nets import nets_factory
  from preprocessing import preprocessing_factory

  sys.argv = argv

  slim = tf.contrib.slim

  tf.app.flags.DEFINE_integer(
      'num_gpus', '1', 'The number of GPUs to use per node')

  tf.app.flags.DEFINE_boolean('rdma', False, 'Whether to use rdma.')

  tf.app.flags.DEFINE_string(
      'master', '', 'The address of the TensorFlow master to use.')

  tf.app.flags.DEFINE_string(
      'train_dir', '/tmp/tfmodel/',
      'Directory where checkpoints and event logs are written to.')

  tf.app.flags.DEFINE_integer('num_clones', 1,
                              'Number of model clones to deploy.')

  tf.app.flags.DEFINE_boolean('clone_on_cpu', False,
                              'Use CPUs to deploy clones.')

  tf.app.flags.DEFINE_integer('worker_replicas', 1, 'Number of worker replicas.')

  tf.app.flags.DEFINE_integer(
      'num_ps_tasks', 0,
      'The number of parameter servers. If the value is 0, then the parameters '
      'are handled locally by the worker.')

  tf.app.flags.DEFINE_integer(
      'num_readers', 4,
      'The number of parallel readers that read data from the dataset.')

  tf.app.flags.DEFINE_integer(
      'num_preprocessing_threads', 4,
      'The number of threads used to create the batches.')

  tf.app.flags.DEFINE_integer(
      'log_every_n_steps', 10,
      'The frequency with which logs are print.')

  tf.app.flags.DEFINE_integer(
      'save_summaries_secs', 600,
      'The frequency with which summaries are saved, in seconds.')

  tf.app.flags.DEFINE_integer(
      'save_interval_secs', 600,
      'The frequency with which the model is saved, in seconds.')

  tf.app.flags.DEFINE_integer(
      'task', 0, 'Task id of the replica running the training.')

  ######################
  # Optimization Flags #
  ######################

  tf.app.flags.DEFINE_float(
      'weight_decay', 0.00004, 'The weight decay on the model weights.')

  tf.app.flags.DEFINE_string(
      'optimizer', 'rmsprop',
      'The name of the optimizer, one of "adadelta", "adagrad", "adam",'
      '"ftrl", "momentum", "sgd" or "rmsprop".')

  tf.app.flags.DEFINE_float(
      'adadelta_rho', 0.95,
      'The decay rate for adadelta.')

  tf.app.flags.DEFINE_float(
      'adagrad_initial_accumulator_value', 0.1,
      'Starting value for the AdaGrad accumulators.')

  tf.app.flags.DEFINE_float(
      'adam_beta1', 0.9,
      'The exponential decay rate for the 1st moment estimates.')

  tf.app.flags.DEFINE_float(
      'adam_beta2', 0.999,
      'The exponential decay rate for the 2nd moment estimates.')

  tf.app.flags.DEFINE_float('opt_epsilon', 1.0, 'Epsilon term for the optimizer.')

  tf.app.flags.DEFINE_float('ftrl_learning_rate_power', -0.5,
                            'The learning rate power.')

  tf.app.flags.DEFINE_float(
      'ftrl_initial_accumulator_value', 0.1,
      'Starting value for the FTRL accumulators.')

  tf.app.flags.DEFINE_float(
      'ftrl_l1', 0.0, 'The FTRL l1 regularization strength.')

  tf.app.flags.DEFINE_float(
      'ftrl_l2', 0.0, 'The FTRL l2 regularization strength.')

  tf.app.flags.DEFINE_float(
      'momentum', 0.9,
      'The momentum for the MomentumOptimizer and RMSPropOptimizer.')

  tf.app.flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.')

  #######################
  # Learning Rate Flags #
  #######################

  tf.app.flags.DEFINE_string(
      'learning_rate_decay_type',
      'exponential',
      'Specifies how the learning rate is decayed. One of "fixed", "exponential",'
      ' or "polynomial"')

  tf.app.flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')

  tf.app.flags.DEFINE_float(
      'end_learning_rate', 0.0001,
      'The minimal end learning rate used by a polynomial decay learning rate.')

  tf.app.flags.DEFINE_float(
      'label_smoothing', 0.0, 'The amount of label smoothing.')

  tf.app.flags.DEFINE_float(
      'learning_rate_decay_factor', 0.94, 'Learning rate decay factor.')

  tf.app.flags.DEFINE_float(
      'num_epochs_per_decay', 2.0,
      'Number of epochs after which learning rate decays.')

  tf.app.flags.DEFINE_bool(
      'sync_replicas', False,
      'Whether or not to synchronize the replicas during training.')

  tf.app.flags.DEFINE_integer(
      'replicas_to_aggregate', 1,
      'The Number of gradients to collect before updating params.')

  tf.app.flags.DEFINE_float(
      'moving_average_decay', None,
      'The decay to use for the moving average.'
      'If left as None, then moving averages are not used.')

  #######################
  # Dataset Flags #
  #######################

  tf.app.flags.DEFINE_string(
      'dataset_name', 'imagenet', 'The name of the dataset to load.')

  tf.app.flags.DEFINE_string(
      'dataset_split_name', 'train', 'The name of the train/test split.')

  tf.app.flags.DEFINE_string(
      'dataset_dir', None, 'The directory where the dataset files are stored.')

  tf.app.flags.DEFINE_integer(
      'labels_offset', 0,
      'An offset for the labels in the dataset. This flag is primarily used to '
      'evaluate the VGG and ResNet architectures which do not use a background '
      'class for the ImageNet dataset.')

  tf.app.flags.DEFINE_string(
      'model_name', 'inception_v3', 'The name of the architecture to train.')

  tf.app.flags.DEFINE_string(
      'preprocessing_name', None, 'The name of the preprocessing to use. If left '
      'as `None`, then the model_name flag is used.')

  tf.app.flags.DEFINE_integer(
      'batch_size', 32, 'The number of samples in each batch.')

  tf.app.flags.DEFINE_integer(
      'train_image_size', None, 'Train image size')

  tf.app.flags.DEFINE_integer('max_number_of_steps', None,
                              'The maximum number of training steps.')

  #####################
  # Fine-Tuning Flags #
  #####################

  tf.app.flags.DEFINE_string(
      'checkpoint_path', None,
      'The path to a checkpoint from which to fine-tune.')

  tf.app.flags.DEFINE_string(
      'checkpoint_exclude_scopes', None,
      'Comma-separated list of scopes of variables to exclude when restoring '
      'from a checkpoint.')

  tf.app.flags.DEFINE_string(
      'trainable_scopes', None,
      'Comma-separated list of scopes to filter the set of variables to train.'
      'By default, None would train all the variables.')

  tf.app.flags.DEFINE_boolean(
      'ignore_missing_vars', False,
      'When restoring a checkpoint would ignore missing variables.')

  FLAGS = tf.app.flags.FLAGS
  FLAGS.job_name = ctx.job_name
  FLAGS.task = ctx.task_index
  FLAGS.num_clones = FLAGS.num_gpus
  FLAGS.worker_replicas = len(ctx.cluster_spec['worker'])
  assert(FLAGS.num_ps_tasks == (len(ctx.cluster_spec['ps']) if 'ps' in ctx.cluster_spec else 0))

  def _configure_learning_rate(num_samples_per_epoch, global_step):
    """Configures the learning rate.

    Args:
      num_samples_per_epoch: The number of samples in each epoch of training.
      global_step: The global_step tensor.

    Returns:
      A `Tensor` representing the learning rate.

    Raises:
      ValueError: if
    """
    decay_steps = int(num_samples_per_epoch / FLAGS.batch_size *
                      FLAGS.num_epochs_per_decay)
    if FLAGS.sync_replicas:
      decay_steps /= FLAGS.replicas_to_aggregate

    if FLAGS.learning_rate_decay_type == 'exponential':
      return tf.train.exponential_decay(FLAGS.learning_rate,
                                        global_step,
                                        decay_steps,
                                        FLAGS.learning_rate_decay_factor,
                                        staircase=True,
                                        name='exponential_decay_learning_rate')
    elif FLAGS.learning_rate_decay_type == 'fixed':
      return tf.constant(FLAGS.learning_rate, name='fixed_learning_rate')
    elif FLAGS.learning_rate_decay_type == 'polynomial':
      return tf.train.polynomial_decay(FLAGS.learning_rate,
                                       global_step,
                                       decay_steps,
                                       FLAGS.end_learning_rate,
                                       power=1.0,
                                       cycle=False,
                                       name='polynomial_decay_learning_rate')
    else:
      raise ValueError('learning_rate_decay_type [%s] was not recognized',
                       FLAGS.learning_rate_decay_type)


  def _configure_optimizer(learning_rate):
    """Configures the optimizer used for training.

    Args:
      learning_rate: A scalar or `Tensor` learning rate.

    Returns:
      An instance of an optimizer.

    Raises:
      ValueError: if FLAGS.optimizer is not recognized.
    """
    if FLAGS.optimizer == 'adadelta':
      optimizer = tf.train.AdadeltaOptimizer(
          learning_rate,
          rho=FLAGS.adadelta_rho,
          epsilon=FLAGS.opt_epsilon)
    elif FLAGS.optimizer == 'adagrad':
      optimizer = tf.train.AdagradOptimizer(
          learning_rate,
          initial_accumulator_value=FLAGS.adagrad_initial_accumulator_value)
    elif FLAGS.optimizer == 'adam':
      optimizer = tf.train.AdamOptimizer(
          learning_rate,
          beta1=FLAGS.adam_beta1,
          beta2=FLAGS.adam_beta2,
          epsilon=FLAGS.opt_epsilon)
    elif FLAGS.optimizer == 'ftrl':
      optimizer = tf.train.FtrlOptimizer(
          learning_rate,
          learning_rate_power=FLAGS.ftrl_learning_rate_power,
          initial_accumulator_value=FLAGS.ftrl_initial_accumulator_value,
          l1_regularization_strength=FLAGS.ftrl_l1,
          l2_regularization_strength=FLAGS.ftrl_l2)
    elif FLAGS.optimizer == 'momentum':
      optimizer = tf.train.MomentumOptimizer(
          learning_rate,
          momentum=FLAGS.momentum,
          name='Momentum')
    elif FLAGS.optimizer == 'rmsprop':
      optimizer = tf.train.RMSPropOptimizer(
          learning_rate,
          decay=FLAGS.rmsprop_decay,
          momentum=FLAGS.momentum,
          epsilon=FLAGS.opt_epsilon)
    elif FLAGS.optimizer == 'sgd':
      optimizer = tf.train.GradientDescentOptimizer(learning_rate)
    else:
      raise ValueError('Optimizer [%s] was not recognized', FLAGS.optimizer)
    return optimizer


  def _add_variables_summaries(learning_rate):
    summaries = []
    for variable in slim.get_model_variables():
      summaries.append(tf.summary.histogram(variable.op.name, variable))
    summaries.append(tf.summary.scalar('training/Learning Rate', learning_rate))
    return summaries


  def _get_init_fn():
    """Returns a function run by the chief worker to warm-start the training.

    Note that the init_fn is only run when initializing the model during the very
    first global step.

    Returns:
      An init function run by the supervisor.
    """
    if FLAGS.checkpoint_path is None:
      return None

    # Warn the user if a checkpoint exists in the train_dir. Then we'll be
    # ignoring the checkpoint anyway.
    if tf.train.latest_checkpoint(FLAGS.train_dir):
      tf.logging.info(
          'Ignoring --checkpoint_path because a checkpoint already exists in %s'
          % FLAGS.train_dir)
      return None

    exclusions = []
    if FLAGS.checkpoint_exclude_scopes:
      exclusions = [scope.strip()
                    for scope in FLAGS.checkpoint_exclude_scopes.split(',')]

    # TODO(sguada) variables.filter_variables()
    variables_to_restore = []
    for var in slim.get_model_variables():
      excluded = False
      for exclusion in exclusions:
        if var.op.name.startswith(exclusion):
          excluded = True
          break
      if not excluded:
        variables_to_restore.append(var)

    if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
      checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
    else:
      checkpoint_path = FLAGS.checkpoint_path

    tf.logging.info('Fine-tuning from %s' % checkpoint_path)

    return slim.assign_from_checkpoint_fn(
        checkpoint_path,
        variables_to_restore,
        ignore_missing_vars=FLAGS.ignore_missing_vars)


  def _get_variables_to_train():
    """Returns a list of variables to train.

    Returns:
      A list of variables to train by the optimizer.
    """
    if FLAGS.trainable_scopes is None:
      return tf.trainable_variables()
    else:
      scopes = [scope.strip() for scope in FLAGS.trainable_scopes.split(',')]

    variables_to_train = []
    for scope in scopes:
      variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
      variables_to_train.extend(variables)
    return variables_to_train

  # main
  cluster_spec, server = TFNode.start_cluster_server(ctx=ctx, num_gpus=FLAGS.num_gpus, rdma=FLAGS.rdma)
  if ctx.job_name == 'ps':
    # `ps` jobs wait for incoming connections from the workers.
    server.join()
  else:
    # `worker` jobs will actually do the work.
    if not FLAGS.dataset_dir:
      raise ValueError('You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
      #######################
      # Config model_deploy #
      #######################
      deploy_config = model_deploy.DeploymentConfig(
          num_clones=FLAGS.num_clones,
          clone_on_cpu=FLAGS.clone_on_cpu,
          replica_id=FLAGS.task,
          num_replicas=FLAGS.worker_replicas,
          num_ps_tasks=FLAGS.num_ps_tasks)

      # Create global_step
      #with tf.device(deploy_config.variables_device()):
      #  global_step = slim.create_global_step()
      with tf.device("/job:ps/task:0"):
        global_step = tf.Variable(0, name="global_step")

      ######################
      # Select the dataset #
      ######################
      dataset = dataset_factory.get_dataset(
          FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

      ######################
      # Select the network #
      ######################
      network_fn = nets_factory.get_network_fn(
          FLAGS.model_name,
          num_classes=(dataset.num_classes - FLAGS.labels_offset),
          weight_decay=FLAGS.weight_decay,
          is_training=True)

      #####################################
      # Select the preprocessing function #
      #####################################
      preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
      image_preprocessing_fn = preprocessing_factory.get_preprocessing(
          preprocessing_name,
          is_training=True)

      ##############################################################
      # Create a dataset provider that loads data from the dataset #
      ##############################################################
      with tf.device(deploy_config.inputs_device()):
        provider = slim.dataset_data_provider.DatasetDataProvider(
            dataset,
            num_readers=FLAGS.num_readers,
            common_queue_capacity=20 * FLAGS.batch_size,
            common_queue_min=10 * FLAGS.batch_size)
        [image, label] = provider.get(['image', 'label'])
        label -= FLAGS.labels_offset

        train_image_size = FLAGS.train_image_size or network_fn.default_image_size

        image = image_preprocessing_fn(image, train_image_size, train_image_size)

        images, labels = tf.train.batch(
            [image, label],
            batch_size=FLAGS.batch_size,
            num_threads=FLAGS.num_preprocessing_threads,
            capacity=5 * FLAGS.batch_size)
        labels = slim.one_hot_encoding(
            labels, dataset.num_classes - FLAGS.labels_offset)
        batch_queue = slim.prefetch_queue.prefetch_queue(
            [images, labels], capacity=2 * deploy_config.num_clones)

      ####################
      # Define the model #
      ####################
      def clone_fn(batch_queue):
        """Allows data parallelism by creating multiple clones of network_fn."""
        images, labels = batch_queue.dequeue()
        logits, end_points = network_fn(images)

        #############################
        # Specify the loss function #
        #############################
        if 'AuxLogits' in end_points:
          tf.losses.softmax_cross_entropy(
              logits=end_points['AuxLogits'], onehot_labels=labels,
              label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss')
        tf.losses.softmax_cross_entropy(
            logits=logits, onehot_labels=labels,
            label_smoothing=FLAGS.label_smoothing, weights=1.0)
        return end_points

      # Gather initial summaries.
      summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

      clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
      first_clone_scope = deploy_config.clone_scope(0)
      # Gather update_ops from the first clone. These contain, for example,
      # the updates for the batch_norm variables created by network_fn.
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

      # Add summaries for end_points.
      end_points = clones[0].outputs
      for end_point in end_points:
        x = end_points[end_point]
        summaries.add(tf.summary.histogram('activations/' + end_point, x))
        summaries.add(tf.summary.scalar('sparsity/' + end_point,
                                        tf.nn.zero_fraction(x)))

      # Add summaries for losses.
      for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
        summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

      # Add summaries for variables.
      for variable in slim.get_model_variables():
        summaries.add(tf.summary.histogram(variable.op.name, variable))

      #################################
      # Configure the moving averages #
      #################################
      if FLAGS.moving_average_decay:
        moving_average_variables = slim.get_model_variables()
        variable_averages = tf.train.ExponentialMovingAverage(
            FLAGS.moving_average_decay, global_step)
      else:
        moving_average_variables, variable_averages = None, None

      #########################################
      # Configure the optimization procedure. #
      #########################################
      with tf.device(deploy_config.optimizer_device()):
        learning_rate = _configure_learning_rate(dataset.num_samples, global_step)
        optimizer = _configure_optimizer(learning_rate)
        summaries.add(tf.summary.scalar('learning_rate', learning_rate))

      if FLAGS.sync_replicas:
        # If sync_replicas is enabled, the averaging will be done in the chief
        # queue runner.
        optimizer = tf.train.SyncReplicasOptimizer(
            opt=optimizer,
            replicas_to_aggregate=FLAGS.replicas_to_aggregate,
            variable_averages=variable_averages,
            variables_to_average=moving_average_variables,
            replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
            total_num_replicas=FLAGS.worker_replicas)
      elif FLAGS.moving_average_decay:
        # Update ops executed locally by trainer.
        update_ops.append(variable_averages.apply(moving_average_variables))

      # Variables to train.
      variables_to_train = _get_variables_to_train()

      #  and returns a train_tensor and summary_op
      total_loss, clones_gradients = model_deploy.optimize_clones(
          clones,
          optimizer,
          var_list=variables_to_train)
      # Add total_loss to summary.
      summaries.add(tf.summary.scalar('total_loss', total_loss))

      # Create gradient updates.
      grad_updates = optimizer.apply_gradients(clones_gradients,
                                               global_step=global_step)
      update_ops.append(grad_updates)

      update_op = tf.group(*update_ops)
      train_tensor = control_flow_ops.with_dependencies([update_op], total_loss,
                                                        name='train_op')

      # Add the summaries from the first clone. These contain the summaries
      # created by model_fn and either optimize_clones() or _gather_clone_loss().
      summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                         first_clone_scope))

      # Merge all summaries together.
      summary_op = tf.summary.merge(list(summaries), name='summary_op')


      ###########################
      # Kicks off the training. #
      ###########################
      summary_writer = tf.summary.FileWriter("tensorboard_%d" %(ctx.worker_num), graph=tf.get_default_graph())
      slim.learning.train(
          train_tensor,
          logdir=FLAGS.train_dir,
          master=server.target,
          is_chief=(FLAGS.task == 0),
          init_fn=_get_init_fn(),
          summary_op=summary_op,
          number_of_steps=FLAGS.max_number_of_steps,
          log_every_n_steps=FLAGS.log_every_n_steps,
          save_summaries_secs=FLAGS.save_summaries_secs,
          save_interval_secs=FLAGS.save_interval_secs,
          summary_writer=summary_writer,
          sync_optimizer=optimizer if FLAGS.sync_replicas else None)
def async_extract():
    # Check training directory.
    train_dir = FLAGS.train_dir
    if not tf.gfile.IsDirectory(train_dir):
        tf.logging.fatal("Training directory %s not found.", train_dir)
        return

    # Build the TensorFlow graph.
    g = tf.Graph()
    with g.as_default():
        ####################
        # Select the network #
        ####################
        network_fn = nets_factory.get_network_fn(FLAGS.model_name,
                                                 num_classes=FLAGS.NUM_CLASSES,
                                                 is_training=False)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=False)

        test_size, test_data, test_label, test_names = async_loader.multi_sample_video_inputs(
            FLAGS.dataset_list,
            FLAGS.dataset_dir,
            FLAGS.batch_size,
            FLAGS.n_steps,
            FLAGS.modality,
            FLAGS.read_stride,
            FLAGS.resize_image_size,
            FLAGS.train_image_size,
            image_preprocessing_fn,
            label_from_one=(FLAGS.labels_offset > 0),
            sample_num=25,
            length1=FLAGS.length,
            merge_label=FLAGS.merge_label)
        print("Batch size %d" % test_data.get_shape()[0].value)

        batch_size_per_gpu = FLAGS.batch_size
        global_step_tensor = slim.create_global_step()

        # Calculate the gradients for each model tower.
        predicts, end_points = network_fn(test_data)
        if hasattr(network_fn, 'rnn_part'):
            predicts, end_points_rnn = network_fn.rnn_part(predicts)
            end_points.update(end_points_rnn)
        if not FLAGS.merge_label:
            predicts = tf.split(predicts, FLAGS.n_steps, 0)[-1]
            test_label = tf.split(test_label, FLAGS.n_steps, 0)[-1]
        top_k_op = tf.nn.in_top_k(predicts, test_label, FLAGS.top_k)

        if FLAGS.moving_average_decay:
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step_tensor)
            variables_to_restore = variable_averages.variables_to_restore(
                slim.get_variables())
            variables_to_restore[
                global_step_tensor.op.name] = global_step_tensor
        else:
            variables_to_restore = slim.get_variables_to_restore()

        for var in variables_to_restore:
            print("Will restore %s" % (var.op.name))
        saver = tf.train.Saver(variables_to_restore)
        sv = tf.train.Supervisor(graph=g,
                                 logdir=FLAGS.eval_dir,
                                 summary_op=None,
                                 summary_writer=None,
                                 global_step=None,
                                 saver=None)
        g.finalize()

        tf.logging.info("Starting evaluation at " +
                        time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()))
        model_path = tf.train.latest_checkpoint(FLAGS.train_dir)
        if not model_path:
            tf.logging.info("Skipping evaluation. No checkpoint found in: %s",
                            FLAGS.train_dir)
        else:
            with sv.managed_session(FLAGS.master,
                                    start_standard_services=False,
                                    config=None) as sess:
                # Load model from checkpoint.
                tf.logging.info("Loading model from checkpoint: %s",
                                model_path)
                saver.restore(sess, model_path)
                global_step = tf.train.global_step(sess,
                                                   global_step_tensor.name)
                tf.logging.info("Successfully loaded %s at global step = %d.",
                                os.path.basename(model_path), global_step)

                # Start the queue runners.
                sv.start_queue_runners(sess)

                # Run evaluation on the latest checkpoint.
                print("Extracting......")
                num_eval_batches = int(
                    math.ceil(
                        float(test_size) / float(batch_size_per_gpu) *
                        float(FLAGS.n_steps)))
                assert (num_eval_batches * batch_size_per_gpu /
                        FLAGS.n_steps) == test_size
                correct = 0
                count = 0
                for i in xrange(num_eval_batches):
                    test_start_time = time.time()
                    ret, pre, name = sess.run([top_k_op, predicts, test_names])
                    correct += np.sum(ret)
                    for b in xrange(pre.shape[0]):
                        fp = open(
                            '%s/%s' %
                            (FLAGS.feature_dir, os.path.basename(name[b])),
                            'a')
                        for f in xrange(pre.shape[1]):
                            fp.write('%f ' % pre[b, f])
                        fp.write('\n')
                        fp.close()
                    test_duration = time.time() - test_start_time
                    count += len(ret)
                    cur_accuracy = float(correct) * 100 / count

                    test_examples_per_sec = float(
                        batch_size_per_gpu) / test_duration

                    if i % 100 == 0:
                        msg = '{:>6.2f}%, {:>6}/{:<6}'.format(
                            cur_accuracy, count, test_size)
                        format_str = (
                            '%s: batch %d, accuracy=%s, (%.1f examples/sec; %.3f '
                            'sec/batch)')
                        print(format_str %
                              (datetime.now(), i, msg, test_examples_per_sec,
                               test_duration))

                msg = '{:>6.2f}%, {:>6}/{:<6}'.format(cur_accuracy, count,
                                                      test_size)
                format_str = (
                    '%s: total batch %d, accuracy=%s, (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str % (datetime.now(), num_eval_batches, msg,
                                    test_examples_per_sec, test_duration))
示例#21
0
文件: train.py 项目: LeeMax117/LCZ
#############################################################
# count the step
global_step = tf.Variable(0, trainable=False)
increment_op = tf.assign_add(global_step, tf.constant(1))

# The raw formulation of cross-entropy,
#
#   tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)),
#                                 reduction_indices=[1]))
#
# can be numerically unstable.
#

network_fn = nets_factory.get_network_fn(
    'M_inception_v4',
    num_classes=17,
    weight_decay = 0.00004,
    is_training = is_training)


y, end_points = network_fn(x)

cross_entropy = tf.reduce_mean(
    tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
# l2_loss = tf.add_n( [tf.nn.l2_loss(w) for w in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)] )
# total_loss = cross_entropy + 7e-5*l2_loss

#########################################################################################
# exponential decay learning rate
# learning_rate = tf.train.exponential_decay(0.01, global_step, decay_steps=1, decay_rate=0.9997, staircase=False)
learning_rate = tf.placeholder(tf.float32)
示例#22
0
def main(_):


    tf.logging.set_verbosity(tf.logging.INFO)

    graph = tf.Graph()
    with graph.as_default(), tf.device('/cpu:0'):
        ######################
        # Config model_deploy#
        ######################

        # required from data
        num_samples_per_epoch = train_data['mouth'].shape[0]
        num_batches_per_epoch = int(num_samples_per_epoch / FLAGS.batch_size)

        num_samples_per_epoch_test = test_data['mouth'].shape[0]
        num_batches_per_epoch_test = int(num_samples_per_epoch_test / FLAGS.batch_size)

        # Create global_step
        global_step = tf.Variable(0, name='global_step', trainable=False)

        #########################################
        # Configure the larning rate. #
        #########################################
        learning_rate = _configure_learning_rate(num_samples_per_epoch, global_step)
        opt = _configure_optimizer(learning_rate)

        ######################
        # Select the network #
        ######################
        is_training = tf.placeholder(tf.bool)

        network_speech_fn = nets_factory.get_network_fn(
            FLAGS.model_speech_name,
            num_classes=2,
            weight_decay=FLAGS.weight_decay,
            is_training=is_training)

        network_mouth_fn = nets_factory.get_network_fn(
            FLAGS.model_mouth_name,
            num_classes=2,
            weight_decay=FLAGS.weight_decay,
            is_training=is_training)

        #####################################
        # Select the preprocessing function #
        #####################################

        # TODO: Do some preprocessing if necessary.

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        # with tf.device(deploy_config.inputs_device()):
        """
        Define the place holders and creating the batch tensor.
        """

        # Mouth spatial set
        INPUT_SEQ_LENGTH = 9
        INPUT_HEIGHT = 60
        INPUT_WIDTH = 100
        INPUT_CHANNELS = 1
        batch_mouth = tf.placeholder(tf.float32, shape=(
            [None, INPUT_SEQ_LENGTH, INPUT_HEIGHT, INPUT_WIDTH, INPUT_CHANNELS]))

        # Speech spatial set
        INPUT_SEQ_LENGTH_SPEECH = 15
        INPUT_HEIGHT_SPEECH = 40
        INPUT_WIDTH_SPEECH = 1
        INPUT_CHANNELS_SPEECH = 3
        batch_speech = tf.placeholder(tf.float32, shape=(
            [None, INPUT_SEQ_LENGTH_SPEECH, INPUT_HEIGHT_SPEECH, INPUT_WIDTH_SPEECH, INPUT_CHANNELS_SPEECH]))

        # Label
        batch_labels = tf.placeholder(tf.uint8, (None, 1))
        margin_imp_tensor = tf.placeholder(tf.float32, ())

        ################################
        ## Feed forwarding to network ##
        ################################
        tower_grads = []
        with tf.variable_scope(tf.get_variable_scope()):
            for i in xrange(FLAGS.num_clones):
                with tf.device('/gpu:%d' % i):
                    with tf.name_scope('%s_%d' % ('tower', i)) as scope:
                        """
                        Two distance metric are defined:
                           1 - distance_weighted: which is a weighted average of the distance between two structures.
                           2 - distance_l2: which is the regular l2-norm of the two networks outputs.
                        Place holders

                        """
                        ########################################
                        ######## Outputs of two networks #######
                        ########################################

                        logits_speech, end_points_speech = network_speech_fn(batch_speech)
                        logits_mouth, end_points_mouth = network_mouth_fn(batch_mouth)

                        # # Uncomment if the output embedding is desired to be as |f(x)| = 1
                        # logits_speech = tf.nn.l2_normalize(logits_speech, dim=1, epsilon=1e-12, name=None)
                        # logits_mouth = tf.nn.l2_normalize(logits_mouth, dim=1, epsilon=1e-12, name=None)

                        #################################################
                        ########### Loss Calculation ####################
                        #################################################

                        # ##### Weighted distance using a fully connected layer #####
                        # distance_vector = tf.subtract(logits_speech, logits_mouth,  name=None)
                        # distance_weighted = slim.fully_connected(distance_vector, 1, activation_fn=tf.nn.sigmoid,
                        #                                          normalizer_fn=None,
                        #                                          scope='fc_weighted')

                        ##### Euclidean distance ####
                        distance_l2 = tf.sqrt(
                            tf.reduce_sum(tf.pow(tf.subtract(logits_speech, logits_mouth), 2), 1, keep_dims=True))

                        ##### Contrastive loss ######
                        loss = losses.contrastive_loss(batch_labels, distance_l2, margin_imp=margin_imp_tensor,
                                                       scope=scope)

                        # ##### call the optimizer ######
                        # # TODO: call optimizer object outside of this gpu environment
                        #
                        # Reuse variables for the next tower.
                        tf.get_variable_scope().reuse_variables()

                        # Calculate the gradients for the batch of data on this CIFAR tower.
                        grads = opt.compute_gradients(loss)

                        # Keep track of the gradients across all towers.
                        tower_grads.append(grads)


        # Calculate the mean of each gradient.
        grads = average_gradients(tower_grads)

        # Apply the gradients to adjust the shared variables.
        apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)

        # Track the moving averages of all trainable variables.
        MOVING_AVERAGE_DECAY = 0.9999
        variable_averages = tf.train.ExponentialMovingAverage(
            MOVING_AVERAGE_DECAY, global_step)
        variables_averages_op = variable_averages.apply(tf.trainable_variables())

        # Group all updates to into a single train op.
        train_op = tf.group(apply_gradient_op, variables_averages_op)

        #################################################
        ########### Summary Section #####################
        #################################################

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Add summaries for all end_points.
        for end_point in end_points_speech:
            x = end_points_speech[end_point]
            # summaries.add(tf.summary.histogram('activations_speech/' + end_point, x))
            summaries.add(tf.summary.scalar('sparsity_speech/' + end_point,
                                            tf.nn.zero_fraction(x)))

        for end_point in end_points_mouth:
            x = end_points_mouth[end_point]
            # summaries.add(tf.summary.histogram('activations_mouth/' + end_point, x))
            summaries.add(tf.summary.scalar('sparsity_mouth/' + end_point,
                                            tf.nn.zero_fraction(x)))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        # Add to parameters to summaries
        summaries.add(tf.summary.scalar('learning_rate', learning_rate))
        summaries.add(tf.summary.scalar('eval/Loss', loss))
        summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

    ###########################
    ######## Training #########
    ###########################

    with tf.Session(graph=graph, config=tf.ConfigProto(allow_soft_placement=True)) as sess:

        # Initialization of the network.
        variables_to_restore = slim.get_variables_to_restore()
        saver = tf.train.Saver(variables_to_restore, max_to_keep=20)
        coord = tf.train.Coordinator()
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        # Restore the model
        latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir=FLAGS.checkpoint_dir)
        saver.restore(sess, latest_checkpoint)

        # op to write logs to Tensorboard
        summary_writer = tf.summary.FileWriter(FLAGS.test_dir, graph=graph)

        ###################################################
        ############################ TEST  ################
        ###################################################
        score_dissimilarity_vector = np.zeros((FLAGS.batch_size * num_batches_per_epoch_test, 1))
        label_vector = np.zeros((FLAGS.batch_size * num_batches_per_epoch_test, 1))

        # Loop over all batches
        for i in range(num_batches_per_epoch_test):
            start_idx = i * FLAGS.batch_size
            end_idx = (i + 1) * FLAGS.batch_size
            speech_test, mouth_test, label_test = test_data['speech'][start_idx:end_idx], test_data['mouth'][
                                                                                          start_idx:end_idx], test_label[
                                                                                                              start_idx:end_idx]

            # # # Uncomment if standardalization is needed
            # # mean subtraction if necessary
            # speech_test = (speech_test - mean_speech) / std_speech
            # mouth_test = (mouth_test - mean_mouth) / std_mouth

            # Evaluation phase
            # WARNING: margin_imp_tensor has no effect here but it needs to be there because its tensor required a value to feed in!!
            loss_value, score_dissimilarity, _ = sess.run([loss, distance_l2, is_training],
                                                          feed_dict={is_training: False,
                                                                     margin_imp_tensor: 50,
                                                                     batch_speech: speech_test,
                                                                     batch_mouth: mouth_test,
                                                                     batch_labels: label_test.reshape(
                                                                         [FLAGS.batch_size, 1])})
            if (i + 1) % FLAGS.log_every_n_steps == 0:
                print("TESTING:" + ", Minibatch " + str(
                    i + 1) + " of %d " % num_batches_per_epoch_test)
            score_dissimilarity_vector[start_idx:end_idx] = score_dissimilarity
            label_vector[start_idx:end_idx] = label_test

        ##############################
        ##### K-fold validation ######
        ##############################
        K = 10
        EER = np.zeros((K, 1))
        AUC = np.zeros((K, 1))
        AP = np.zeros((K, 1))
        batch_k_validation = int(label_vector.shape[0] / float(K))

        for i in range(K):
            EER[i, :], AUC[i, :], AP[i, :], fpr, tpr = calculate_roc.calculate_eer_auc_ap(
                label_vector[i * batch_k_validation:(i + 1) * batch_k_validation],
                score_dissimilarity_vector[i * batch_k_validation:(i + 1) * batch_k_validation])

        # Printing Equal Error Rate(EER), Area Under the Curve(AUC) and Average Precision(AP)
        print("TESTING:" +", EER= " + str(np.mean(EER, axis=0)) + ", AUC= " + str(
            np.mean(AUC, axis=0)) + ", AP= " + str(np.mean(AP, axis=0)))
def main(_):
  if not FLAGS.dataset_dir:
    raise ValueError('You must supply the dataset directory with --dataset_dir')

  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    tf_global_step = slim.get_or_create_global_step()

    ######################
    # Select the dataset #
    ######################
    dataset = dataset_factory.get_dataset(
        FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

    ####################
    # Select the model #
    ####################
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        is_training=False)

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    provider = slim.dataset_data_provider.DatasetDataProvider(
        dataset,
        shuffle=False,
        common_queue_capacity=2 * FLAGS.batch_size,
        common_queue_min=FLAGS.batch_size)
    [image, label, coarse_label] = provider.get(
        ['image', 'label', 'coarse_label'])
    label -= FLAGS.labels_offset

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
        preprocessing_name,
        is_training=False)

    eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size

#    image = tf.image.grayscale_to_rgb(image)

    image = image_preprocessing_fn(image, eval_image_size, eval_image_size)

    images, labels, coarse_labels = tf.train.batch(
        [image, label, coarse_label],
        batch_size=FLAGS.batch_size,
        num_threads=FLAGS.num_preprocessing_threads,
        capacity=5 * FLAGS.batch_size)
    coarse_labels = tf.cast(coarse_labels, tf.int32)
    tf.image_summary('image', images, max_images=5)

    ####################
    # Define the model #
    ####################
    logits, _ = network_fn(images)

    if FLAGS.moving_average_decay:
      variable_averages = tf.train.ExponentialMovingAverage(
          FLAGS.moving_average_decay, tf_global_step)
      variables_to_restore = variable_averages.variables_to_restore(
          slim.get_model_variables())
      variables_to_restore[tf_global_step.op.name] = tf_global_step
    else:
      variables_to_restore = slim.get_variables_to_restore()

    one_hot_labels = slim.one_hot_encoding(labels, 2)
    loss = slim.losses.softmax_cross_entropy(logits, one_hot_labels)

    predictions = tf.argmax(logits, 1)
    labels = tf.squeeze(labels)

    # Define the metrics:
    names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
        'Total_Loss': slim.metrics.streaming_mean(loss),
        'Accuracy': slim.metrics.streaming_accuracy(predictions, labels),
    })

  with tf.variable_scope('coarse_label_accuracy',
                         values=[predictions, labels, coarse_labels]):
    totals = tf.Variable(
        initial_value=tf.zeros([len(dataset.coarse_labels_to_names)]),
        trainable=False,
        collections=[tf.GraphKeys.LOCAL_VARIABLES],
        dtype=tf.float32,
        name='totals')

    counts = tf.Variable(
        initial_value=tf.zeros([len(dataset.coarse_labels_to_names)]),
        trainable=False,
        collections=[tf.GraphKeys.LOCAL_VARIABLES],
        dtype=tf.float32,
        name='counts')

    correct = tf.cast(tf.equal(predictions, labels), tf.int32)
    accuracy_ops = []
    for index, coarse_key in list(enumerate(dataset.coarse_labels_to_names)):
      label_correct = tf.boolean_mask(correct, tf.equal(coarse_key, coarse_labels))
      sum_correct = tf.reduce_sum(label_correct)
      sum_correct = tf.cast(tf.expand_dims(sum_correct, 0), tf.float32)
      delta_totals = tf.SparseTensor([[index]], sum_correct, totals.get_shape())
      label_count = tf.cast(tf.shape(label_correct), tf.float32)
      delta_counts = tf.SparseTensor([[index]], label_count, counts.get_shape())

      totals_compute_op = tf.assign_add(
          totals,
          tf.sparse_tensor_to_dense(delta_totals),
          use_locking=True)
      counts_compute_op = tf.assign_add(
          counts,
          tf.sparse_tensor_to_dense(delta_counts),
          use_locking=True)

      accuracy_ops.append(totals_compute_op)
      accuracy_ops.append(counts_compute_op)
    with tf.control_dependencies(accuracy_ops):
      update_op = tf.select(tf.equal(counts, 0),
                            tf.zeros_like(counts, tf.float32),
                            tf.div(totals, counts))
      names_to_updates['Coarse_Label_Accuracy'] = update_op

    if FLAGS.recall:
      recall_value, recall_update = slim.metrics.streaming_recall_at_k(
          logits, labels, 5)
      names_to_values['Recall@5'] = recall_value
      names_to_updates['Recall@5'] = recall_update

    # Print the summaries to screen.
    # TODO(vonclites) list(d.items()) is for Python 3... check compatibility
    for name, value in list(names_to_values.items()):
      summary_name = 'eval/%s' % name
      op = tf.scalar_summary(summary_name, value, collections=[])
      op = tf.Print(op, [value], summary_name)
      tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)

    for index, label_name in list(enumerate(dataset.coarse_labels_to_names.values())):
      summary_name = 'eval/%s' % label_name
      op = tf.scalar_summary(summary_name, update_op[index], collections=[])
      op = tf.Print(op, [update_op[index]], summary_name)
      tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)

    # TODO(sguada) use num_epochs=1
    if FLAGS.max_num_batches:
      num_batches = FLAGS.max_num_batches
    else:
      # This ensures that we make a single pass over all of the data.
      num_batches = math.ceil(dataset.num_samples / float(FLAGS.batch_size))

#    if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
#      checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
#    else:
#      checkpoint_path = FLAGS.checkpoint_path

    tf.logging.info('Evaluating %s' % FLAGS.checkpoint_path)

    slim.evaluation.evaluation_loop(
        master=FLAGS.master,
        checkpoint_dir=FLAGS.checkpoint_path,
        logdir=FLAGS.eval_dir,
        num_evals=num_batches,
        eval_op=list(names_to_updates.values()),
        eval_interval_secs=FLAGS.eval_interval_secs,
        variables_to_restore=slim.get_variables_to_restore())
示例#24
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # Create a saver.
        saver = tf.train.Saver(tf.global_variables())

        ###########################
        # Kicks off the training. #
        ###########################
        init = tf.global_variables_initializer()

        # Start running operations on the Graph. allow_soft_placement must be set to
        # True to build towers on GPU, as some of the ops do not have GPU
        # implementations.
        sess = tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)
        if FLAGS.checkpoint_path == FLAGS.train_dir:
            saver.restore(sess, tf.train.latest_checkpoint(FLAGS.train_dir))

        # load pretrained weights
        weight_ini_fn = _get_init_fn()
        weight_ini_fn(sess)

        # Start the queue runners.
        tf.train.start_queue_runners(sess=sess)

        summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)

        for step in xrange(FLAGS.max_number_of_steps):
            start_time = time.time()
            # _, loss_value = sess.run([train_tensor, loss])
            # _, loss_value = sess.run([train_tensor, total_loss])
            loss_value = sess.run(train_tensor)
            duration = time.time() - start_time

            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            if step % FLAGS.log_every_n_steps == 0:
                # num_examples_per_step = FLAGS.batch_size * FLAGS.num_gpus
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                # sec_per_batch = duration / FLAGS.num_gpus
                sec_per_batch = duration

                format_str = ('step %d, loss = %.2f (%.1f examples/sec; %.3f '
                              'sec/batch)')
                print(format_str %
                      (step, loss_value, examples_per_sec, sec_per_batch))

            if step % FLAGS.summary_snapshot_steps == 0:
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

            # Save the model checkpoint periodically.
            if step % FLAGS.model_snapshot_steps == 0 or (
                    step + 1) == FLAGS.max_number_of_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)

        print('OK...')
def main(_):

    # Log
    tf.logging.set_verbosity(tf.logging.INFO)

    graph = tf.Graph()
    with graph.as_default(), tf.device('/cpu:0'):

        #########################################
        ########## required from data ###########
        #########################################
        num_samples_per_epoch = fileh.root.label_train.shape[0]
        num_batches_per_epoch = int(num_samples_per_epoch / FLAGS.batch_size)

        num_samples_per_epoch_test = fileh.root.label_test.shape[0]
        num_batches_per_epoch_test = int(num_samples_per_epoch_test / FLAGS.batch_size)

        # Create global_step
        global_step = tf.Variable(0, name='global_step', trainable=False)

        #####################################
        #### Configure the larning rate. ####
        #####################################
        learning_rate = _configure_learning_rate(num_samples_per_epoch, global_step)
        opt = _configure_optimizer(learning_rate)

        ######################
        # Select the network #
        ######################

        # Training flag.
        is_training = tf.placeholder(tf.bool)

        # Get the network. The number of subjects is num_subjects.
        model_speech_fn = nets_factory.get_network_fn(
            FLAGS.model_speech,
            num_classes=num_subjects,
            weight_decay=FLAGS.weight_decay,
            is_training=is_training)


        #####################################
        # Select the preprocessing function #
        #####################################

        # TODO: Do some preprocessing if necessary.

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        # with tf.device(deploy_config.inputs_device()):
        """
        Define the place holders and creating the batch tensor.
        """
        speech = tf.placeholder(tf.float32, (20, 80, 40, 1))
        label = tf.placeholder(tf.int32, (1))
        batch_dynamic = tf.placeholder(tf.int32, ())
        margin_imp_tensor = tf.placeholder(tf.float32, ())

        # Create the batch tensors
        batch_speech, batch_labels = tf.train.batch(
            [speech, label],
            batch_size=batch_dynamic,
            num_threads=FLAGS.num_preprocessing_threads,
            capacity=5 * FLAGS.batch_size)

        #############################
        # Specify the loss function #
        #############################
        tower_grads = []
        with tf.variable_scope(tf.get_variable_scope()):
            for i in range(FLAGS.num_clones):
                with tf.device('/gpu:%d' % i):
                    with tf.name_scope('%s_%d' % ('tower', i)) as scope:
                        """
                        Two distance metric are defined:
                           1 - distance_weighted: which is a weighted average of the distance between two structures.
                           2 - distance_l2: which is the regular l2-norm of the two networks outputs.
                        Place holders
                        """

                        ########################################
                        ######## Outputs of two networks #######
                        ########################################

                        # Distribute data among all clones equally.
                        step = int(FLAGS.batch_size / float(FLAGS.num_clones))

                        # Network outputs.
                        logits, end_points_speech = model_speech_fn(batch_speech[i * step: (i + 1) * step])


                        ###################################
                        ########## Loss function ##########
                        ###################################
                        # one_hot labeling
                        label_onehot = tf.one_hot(tf.squeeze(batch_labels[i * step : (i + 1) * step], [1]), depth=num_subjects, axis=-1)

                        SOFTMAX = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=label_onehot)

                        # Define loss
                        with tf.name_scope('loss'):
                            loss = tf.reduce_mean(SOFTMAX)

                        # Accuracy
                        with tf.name_scope('accuracy'):
                            # Evaluate the model
                            correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(label_onehot, 1))

                            # Accuracy calculation
                            accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

                        # ##### call the optimizer ######
                        # # TODO: call optimizer object outside of this gpu environment
                        #
                        # Reuse variables for the next tower.
                        tf.get_variable_scope().reuse_variables()

                        # Calculate the gradients for the batch of data on this CIFAR tower.
                        grads = opt.compute_gradients(loss)

                        # Keep track of the gradients across all towers.
                        tower_grads.append(grads)

        # We must calculate the mean of each gradient. Note that this is the
        # synchronization point across all towers.
        grads = average_gradients(tower_grads)

        # Apply the gradients to adjust the shared variables.
        apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)

        # Track the moving averages of all trainable variables.
        MOVING_AVERAGE_DECAY = 0.9999
        variable_averages = tf.train.ExponentialMovingAverage(
            MOVING_AVERAGE_DECAY, global_step)
        variables_averages_op = variable_averages.apply(tf.trainable_variables())

        # Group all updates to into a single train op.
        train_op = tf.group(apply_gradient_op, variables_averages_op)

        #################################################
        ########### Summary Section #####################
        #################################################

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Add summaries for all end_points.
        for end_point in end_points_speech:
            x = end_points_speech[end_point]
            summaries.add(tf.summary.scalar('sparsity/' + end_point,
                                            tf.nn.zero_fraction(x)))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        # Add to parameters to summaries
        summaries.add(tf.summary.scalar('learning_rate', learning_rate))
        summaries.add(tf.summary.scalar('global_step', global_step))
        summaries.add(tf.summary.scalar('eval/Loss', loss))
        summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

    ###########################
    ######## Training #########
    ###########################

    with tf.Session(graph=graph, config=tf.ConfigProto(allow_soft_placement=True)) as sess:

        # Initialization of the network.
        variables_to_restore = slim.get_variables_to_restore()
        saver = tf.train.Saver(variables_to_restore, max_to_keep=20)
        coord = tf.train.Coordinator()
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        # op to write logs to Tensorboard
        summary_writer = tf.summary.FileWriter(FLAGS.train_dir, graph=graph)

        #####################################
        ############## TRAIN ################
        #####################################

        step = 1
        for epoch in range(FLAGS.num_epochs):

            # Loop over all batches
            for batch_num in range(num_batches_per_epoch):
                step += 1
                start_idx = batch_num * FLAGS.batch_size
                end_idx = (batch_num + 1) * FLAGS.batch_size
                speech_train, label_train = fileh.root.utterance_train[start_idx:end_idx, :, :,
                                            :], fileh.root.label_train[start_idx:end_idx]

                # This transpose is necessary for 3D convolutional operation which will be performed by TensorFlow.
                speech_train = np.transpose(speech_train[None, :, :, :, :], axes=(1, 4, 2, 3, 0))

                # shuffling
                index = random.sample(range(speech_train.shape[0]), speech_train.shape[0])
                speech_train = speech_train[index]
                label_train = label_train[index]


                _, loss_value, train_accuracy, summary, training_step, _ = sess.run(
                    [train_op, loss, accuracy, summary_op, global_step, is_training],
                    feed_dict={is_training: True, batch_dynamic: label_train.shape[0], margin_imp_tensor: 100,
                               batch_speech: speech_train,
                               batch_labels: label_train.reshape([label_train.shape[0], 1])})
                summary_writer.add_summary(summary, epoch * num_batches_per_epoch + i)


                # # log
                if (batch_num + 1) % FLAGS.log_every_n_steps == 0:
                    print("Epoch " + str(epoch + 1) + ", Minibatch " + str(
                        batch_num + 1) + " of %d " % num_batches_per_epoch + ", Minibatch Loss= " + \
                          "{:.4f}".format(loss_value) + ", TRAIN ACCURACY= " + "{:.3f}".format(
                        100 * train_accuracy))

            # Save the model
            saver.save(sess, FLAGS.train_dir, global_step=training_step)

            # ###################################################
            # ############## TEST PER EACH EPOCH ################
            # ###################################################

            label_vector = np.zeros((FLAGS.batch_size * num_batches_per_epoch_test, 1))
            test_accuracy_vector = np.zeros((num_batches_per_epoch_test, 1))

            # Loop over all batches
            for i in range(num_batches_per_epoch_test):
                start_idx = i * FLAGS.batch_size
                end_idx = (i + 1) * FLAGS.batch_size
                speech_test, label_test = fileh.root.utterance_test[start_idx:end_idx, :, :,
                                                           :], fileh.root.label_test[
                                                                       start_idx:end_idx]

                # Get the test batch.
                speech_test = np.transpose(speech_test[None,:,:,:,:],axes=(1,4,2,3,0))

                # Evaluation
                loss_value, test_accuracy, _ = sess.run([loss, accuracy, is_training],
                                                                             feed_dict={is_training: False,
                                                                                        batch_dynamic: FLAGS.batch_size,
                                                                                        margin_imp_tensor: 50,
                                                                                        batch_speech: speech_test,
                                                                                        batch_labels: label_test.reshape(
                                                                                            [FLAGS.batch_size, 1])})
                label_test = label_test.reshape([FLAGS.batch_size, 1])
                label_vector[start_idx:end_idx] = label_test
                test_accuracy_vector[i, :] = test_accuracy


                # ROC

                ##############################
                ##### K-split validation #####
                ##############################
            print("TESTING after finishing the training on: epoch " + str(epoch + 1))
            # print("TESTING accuracy = ", 100 * np.mean(test_accuracy_vector, axis=0))

            K = 4
            Accuracy = np.zeros((K, 1))
            batch_k_validation = int(test_accuracy_vector.shape[0] / float(K))

            for i in range(K):
                Accuracy[i, :] = 100 * np.mean(test_accuracy_vector[i * batch_k_validation:(i + 1) * batch_k_validation], axis=0)

            # Reporting the K-fold validation
            print("Test Accuracy " + str(epoch + 1) + ", Mean= " + \
                          "{:.4f}".format(np.mean(Accuracy, axis=0)[0]) + ", std= " + "{:.3f}".format(
                        np.std(Accuracy, axis=0)[0]))
示例#26
0
def main():
    parser = argparse.ArgumentParser(description='Slim model generator')
    parser.add_argument('--model_name', type=str, help='')
    parser.add_argument('--export_dir',
                        type=str,
                        default="/tmp/export_dir",
                        help='GCS or local path to save graph files')
    parser.add_argument('--saved_model_dir',
                        type=str,
                        help='GCS or local path to save the generated model')
    parser.add_argument('--batch_size',
                        type=str,
                        default=1,
                        help='batch size to be used in the exported model')
    parser.add_argument('--checkpoint_url',
                        type=str,
                        help='URL to the pretrained compressed checkpoint')
    parser.add_argument('--num_classes',
                        type=int,
                        default=1000,
                        help='number of model classes')
    args = parser.parse_args()

    MODEL = args.model_name
    URL = args.checkpoint_url
    if not validators.url(args.checkpoint_url):
        print('use a valid URL parameter')
        exit(1)
    TMP_DIR = "/tmp/slim_tmp"
    NUM_CLASSES = args.num_classes
    BATCH_SIZE = args.batch_size
    MODEL_FILE_NAME = URL.rsplit('/', 1)[-1]
    EXPORT_DIR = args.export_dir
    SAVED_MODEL_DIR = args.saved_model_dir

    tmp_graph_file = os.path.join(TMP_DIR, MODEL + '_graph.pb')
    export_graph_file = os.path.join(EXPORT_DIR, MODEL + '_graph.pb')
    frozen_file = os.path.join(EXPORT_DIR, 'frozen_graph_' + MODEL + '.pb')

    if not os.path.exists(TMP_DIR):
        os.makedirs(TMP_DIR)

    if not os.path.exists(TMP_DIR + '/' + MODEL_FILE_NAME):
        print("Downloading and decompressing the model checkpoint...")
        response = requests.get(URL, stream=True)
        with open(os.path.join(TMP_DIR, MODEL_FILE_NAME), 'wb') as output:
            output.write(response.content)
            tar = tarfile.open(os.path.join(TMP_DIR, MODEL_FILE_NAME))
            tar.extractall(path=TMP_DIR)
            tar.close()
            print("Model checkpoint downloaded and decompressed to:", TMP_DIR)
    else:
        print("Reusing existing model file ",
              os.path.join(TMP_DIR, MODEL_FILE_NAME))

    checkpoint = glob.glob(TMP_DIR + '/*.ckpt*')
    print("checkpoint", checkpoint)
    if len(checkpoint) > 0:
        m = re.match(r"([\S]*.ckpt)", checkpoint[-1])
        print("checkpoint match", m)
        checkpoint = m[0]
        print(checkpoint)
    else:
        print("checkpoint file not detected in " + URL)
        exit(1)

    print("Saving graph def file")
    with tf.Graph().as_default() as graph:

        network_fn = nets_factory.get_network_fn(MODEL,
                                                 num_classes=NUM_CLASSES,
                                                 is_training=False)
        image_size = network_fn.default_image_size
        if BATCH_SIZE == "None" or BATCH_SIZE == "-1":
            batchsize = None
        else:
            batchsize = BATCH_SIZE
        placeholder = tf.placeholder(
            name='input',
            dtype=tf.float32,
            shape=[batchsize, image_size, image_size, 3])
        network_fn(placeholder)
        graph_def = graph.as_graph_def()

        with gfile.GFile(tmp_graph_file, 'wb') as f:
            f.write(graph_def.SerializeToString())
    if urlparse(EXPORT_DIR).scheme == 'gs':
        upload_to_gcs(tmp_graph_file, export_graph_file)
    elif urlparse(EXPORT_DIR).scheme == '':
        if not os.path.exists(EXPORT_DIR):
            os.makedirs(EXPORT_DIR)
        copyfile(tmp_graph_file, export_graph_file)
    else:
        print("Invalid format of model export path")
    print("Graph file saved to ", os.path.join(EXPORT_DIR,
                                               MODEL + '_graph.pb'))

    print("Analysing graph")
    p = Popen("./summarize_graph --in_graph=" + tmp_graph_file +
              " --print_structure=false",
              shell=True,
              stdout=PIPE,
              stderr=PIPE)
    summary, err = p.communicate()
    inputs = []
    outputs = []
    for line in summary.split(b'\n'):
        line_str = line.decode()
        if re.match(r"Found [\d]* possible inputs", line_str) is not None:
            print("in", line)
            m = re.findall(r'name=[\S]*,', line.decode())
            for match in m:
                print("match", match)
                input = match[5:-1]
                inputs.append(input)
            print("inputs", inputs)

        if re.match(r"Found [\d]* possible outputs", line_str) is not None:
            print("out", line)
            m = re.findall(r'name=[\S]*,', line_str)
            for match in m:
                print("match", match)
                output = match[5:-1]
                outputs.append(output)
            print("outputs", outputs)

    output_node_names = ",".join(outputs)
    print("Creating freezed graph based on pretrained checkpoint")
    freeze_graph(input_graph=tmp_graph_file,
                 input_checkpoint=checkpoint,
                 input_binary=True,
                 clear_devices=True,
                 input_saver='',
                 output_node_names=output_node_names,
                 restore_op_name="save/restore_all",
                 filename_tensor_name="save/Const:0",
                 output_graph=frozen_file,
                 initializer_nodes="")
    if urlparse(SAVED_MODEL_DIR).scheme == '' and \
            os.path.exists(SAVED_MODEL_DIR):
        shutil.rmtree(SAVED_MODEL_DIR)

    builder = tf.saved_model.builder.SavedModelBuilder(SAVED_MODEL_DIR)

    with tf.gfile.GFile(frozen_file, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    sigs = {}

    with tf.Session(graph=tf.Graph()) as sess:
        tf.import_graph_def(graph_def, name="")
        g = tf.get_default_graph()
        inp_dic = {}
        for inp in inputs:
            inp_t = g.get_tensor_by_name(inp + ":0")
            inp_dic[inp] = inp_t
        out_dic = {}
        for out in outputs:
            out_t = g.get_tensor_by_name(out + ":0")
            out_dic[out] = out_t

        sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
            tf.saved_model.signature_def_utils.predict_signature_def(
                inp_dic, out_dic)

        builder.add_meta_graph_and_variables(sess, [tag_constants.SERVING],
                                             signature_def_map=sigs)
    print("Exporting saved model to:", SAVED_MODEL_DIR + ' ...')
    builder.save()

    print("Saved model exported to:", SAVED_MODEL_DIR)
    _show_all(SAVED_MODEL_DIR)
    pb_visual_writer = tf.summary.FileWriter(EXPORT_DIR)
    pb_visual_writer.add_graph(sess.graph)
    print("Visualize the model by running: "
          "tensorboard --logdir={}".format(SAVED_MODEL_DIR))
    with open('/tmp/saved_model_dir.txt', 'w') as f:
        f.write(SAVED_MODEL_DIR)
    with open('/tmp/export_dir.txt', 'w') as f:
        f.write(EXPORT_DIR)
示例#27
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        tf_global_step = slim.get_or_create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ####################
        # Select the model #
        ####################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            is_training=False)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        provider = slim.dataset_data_provider.DatasetDataProvider(
            dataset,
            shuffle=False,
            common_queue_capacity=5 * FLAGS.batch_size,
            common_queue_min=FLAGS.batch_size)
        [image, label] = provider.get(['image', 'label'])
        label -= FLAGS.labels_offset

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=False)

        eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size

        image = image_preprocessing_fn(image, eval_image_size, eval_image_size)

        images, labels = tf.train.batch(
            [image, label],
            batch_size=FLAGS.batch_size,
            num_threads=FLAGS.num_preprocessing_threads,
            capacity=5 * FLAGS.batch_size)

        ####################
        # Define the model #
        ####################

        logits_1, end_points_1 = network_fn(images)

        attention_maps = tf.reduce_mean(end_points_1['attention_maps'],
                                        axis=-1,
                                        keepdims=True)
        attention_maps = tf.image.resize_bilinear(
            attention_maps, [eval_image_size, eval_image_size])
        bboxes = tf.py_func(mask2bbox, [attention_maps], [tf.float32])
        bboxes = tf.reshape(bboxes, [FLAGS.batch_size, 4])
        box_ind = tf.range(FLAGS.batch_size, dtype=tf.int32)

        images = tf.image.crop_and_resize(
            images,
            bboxes,
            box_ind,
            crop_size=[eval_image_size, eval_image_size])
        logits_2, end_points_2 = network_fn(images, reuse=True)

        logits = tf.log(
            tf.nn.softmax(logits_1) * 0.5 + tf.nn.softmax(logits_2) * 0.5)

        if FLAGS.moving_average_decay:
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, tf_global_step)
            variables_to_restore = variable_averages.variables_to_restore(
                slim.get_model_variables())
            variables_to_restore[tf_global_step.op.name] = tf_global_step
        else:
            variables_to_restore = slim.get_variables_to_restore()

        logits_to_updates = add_eval_summary(logits, labels, scope='/bilinear')
        logits_1_to_updates = add_eval_summary(logits_1,
                                               labels,
                                               scope='/logits_1')
        logits_2_to_updates = add_eval_summary(logits_2,
                                               labels,
                                               scope='/logits_2')

        # TODO(sguada) use num_epochs=1
        if FLAGS.max_num_batches:
            num_batches = FLAGS.max_num_batches
        else:
            # This ensures that we make a single pass over all of the data.
            num_batches = math.ceil(dataset.num_samples /
                                    float(FLAGS.batch_size))

        if len(FLAGS.gpus) == 0:
            config = tf.ConfigProto(device_count={'GPU': 0})
        else:
            config = tf.ConfigProto(allow_soft_placement=True,
                                    log_device_placement=False)
            config.gpu_options.allow_growth = True
            # config.gpu_options.per_process_gpu_memory_fraction = 1.0
            config.gpu_options.visible_device_list = FLAGS.gpus

        while True:
            if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
                checkpoint_path = tf.train.latest_checkpoint(
                    FLAGS.checkpoint_path)
            else:
                checkpoint_path = FLAGS.checkpoint_path

            tf.logging.info('Evaluating %s' % checkpoint_path)

            eval_op = list(logits_to_updates.values())
            eval_op.extend(list(logits_1_to_updates.values()))
            eval_op.extend(list(logits_2_to_updates.values()))

            slim.evaluation.evaluate_once(
                master=FLAGS.master,
                checkpoint_path=checkpoint_path,
                logdir=FLAGS.eval_dir,
                num_evals=num_batches,
                eval_op=eval_op,
                variables_to_restore=variables_to_restore,
                session_config=config)

            time.sleep(30)
示例#28
0
######################################
# Build tensor graph
######################################
image_string = tf.placeholder(
    tf.string
)  # Entry to the computational graph, e.g. image_string = tf.gfile.FastGFile(image_file).read()
image = tf.image.decode_jpeg(
    image_string,
    channels=3,
    try_recover_truncated=True,
    acceptable_fraction=0.3)  ## To process corrupted image files
image_preprocessing_fn = preprocessing_factory.get_preprocessing(
    preprocessing_name, is_training=False)
network_fn = nets_factory.get_network_fn(FLAGS.model_name,
                                         get_labels(num_only=True),
                                         is_training=False)
if FLAGS.eval_image_size is None:
    eval_image_size = network_fn.default_image_size
processed_image = image_preprocessing_fn(image, eval_image_size,
                                         eval_image_size)
processed_images = tf.expand_dims(processed_image, 0)
logits, _ = network_fn(processed_images)
probabilities = tf.nn.softmax(logits)
# Initialize graph
init_fn = slim.assign_from_checkpoint_fn(
    checkpoint_path, slim.get_model_variables(model_variables))
sess = tf.Session()
init_fn(sess)

示例#29
0
#获取图片数据和标签
image, image_raw, label0, label1, label2, label3 = read_and_decode(
    TFRECORD_FILE)

#使用shuffle_batch可以随机打乱
image_batch, image_raw_batch, label_batch0, label_batch1, label_batch2, label_batch3 = tf.train.shuffle_batch(
    [image, image_raw, label0, label1, label2, label3],
    batch_size=BATCH_SIZE,
    capacity=50000,
    min_after_dequeue=10000,
    num_threads=1)

#定义网络结构
train_network_fn = nets_factory.get_network_fn('alexnet_v2',
                                               num_classes=CHAR_SET_LEN,
                                               weight_decay=0.0005,
                                               is_training=False)

with tf.Session() as sess:
    X = tf.reshape(x, [BATCH_SIZE, 224, 224, 1])
    #数据输入网络得到输出值
    logits0, logits1, logits2, logits3, end_points = train_network_fn(X)

    #预测值
    predict0 = tf.reshape(logits0, [-1, CHAR_SET_LEN])
    predict0 = tf.argmax(predict0, 1)

    predict1 = tf.reshape(logits1, [-1, CHAR_SET_LEN])
    predict1 = tf.argmax(predict1, 1)

    predict2 = tf.reshape(logits2, [-1, CHAR_SET_LEN])
示例#30
0
                                                               common_queue_min=100 * batch_size)
     images, labels = provider.get(['image', 'label'])
 
 images = tf.to_float(images)
 images = tf.concat([(tf.slice(images,[0,0,0],[32,32,1])-112.4776)/70.4587,
                     (tf.slice(images,[0,0,1],[32,32,1])-124.1058)/65.4312,
                     (tf.slice(images,[0,0,2],[32,32,1])-129.3773)/68.2094],2)
 batch_images, batch_labels = tf.train.batch([images, labels],
                                         batch_size = batch_size,
                                         num_threads = 1,
                                         capacity = 200 * batch_size)
 
 batch_queue = slim.prefetch_queue.prefetch_queue([batch_images, batch_labels], capacity=50*batch_size)
 img, lb = batch_queue.dequeue()
 ## Load Model
 network_fn = nets_factory.get_network_fn(model_name)
 end_points = network_fn(img, is_training=False)
 print (end_points)
 task1 = tf.to_int32(tf.argmax(end_points['Logits'], 1))
 
 training_accuracy1 = slim.metrics.accuracy(task1, tf.to_int32(lb))
 
 variables_to_restore = slim.get_variables_to_restore()
 checkpoint_path = latest_checkpoint(train_dir)
 saver = Saver(variables_to_restore)
 config = ConfigProto()
 config.gpu_options.allow_growth=True
 sess = Session(config=config)
 sv = supervisor.Supervisor(logdir=checkpoint_path,
                            summary_op=None,
                            summary_writer=None,
示例#31
0
def main():
    args, cfg = parse_args()
    train_dir = get_output_dir(
        'default' if args.cfg_file is None else args.cfg_file)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    print('Using Config:')
    pprint.pprint(cfg)

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        tf_global_step = tf.train.get_or_create_global_step()

        ######################
        # Select the dataset #
        ######################
        kwargs = {}
        if cfg.TEST.VIDEO_FRAMES_PER_VIDEO > 1:
            kwargs['num_samples'] = cfg.TEST.VIDEO_FRAMES_PER_VIDEO
            kwargs['modality'] = cfg.INPUT.VIDEO.MODALITY
            kwargs['split_id'] = cfg.INPUT.SPLIT_ID
        if args.dataset_list_dir is not None:
            kwargs['dataset_list_dir'] = args.dataset_list_dir
        elif cfg.DATASET_LIST_DIR != '':
            kwargs['dataset_list_dir'] = cfg.DATASET_LIST_DIR
        if cfg.INPUT_FILE_STYLE_LABEL != '':
            kwargs['input_file_style_label'] = cfg.INPUT_FILE_STYLE_LABEL
        dataset, num_pose_keypoints = dataset_factory.get_dataset(
            cfg.DATASET_NAME, cfg.TEST.DATASET_SPLIT_NAME, cfg.DATASET_DIR,
            **kwargs)

        ####################
        # Select the model #
        ####################
        network_fn = nets_factory.get_network_fn(
            cfg.MODEL_NAME,
            num_classes=dataset.num_classes,
            num_pose_keypoints=num_pose_keypoints,
            is_training=False,
            cfg=cfg)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        provider = slim.dataset_data_provider.DatasetDataProvider(
            dataset,
            shuffle=False,
            num_epochs=1,
            common_queue_capacity=2 * cfg.TEST.BATCH_SIZE,
            common_queue_min=cfg.TEST.BATCH_SIZE)
        [image, action_label] = get_input(provider, cfg,
                                          ['image', 'action_label'])
        # label -= FLAGS.labels_offset

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = cfg.MODEL_NAME
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=False)

        eval_image_size = cfg.TRAIN.IMAGE_SIZE or network_fn.default_image_size

        image = image_preprocessing_fn(image,
                                       eval_image_size,
                                       eval_image_size,
                                       resize_side_min=cfg.TRAIN.RESIZE_SIDE,
                                       resize_side_max=cfg.TRAIN.RESIZE_SIDE)

        # additional preprocessing as required
        if 'flips' in args.preprocs:
            tf.logging.info('Flipping all images while testing!')
            image = tf.stack(
                [tf.image.flip_left_right(el) for el in tf.unstack(image)])

        images, action_labels = tf.train.batch(
            [image, action_label],
            batch_size=cfg.TEST.BATCH_SIZE,
            # following is because if there are more, the order of batch can be
            # different due to different speed... so avoid that
            # http://stackoverflow.com/questions/35001027/does-batching-queue-tf-train-batch-not-preserve-order#comment57731040_35001027
            # num_threads=1 if args.save else cfg.NUM_PREPROCESSING_THREADS,
            num_threads=
            1,  # The above was too unsafe as sometimes I forgot --save
            # and it would just randomize the whole thing.
            # This is very important so
            # shifting to this by default. Better safe than sorry.
            allow_smaller_final_batch=True if cfg.TEST.VIDEO_FRAMES_PER_VIDEO
            == 1 else False,  # because otherwise we need to
            # average logits over the frames,
            # and that needs first dimensions
            # to be fully defined
            capacity=5 * cfg.TEST.BATCH_SIZE)

        ####################
        # Define the model #
        ####################
        logits, end_points = network_fn(images)
        end_points['images'] = images

        if cfg.TEST.MOVING_AVERAGE_DECAY:
            variable_averages = tf.train.ExponentialMovingAverage(
                cfg.TEST.MOVING_AVERAGE_DECAY, tf_global_step)
            variables_to_restore = variable_averages.variables_to_restore(
                slim.get_model_variables())
            variables_to_restore[tf_global_step.op.name] = tf_global_step
        else:
            variables_to_restore = slim.get_variables_to_restore()

        predictions = tf.argmax(logits, 1)
        if cfg.TRAIN.LOSS_FN_ACTION.startswith('multi-label'):
            logits = tf.sigmoid(logits)
        else:
            logits = tf.nn.softmax(logits, -1)
        labels = tf.squeeze(action_labels)
        end_points['labels'] = labels

        # Define the metrics:
        names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
            'Accuracy':
            slim.metrics.streaming_accuracy(predictions, labels),
            # 'Recall@5': slim.metrics.streaming_recall_at_k(
            #     logits, labels, 5),
        })

        # Print the summaries to screen.
        for name, value in names_to_values.iteritems():
            summary_name = 'eval/%s' % name
            op = tf.summary.scalar(summary_name, value, collections=[])
            op = tf.Print(op, [value], summary_name)
            tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)

        # TODO(sguada) use num_epochs=1
        if cfg.TEST.MAX_NUM_BATCHES:
            num_batches = cfg.TEST.MAX_NUM_BATCHES
        else:
            # This ensures that we make a single pass over all of the data.
            num_batches = math.ceil(dataset.num_samples /
                                    float(cfg.TEST.BATCH_SIZE))

        # just test the latest trained model
        checkpoint_path = cfg.TEST.CHECKPOINT_PATH or train_dir
        if tf.gfile.IsDirectory(checkpoint_path):
            checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
        else:
            checkpoint_path = checkpoint_path
        checkpoint_step = int(checkpoint_path.split('-')[-1])

        tf.logging.info('Evaluating %s' % checkpoint_path)

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        summary_writer = tf.summary.FileWriter(logdir=train_dir)

        if cfg.TEST.EVAL_METRIC == 'mAP' or args.save or args.ept:
            from tensorflow.python.training import supervisor
            from tensorflow.python.framework import ops
            import h5py
            saver = tf.train.Saver(variables_to_restore)
            sv = supervisor.Supervisor(graph=ops.get_default_graph(),
                                       logdir=None,
                                       summary_op=None,
                                       summary_writer=summary_writer,
                                       global_step=None,
                                       saver=None)
            all_labels = []
            end_points['logits'] = logits
            end_points_to_save = args.ept + ['logits']
            end_points_to_save = list(set(end_points_to_save))
            all_feats = dict([(ename, []) for ename in end_points_to_save])
            start_time = time.time()
            with sv.managed_session('',
                                    start_standard_services=False,
                                    config=config) as sess:
                saver.restore(sess, checkpoint_path)
                sv.start_queue_runners(sess)
                for j in tqdm(range(int(math.ceil(num_batches)))):
                    feats = sess.run([
                        action_labels,
                        [end_points[ename] for ename in end_points_to_save]
                    ])
                    all_labels.append(feats[0])
                    for ept_id, ename in enumerate(end_points_to_save):
                        all_feats[ename].append(feats[1][ept_id])
            print(time.time() - start_time)
            APs = []
            all_labels = np.concatenate(all_labels)
            if args.save or args.ept:
                res_outdir = os.path.join(train_dir, 'Features/')
                mkdir_p(res_outdir)
                outfpath = args.outfpath or os.path.join(
                    res_outdir, 'features_ckpt_{}_{}.h5'.format(
                        cfg.TEST.DATASET_SPLIT_NAME, checkpoint_step))
                print(
                    'Saving the features/logits/labels to {}'.format(outfpath))
                with h5py.File(outfpath, 'a') as fout:
                    for ename in end_points_to_save:
                        if ename in fout:
                            tf.logging.warning(
                                'Deleting {} from output HDF5 to write the '
                                'new features.'.format(ename))
                            del fout[ename]
                        if ename == 'labels':
                            feat_to_save = np.array(all_feats[ename])
                        else:
                            feat_to_save = np.concatenate(all_feats[ename])
                        try:
                            fout.create_dataset(ename,
                                                data=feat_to_save,
                                                compression='gzip',
                                                compression_opts=9)
                        except:
                            pdb.set_trace(
                            )  # manually deal with it and continue
                    if 'labels' in fout:
                        del fout['labels']
                    fout.create_dataset('labels',
                                        data=all_labels,
                                        compression='gzip',
                                        compression_opts=9)

            if args.ept:
                tf.logging.info(
                    'Evaluation had --ept passed in. '
                    'This indicates script was used for feature '
                    'extraction. Hence, not performing any evaluation.')
                return
            # Evaluation code
            all_logits = np.concatenate(all_feats['logits'])
            acc = np.mean(all_logits.argmax(axis=1) == all_labels)
            mAP = compute_map(all_logits, all_labels)[0]
            print('Mean AP: {}'.format(mAP))
            print('Accuracy: {}'.format(acc))
            summary_writer.add_summary(tf.Summary(value=[
                tf.Summary.Value(tag='mAP/{}'.format(
                    cfg.TEST.DATASET_SPLIT_NAME),
                                 simple_value=mAP)
            ]),
                                       global_step=checkpoint_step)
            summary_writer.add_summary(tf.Summary(value=[
                tf.Summary.Value(tag='Accuracy/{}'.format(
                    cfg.TEST.DATASET_SPLIT_NAME),
                                 simple_value=acc)
            ]),
                                       global_step=checkpoint_step)
        else:
            slim.evaluation.evaluate_once(
                master='',
                checkpoint_path=checkpoint_path,
                logdir=train_dir,
                num_evals=num_batches,
                eval_op=names_to_updates.values(),
                variables_to_restore=variables_to_restore,
                session_config=config)
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        # with tf.device(deploy_config.variables_device()):
        global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        # with tf.device(deploy_config.inputs_device()):
        provider = slim.dataset_data_provider.DatasetDataProvider(
            dataset,
            num_readers=FLAGS.num_readers,
            common_queue_capacity=20 * FLAGS.batch_size,
            common_queue_min=10 * FLAGS.batch_size)
        [image, label] = provider.get(['image', 'label'])
        label -= FLAGS.labels_offset

        train_image_size = FLAGS.train_image_size or network_fn.default_image_size

        image = image_preprocessing_fn(image, train_image_size,
                                       train_image_size)

        images, labels = tf.train.batch(
            [image, label],
            batch_size=FLAGS.batch_size,
            num_threads=FLAGS.num_preprocessing_threads,
            capacity=5 * FLAGS.batch_size)
        labels = slim.one_hot_encoding(
            labels, dataset.num_classes - FLAGS.labels_offset)
        batch_queue = slim.prefetch_queue.prefetch_queue(
            [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        if FLAGS.quantize_delay >= 0:
            tf.contrib.quantize.create_training_graph(
                quant_delay=FLAGS.quantize_delay)

        #########################################
        # Configure the optimization procedure. #
        #########################################
        # with tf.device(deploy_config.optimizer_device()):
        learning_rate = _configure_learning_rate(dataset.num_samples,
                                                 global_step)
        optimizer = _configure_optimizer(learning_rate)
        summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        ###########################
        # Kicks off the training. #
        ###########################

        ## todos ubaid: tensor graph, metagraph, timingcode

        print(len(tf.get_default_graph().get_operations()))
        all_ops = tf.get_default_graph().get_operations()
        adj_list_graph_notensors = {}
        for op in all_ops:
            adj_list_graph_notensors[op.name] = set(
                [inp.name.split(":")[0] for inp in op.inputs])
        adj_list_graph_notensors = {
            op_name: list(op_deps)
            for op_name, op_deps in adj_list_graph_notensors.items()
        }
        with open('%s/org_train_graph_notensors.json' % (FLAGS.log_fn_pref),
                  'w') as outfile:
            json.dump(adj_list_graph_notensors, outfile)

        adj_list_graph = {}
        for op in all_ops:
            adj_list_graph[op.name] = set([inp.name for inp in op.inputs])
        adj_list_graph = {
            op_name: list(op_deps)
            for op_name, op_deps in adj_list_graph.items()
        }
        with open('%s/org_train_graph.json' % (FLAGS.log_fn_pref),
                  'w') as outfile:
            json.dump(adj_list_graph, outfile)

        print(':::::: done creating model ::::::')
        print(len(all_ops))
        print(FLAGS.log_fn_pref)
        print(FLAGS.doLog)

        metagraph = tf.train.export_meta_graph()
        temp_meta = MessageToJson(metagraph.graph_def)
        with open('%s/metagraph.json' % (FLAGS.log_fn_pref), 'w') as outfile:
            json.dump(temp_meta, outfile)

        config_proto = tf.ConfigProto(
            log_device_placement=False,
            allow_soft_placement=False,
            graph_options=tf.GraphOptions(build_cost_model=1))

        config_proto.gpu_options.allow_growth = True
        config_proto.intra_op_parallelism_threads = 1
        config_proto.inter_op_parallelism_threads = 1
        print("*********!!!!!!!No opts!!!!!!!*********")
        config_proto.graph_options.optimizer_options.opt_level = -1
        config_proto.graph_options.rewrite_options.constant_folding = (
            rewriter_config_pb2.RewriterConfig.OFF)
        config_proto.graph_options.rewrite_options.arithmetic_optimization = (
            rewriter_config_pb2.RewriterConfig.OFF)
        config_proto.graph_options.rewrite_options.dependency_optimization = (
            rewriter_config_pb2.RewriterConfig.OFF)
        config_proto.graph_options.rewrite_options.layout_optimizer = (
            rewriter_config_pb2.RewriterConfig.OFF)

        print(FLAGS.train_dir)

        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            session_config=config_proto,
            master=FLAGS.master,
            #is_chief=(FLAGS.task == 0),
            is_chief=1,
            init_fn=_get_init_fn(),
            summary_op=summary_op,
            #number_of_steps=FLAGS.max_number_of_steps,
            number_of_steps=14,
            log_every_n_steps=FLAGS.log_every_n_steps,
            trace_every_n_steps=1,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            #        sync_optimizer=optimizer if FLAGS.sync_replicas else None)
            sync_optimizer=optimizer if FLAGS.sync_replicas else None,
            log_fn_pref=FLAGS.log_fn_pref,
            doLogs=FLAGS.doLog)
示例#33
0
def main(_):
	if not FLAGS.dataset_dir:
		raise ValueError('You must supply the dataset directory with --dataset_dir')
	FLAGS.vis_dir = FLAGS.vis_dir + FLAGS.model_name + '/'
	FLAGS.checkpoint_path = FLAGS.checkpoint_path + FLAGS.model_name + '/training/'
	compute_saliency = False
	max_saliency = 200
	###################################
	# Dictionnary for label inference #
	###################################       
	keys_9 = []
	vals_9 = []
	with open(FLAGS.dataset_dir + "labels_to_labels_9.txt") as f:
		for line in f:
			 (key, val) = line.split(":")
			 keys_9.append(int(key))
			 vals_9.append(int(val))
                
	if FLAGS.model_name == "inception_v4":
		embedding_name = 'PreLogitsFlatten'
		embedding_dim = 1536
	elif FLAGS.model_name == "inception_v3":
		embedding_name = 'PreLogits'
		embedding_dim = 2048
        
	tf.logging.set_verbosity(tf.logging.INFO)

	tf.reset_default_graph()
	graph = tf.Graph()
	with graph.as_default():
		tf_global_step = slim.get_or_create_global_step()

		label_dict_9 = tf.contrib.lookup.HashTable(
			tf.contrib.lookup.KeyValueTensorInitializer(
				keys_9, vals_9, key_dtype=tf.int64, value_dtype=tf.int64), -1)

		step = tf.Variable(0, dtype=tf.int32,
			collections=[tf.GraphKeys.LOCAL_VARIABLES], trainable=False)

		######################
		# Select the dataset #
		######################
		dataset = dataset_split.get_split(
			FLAGS.dataset_split_name, FLAGS.dataset_dir)

		#########################################
		# Define different variable of interest #
		#########################################
		labels_total = tf.Variable([0] * dataset.num_samples, dtype=tf.int64,
			collections=[tf.GraphKeys.LOCAL_VARIABLES], trainable=False)
		labels_total_9 = tf.Variable([0] * dataset.num_samples, dtype=tf.int64,
			collections=[tf.GraphKeys.LOCAL_VARIABLES], trainable=False)
        
		embeddings_total = tf.Variable([[0] * embedding_dim] * dataset.num_samples, dtype=tf.float32,
			collections=[tf.GraphKeys.LOCAL_VARIABLES], trainable=False)
        
		if compute_saliency:
			print("Wait...")
			images_total = tf.Variable([np.zeros((299, 299, 3)).tolist()] * max_saliency, dtype=tf.float32,
				collections=[tf.GraphKeys.LOCAL_VARIABLES], trainable=False)
			print("Ok.")
			print("Wait...")
			saliency_total = tf.Variable([np.zeros((299, 299)).tolist()] * max_saliency, dtype=tf.float32,
				collections=[tf.GraphKeys.LOCAL_VARIABLES], trainable=False)
			print("Ok.")
        
		####################
		# Select the model #
		####################
		network_fn = nets_factory.get_network_fn(
			FLAGS.model_name,
			num_classes=(dataset.num_classes - FLAGS.labels_offset),
			is_training=False)
	
		##############################################################
		# Create a dataset provider that loads data from the dataset #
		##############################################################
		provider = slim.dataset_data_provider.DatasetDataProvider(
			dataset,
			shuffle=False,
			common_queue_capacity=2 * FLAGS.batch_size,
			common_queue_min=FLAGS.batch_size)
		[image, label] = provider.get(['image', 'label'])
		label -= FLAGS.labels_offset
	
		#####################################
		# Select the preprocessing function #
		#####################################
		preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
		image_preprocessing_fn = preprocessing_factory.get_preprocessing(
			preprocessing_name,
			is_training=False)

		eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size
	
		image = image_preprocessing_fn(image, eval_image_size, eval_image_size)
		
		images, labels = tf.train.batch(
			[image, label],
			batch_size=FLAGS.batch_size,
			num_threads=FLAGS.num_preprocessing_threads,
			capacity=5 * FLAGS.batch_size)
	
		####################
		# Define the model #
		####################
		logits, end_points = network_fn(images)
		probs = tf.nn.softmax(logits, axis=1)
        
		if FLAGS.moving_average_decay:
			variable_averages = tf.train.ExponentialMovingAverage(
				FLAGS.moving_average_decay, tf_global_step)
			variables_to_restore = variable_averages.variables_to_restore(
				slim.get_model_variables())
			variables_to_restore[tf_global_step.op.name] = tf_global_step
		else:
			variables_to_restore = slim.get_variables_to_restore()

		labels = tf.squeeze(labels)
        
		embeddings = tf.squeeze(end_points[embedding_name])

		labels_9 = label_dict_9.lookup(labels)

		# Saliency map
		if compute_saliency:
			correct_scores = tf.gather_nd(logits,
                                  tf.stack((tf.range(images.shape[0]),
                                            tf.cast(labels, tf.int32)), axis=1))
        
			loss_ = tf.reduce_sum(correct_scores, axis=0)
			grads_ = tf.gradients(loss_, images)[0]
			saliency = tf.reduce_max(tf.abs(grads_), axis=-1)
        
		op = []
		with tf.control_dependencies([labels, labels_9, embeddings]):
			labels_assign_op = tf.assign(
				labels_total[step * FLAGS.batch_size:(step + 1) * FLAGS.batch_size],
				tf.identity(labels))
			op.append(labels_assign_op)
			labels_9_assign_op = tf.assign(
				labels_total_9[step * FLAGS.batch_size:(step + 1) * FLAGS.batch_size],
				tf.identity(labels_9))
			op.append(labels_9_assign_op)
			embeddings_assign_op = tf.assign(
				embeddings_total[step * FLAGS.batch_size:(step + 1) * FLAGS.batch_size],
				tf.identity(embeddings))
			op.append(embeddings_assign_op)
            
		if compute_saliency:
			with tf.control_dependencies([saliency]):
				#step_ = tf.cond(step <= max_saliency, lambda: tf.identity(step), lambda: tf.identity(step_))
				images_assign_op = tf.assign(
					images_total[step * FLAGS.batch_size:(step + 1) * FLAGS.batch_size],
					tf.identity(images))
				op.append(images_assign_op)
				saliency_assign_op = tf.assign(
					saliency_total[step * FLAGS.batch_size:(step + 1) * FLAGS.batch_size],
					tf.identity(saliency))
				op.append(saliency_assign_op)

		with tf.control_dependencies(op):
			step_update_op = tf.assign(step, step + 1)

		# TODO(sguada) use num_epochs=1
		if FLAGS.max_num_batches:
			num_batches = FLAGS.max_num_batches
		else:
			# This ensures that we make a single pass over all of the data.
			num_batches = math.ceil(dataset.num_samples / float(FLAGS.batch_size))
	
		if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
			checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
		else:
			checkpoint_path = FLAGS.checkpoint_path
	
		tf.logging.info('Evaluating %s' % checkpoint_path)
		if compute_saliency:
			final_op = [embeddings_total, labels_total_9, saliency_total, images_total]
		else:
			final_op = [embeddings_total, labels_total_9]
            
		output = slim.evaluation.evaluate_once(
			master=FLAGS.master,
			checkpoint_path=checkpoint_path,
			logdir=FLAGS.vis_dir,
			num_evals=num_batches,
			eval_op=[step_update_op],
			final_op=final_op,
			variables_to_restore=variables_to_restore)
        
	embed, lab = output[0], output[1]
	log_dir = os.path.join(FLAGS.vis_dir, 'tsne_%s/' % FLAGS.dataset_split_name)
	if not os.path.exists(log_dir):
		os.makedirs(log_dir)
	metadata = os.path.join(log_dir, 'metadata.tsv')
	open(metadata, 'a').close()
	images = tf.Variable(embed, name='images')
	with open(metadata, 'w') as metadata_file:
		for row in lab:
			metadata_file.write('%d\n' % row)
	with tf.Session() as sess:
		saver = tf.train.Saver([images])

		sess.run(images.initializer)
		saver.save(sess, os.path.join(log_dir, 'images.ckpt'))

		config = projector.ProjectorConfig()
		# One can add multiple embeddings.
		embedding = config.embeddings.add()
		embedding.tensor_name = images.name
		# Link this tensor to its metadata file (e.g. labels).
		embedding.metadata_path = metadata
		# Saves a config file that TensorBoard will read during startup.
		projector.visualize_embeddings(tf.summary.FileWriter(log_dir), config)

	if compute_saliency:
		sal, im = output[2], output[3]
		saliency_dir = os.path.join(FLAGS.vis_dir, 'saliency_%s/' % FLAGS.dataset_split_name)
		if not os.path.exists(saliency_dir):
			os.makedirs(saliency_dir)
            
		im = im / 2 + 0.5
        
		for i, img in enumerate(im[:max_saliency]):
			scipy.misc.imsave(saliency_dir + 'img_%i_label_%i.jpg'% (i, lab[i]), img)
			scipy.misc.imsave(saliency_dir + 'sal_%i_label_%i.jpg'% (i, lab[i]), sal[i])
示例#34
0
def main(_):
  if not FLAGS.dataset_dir:
    raise ValueError('You must supply the dataset directory with --dataset_dir')

  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    tf_global_step = slim.get_or_create_global_step()

    ######################
    # Select the dataset #
    ######################
    dataset = dataset_factory.get_dataset(
        FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

    ####################
    # Select the model #
    ####################
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        is_training=False)

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    provider = slim.dataset_data_provider.DatasetDataProvider(
        dataset,
        shuffle=False,
        common_queue_capacity=2 * FLAGS.batch_size,
        common_queue_min=FLAGS.batch_size)
    [image, label] = provider.get(['image', 'label'])
    label -= FLAGS.labels_offset

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
        preprocessing_name,
        is_training=False)

    eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size

    image = image_preprocessing_fn(image, eval_image_size, eval_image_size)

    images, labels = tf.train.batch(
        [image, label],
        batch_size=FLAGS.batch_size,
        num_threads=FLAGS.num_preprocessing_threads,
        capacity=5 * FLAGS.batch_size)

    ####################
    # Define the model #
    ####################
    logits, _ = network_fn(images)

    if FLAGS.quantize:
      tf.contrib.quantize.create_eval_graph()

    if FLAGS.moving_average_decay:
      variable_averages = tf.train.ExponentialMovingAverage(
          FLAGS.moving_average_decay, tf_global_step)
      variables_to_restore = variable_averages.variables_to_restore(
          slim.get_model_variables())
      variables_to_restore[tf_global_step.op.name] = tf_global_step
    else:
      if FLAGS.append_scope_string:
            # If I've specified a string for the name of the scope in the checkpoint file, append it here so we can match up the layer names
            variables_to_restore_orig = slim.get_variables_to_restore()    
            variables_to_restore = {}
            for var in variables_to_restore_orig:
                curr_name = var.op.name
                if 'global_step' not in curr_name:
                    new_name = FLAGS.append_scope_string + '/' + curr_name
                else:
                    new_name = curr_name 
                variables_to_restore[new_name]=  var
      else:
            variables_to_restore = slim.get_variables_to_restore()

    predictions = tf.argmax(logits, 1)
    labels = tf.squeeze(labels)

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    # Define the validation set metrics.
    # resettable.
    with tf.name_scope('eval_metrics'):
      eval_acc_value, eval_acc_op = tf.metrics.accuracy(predictions=predictions,labels=labels)    
      eval_recall_5_value, eval_recall_5_op = slim.metrics.streaming_recall_at_k(predictions=logits, labels=labels,k=5) 
      # add these variables as summaries for tensorboard
      summaries.add(tf.summary.scalar('eval_recall_5', eval_recall_5_value))
      summaries.add(tf.summary.scalar('eval_acc', eval_acc_value))
      
    # gather up all the variables that are used to compute eval metrics
    #stream_vars = [i for i in tf.local_variables() if i.name.split('/')[0]=='eval_metrics']
    # make an operation that'll let us re-initialize just these vars.
    #reset_op = tf.initialize_variables(stream_vars)
    
    # make an operation that'll let us run evaluation (all metrics)
    eval_op = list([eval_acc_op, eval_recall_5_op])
    
     # Gather validation summaries
    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES))
    
    # Merge all summaries together (this includes training summaries too).
    summary_op = tf.summary.merge(list(summaries), name='summary_op')

#    # Define the metrics:
#    names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
#        'Accuracy': slim.metrics.streaming_accuracy(predictions, labels),
#        'Recall_5': slim.metrics.streaming_recall_at_k(
#            logits, labels, 5),
#    })
#    eval_op=list(names_to_updates.values())
#    # Print the summaries to screen.
#    for name, value in names_to_values.items():
#      summary_name = 'eval_SEPARATE/%s' % name
#      op = tf.summary.scalar(summary_name, value, collections=[])
#      op = tf.Print(op, [value], summary_name)
#      tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)

    # TODO(sguada) use num_epochs=1
    if FLAGS.max_num_batches:
      num_batches = FLAGS.max_num_batches
    else:
      # This ensures that we make a single pass over all of the data.
      num_batches = math.ceil(dataset.num_samples / float(FLAGS.batch_size))

    if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
      checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
    else:
      checkpoint_path = FLAGS.checkpoint_path

    tf.logging.info('Evaluating %s' % checkpoint_path)

    slim.evaluation.evaluate_once(
        master=FLAGS.master,
        checkpoint_path=checkpoint_path,
        logdir=FLAGS.eval_dir,
        num_evals=num_batches,
        eval_op=eval_op,
        variables_to_restore=variables_to_restore,
        summary_op=summary_op)
def test():
    # Check training directory.
    train_dir = FLAGS.train_dir
    if not tf.gfile.IsDirectory(train_dir):
        tf.logging.info("Training directory %s not found.", train_dir)
        return

    # Build the TensorFlow graph.
    g = tf.Graph()
    with g.as_default():
        ####################
        # Select the network #
        ####################
        network_fn = nets_factory.get_network_fn(FLAGS.model_name,
                                                 num_classes=FLAGS.NUM_CLASSES,
                                                 is_training=False)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=False)

        test_size, test_data, test_label, test_names = async_loader.video_inputs(
            FLAGS.dataset_list,
            FLAGS.dataset_dir,
            FLAGS.resize_image_size,
            FLAGS.train_image_size,
            FLAGS.batch_size,
            FLAGS.n_steps,
            FLAGS.modality,
            FLAGS.read_stride,
            image_preprocessing_fn,
            shuffle=False,
            label_from_one=(FLAGS.labels_offset > 0),
            length1=FLAGS.length,
            crop=0,
            merge_label=FLAGS.merge_label)
        print("Batch size %d" % test_data.get_shape()[0].value)

        batch_size_per_gpu = FLAGS.batch_size
        global_step_tensor = slim.create_global_step()

        # Calculate the gradients for each model tower.
        logits, end_points = network_fn(test_data)
        if hasattr(network_fn, 'rnn_part'):
            logits, end_points_rnn = network_fn.rnn_part(logits)
            end_points.update(end_points_rnn)
        if not FLAGS.merge_label:
            logits = tf.split(logits, FLAGS.n_steps, 0)[-1]
            test_label = tf.split(test_label, FLAGS.n_steps, 0)[-1]
        top_k_op = tf.nn.in_top_k(logits, test_label, FLAGS.top_k)

        summary_op = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(FLAGS.eval_dir)

        if FLAGS.moving_average_decay:
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step_tensor)
            variables_to_restore = variable_averages.variables_to_restore(
                slim.get_variables())
            variables_to_restore[
                global_step_tensor.op.name] = global_step_tensor
        else:
            variables_to_restore = slim.get_variables_to_restore()

        for var in variables_to_restore:
            print("Will restore %s" % (var.op.name))
        saver = tf.train.Saver(variables_to_restore)
        sv = tf.train.Supervisor(graph=g,
                                 logdir=FLAGS.eval_dir,
                                 summary_op=None,
                                 summary_writer=None,
                                 global_step=None,
                                 saver=None)
        g.finalize()

        with sv.managed_session(FLAGS.master,
                                start_standard_services=False,
                                config=None) as sess:
            while True:
                start = time.time()
                tf.logging.info(
                    "Starting evaluation at " +
                    time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()))
                model_path = tf.train.latest_checkpoint(FLAGS.train_dir)
                if not model_path:
                    tf.logging.info(
                        "Skipping evaluation. No checkpoint found in: %s",
                        FLAGS.train_dir)
                else:
                    # Load model from checkpoint.
                    tf.logging.info("Loading model from checkpoint: %s",
                                    model_path)
                    saver.restore(sess, model_path)
                    global_step = tf.train.global_step(sess,
                                                       global_step_tensor.name)
                    tf.logging.info(
                        "Successfully loaded %s at global step = %d.",
                        os.path.basename(model_path), global_step)

                    if global_step > 0:
                        # Start the queue runners.
                        sv.start_queue_runners(sess)

                        # Run evaluation on the latest checkpoint.
                        try:
                            test_once(test_size,
                                      top_k_op,
                                      sess,
                                      test_names,
                                      batch_size_per_gpu,
                                      summary_op,
                                      summary_writer,
                                      show_log=True)
                        except Exception:  # pylint: disable=broad-except
                            tf.logging.error("Evaluation failed.")
                time_to_next_eval = start + FLAGS.eval_interval_secs - time.time(
                )
                if time_to_next_eval > 0:
                    time.sleep(time_to_next_eval)
def run_testing(degrad_ckpt_file, ckpt_dir, model_name, is_training):
    batch_size = 128
    # Create model directory
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    use_pretrained_model = True

    config = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=False)
    config.gpu_options.allow_growth = True
    network_fn = nets_factory.get_network_fn(
        model_name,
        num_classes=FLAGS.num_classes_budget,
        weight_decay=FLAGS.weight_decay,
        is_training=True)
    with tf.Graph().as_default():
        with tf.Session(config=config) as sess:
            images_placeholder, labels_placeholder, isTraining_placeholder = placeholder_inputs(
                batch_size * FLAGS.gpu_num)
            logits_lst = []
            losses_lst = []
            with tf.variable_scope(tf.get_variable_scope()) as scope:
                for gpu_index in range(0, FLAGS.gpu_num):
                    with tf.device('/gpu:%d' % gpu_index):
                        print('/gpu:%d' % gpu_index)
                        with tf.name_scope('%s_%d' %
                                           ('gpu', gpu_index)) as scope:
                            degrad_images = residualNet(
                                images_placeholder[gpu_index *
                                                   FLAGS.image_batch_size:
                                                   (gpu_index + 1) *
                                                   FLAGS.image_batch_size],
                                training=False)
                            logits, _ = network_fn(degrad_images)
                            loss = tf.reduce_mean(
                                tf.nn.sigmoid_cross_entropy_with_logits(
                                    logits=logits,
                                    labels=labels_placeholder[gpu_index *
                                                              batch_size:
                                                              (gpu_index + 1) *
                                                              batch_size, :]))
                            logits_lst.append(logits)
                            losses_lst.append(loss)
                            # Reuse variables for the next tower.
                            tf.get_variable_scope().reuse_variables()
            loss_op = tf.reduce_mean(losses_lst)
            logits_op = tf.concat(logits_lst, 0)

            train_image_files = [
                os.path.join(FLAGS.train_images_files_dir, f)
                for f in os.listdir(FLAGS.train_images_files_dir)
                if f.endswith('.tfrecords')
            ]
            test_image_files = [
                os.path.join(FLAGS.test_images_files_dir, f)
                for f in os.listdir(FLAGS.test_images_files_dir)
                if f.endswith('.tfrecords')
            ]
            print(
                '#############################Reading from files###############################'
            )
            print(test_image_files)

            if is_training:
                images_op, labels_op = inputs_images(
                    filenames=train_image_files,
                    batch_size=FLAGS.image_batch_size * FLAGS.gpu_num,
                    num_epochs=1,
                    num_threads=FLAGS.num_threads,
                    num_examples_per_epoch=FLAGS.num_examples_per_epoch,
                    shuffle=False)
            else:
                images_op, labels_op = inputs_images(
                    filenames=test_image_files,
                    batch_size=FLAGS.image_batch_size * FLAGS.gpu_num,
                    num_epochs=1,
                    num_threads=FLAGS.num_threads,
                    num_examples_per_epoch=FLAGS.num_examples_per_epoch,
                    shuffle=False)

            init_op = tf.group(tf.local_variables_initializer(),
                               tf.global_variables_initializer())
            sess.run(init_op)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            varlist_budget = [
                v for v in tf.trainable_variables()
                if any(x in v.name for x in [
                    "InceptionV1", "InceptionV2", "resnet_v1_50",
                    "resnet_v1_101", "resnet_v2_50", "resnet_v2_101",
                    "MobilenetV1_1.0", "MobilenetV1_0.75", "MobilenetV1_0.5",
                    'MobilenetV1_0.25'
                ])
            ]

            varlist_degrad = [
                v for v in tf.trainable_variables() if v not in varlist_budget
            ]

            saver = tf.train.Saver(varlist_degrad)
            saver.restore(sess, degrad_ckpt_file)

            gvar_list = tf.global_variables()
            bn_moving_vars = [g for g in gvar_list if 'moving_mean' in g.name]
            bn_moving_vars += [
                g for g in gvar_list if 'moving_variance' in g.name
            ]
            saver = tf.train.Saver(varlist_budget + bn_moving_vars)
            ckpt = tf.train.get_checkpoint_state(checkpoint_dir=ckpt_dir)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                print('Session restored from pretrained budget model at {}!'.
                      format(ckpt.model_checkpoint_path))
            else:
                raise FileNotFoundError(errno.ENOENT,
                                        os.strerror(errno.ENOENT), ckpt_dir)

            loss_budget_lst = []
            pred_probs_lst = []
            gt_lst = []
            try:
                while not coord.should_stop():
                    images, labels = sess.run([images_op, labels_op])
                    gt_lst.append(labels)
                    feed = {
                        images_placeholder: images,
                        labels_placeholder: labels,
                        isTraining_placeholder: False
                    }
                    logits, loss = sess.run([logits_op, loss_op],
                                            feed_dict=feed)
                    loss_budget_lst.append(loss)
                    pred_probs_lst.append(logits)
            except tf.errors.OutOfRangeError:
                print('Done testing on all the examples')
            finally:
                coord.request_stop()

            pred_probs_mat = np.concatenate(pred_probs_lst, axis=0)
            gt_mat = np.concatenate(gt_lst, axis=0)
            n_examples, n_labels = gt_mat.shape
            isTraining = lambda bool: "training" if bool else "validation"
            with open(
                    os.path.join(
                        ckpt_dir, '{}_{}_class_scores.txt'.format(
                            model_name, isTraining(is_training))), 'w') as wf:
                wf.write('# Examples = {}\n'.format(n_examples))
                wf.write('# Labels = {}\n'.format(n_labels))
                wf.write('Macro MAP = {:.2f}\n'.format(
                    100 * average_precision_score(
                        gt_mat, pred_probs_mat, average='macro')))
                cmap_stats = average_precision_score(gt_mat,
                                                     pred_probs_mat,
                                                     average=None)
                attr_id_to_name, attr_id_to_idx = load_attributes()
                idx_to_attr_id = {v: k for k, v in attr_id_to_idx.items()}
                wf.write('\t'.join([
                    'attribute_id', 'attribute_name', 'num_occurrences', 'ap'
                ]) + '\n')
                for idx in range(n_labels):
                    attr_id = idx_to_attr_id[idx]
                    attr_name = attr_id_to_name[attr_id]
                    attr_occurrences = np.sum(gt_mat, axis=0)[idx]
                    ap = cmap_stats[idx]
                    wf.write('{}\t{}\t{}\t{}\n'.format(attr_id, attr_name,
                                                       attr_occurrences,
                                                       ap * 100.0))

            coord.join(threads)
            sess.close()

    print("done")
示例#37
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        deploy_config = model_deploy.DeploymentConfig(num_clones=1,
                                                      clone_on_cpu=False,
                                                      replica_id=0,
                                                      num_replicas=1,
                                                      num_ps_tasks=0)

        tf_global_step = slim.get_or_create_global_step()

        ######################
        # Select the dataset #
        ######################
        with tf.device(deploy_config.inputs_device()):
            dataset = cloudgermam.get_split1(FLAGS.dataset_dir,
                                             FLAGS.dataset_split_name,
                                             FLAGS.batch_size,
                                             FLAGS.num_epochs,
                                             FLAGS.num_readers)

        ####################
        # Select the model #
        ####################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            is_training=False)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            sen1, sen2, labels = dataset.get_next()
            sen1.set_shape([FLAGS.batch_size, 32, 32, 8])
            sen2.set_shape([FLAGS.batch_size, 32, 32, 10])
            images = sen2[:, :, :, :3]
            # images = tf.concat((sen1, sen2), axis=3)
            labels.set_shape([FLAGS.batch_size])
        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=False)

        eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size

        # images = image_preprocessing_fn(images, eval_image_size, eval_image_size)

        ####################
        # Define the model #
        ####################
        logits, _ = network_fn(images)

        if FLAGS.quantize:
            tf.contrib.quantize.create_eval_graph()

        if FLAGS.moving_average_decay:
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, tf_global_step)
            variables_to_restore = variable_averages.variables_to_restore(
                slim.get_model_variables())
            variables_to_restore[tf_global_step.op.name] = tf_global_step
        else:
            variables_to_restore = slim.get_variables_to_restore()

        predictions = tf.argmax(logits, 1)
        labels = tf.squeeze(labels)

        # Define the metrics:
        names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
            'Accuracy':
            slim.metrics.streaming_accuracy(predictions, labels),
            'Recall_5':
            slim.metrics.streaming_recall_at_k(logits, labels, 5),
        })

        # Print the summaries to screen.
        for name, value in names_to_values.items():
            summary_name = 'eval/%s' % name
            op = tf.summary.scalar(summary_name, value, collections=[])
            op = tf.Print(op, [value], summary_name)
            tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)

        # TODO(sguada) use num_epochs=1
        if FLAGS.max_num_batches:
            num_batches = FLAGS.max_num_batches
        else:
            # This ensures that we make a single pass over all of the data.
            num_batches = math.ceil(dataset.num_samples /
                                    float(FLAGS.batch_size))

        if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
            checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
        else:
            checkpoint_path = FLAGS.checkpoint_path
        # checkpoint_path = './Log/model.ckpt-204693'
        tf.logging.info('Evaluating %s' % checkpoint_path)

        slim.evaluation.evaluate_once(
            master=FLAGS.master,
            checkpoint_path=checkpoint_path,
            logdir=FLAGS.eval_dir,
            num_evals=num_batches,
            eval_op=list(names_to_updates.values()),
            variables_to_restore=variables_to_restore)
def run_training(degrad_ckpt_file, ckpt_dir, model_name, max_steps,
                 train_from_scratch, ckpt_path):
    batch_size = 128
    # Create model directory
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    continue_from_trained_model = False

    config = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=False)
    config.gpu_options.allow_growth = True
    network_fn = nets_factory.get_network_fn(
        model_name,
        num_classes=FLAGS.num_classes_budget,
        weight_decay=FLAGS.weight_decay,
        is_training=True)
    with tf.Graph().as_default():
        with tf.Session(config=config) as sess:
            global_step = tf.get_variable(
                'global_step', [],
                initializer=tf.constant_initializer(0),
                trainable=False)
            images_placeholder, labels_placeholder, isTraining_placeholder = placeholder_inputs(
                batch_size * FLAGS.gpu_num)
            tower_grads = []
            logits_lst = []
            losses_lst = []
            opt = tf.train.AdamOptimizer(1e-4)
            with tf.variable_scope(tf.get_variable_scope()) as scope:
                for gpu_index in range(0, FLAGS.gpu_num):
                    with tf.device('/gpu:%d' % gpu_index):
                        print('/gpu:%d' % gpu_index)
                        with tf.name_scope('%s_%d' %
                                           ('gpu', gpu_index)) as scope:
                            degrad_images = residualNet(
                                images_placeholder[gpu_index *
                                                   FLAGS.image_batch_size:
                                                   (gpu_index + 1) *
                                                   FLAGS.image_batch_size],
                                training=False)

                            logits, _ = network_fn(degrad_images)
                            loss = tf.reduce_mean(
                                tf.nn.sigmoid_cross_entropy_with_logits(
                                    logits=logits,
                                    labels=labels_placeholder[gpu_index *
                                                              batch_size:
                                                              (gpu_index + 1) *
                                                              batch_size, :]))
                            logits_lst.append(logits)
                            losses_lst.append(loss)
                            print([v.name for v in tf.trainable_variables()])
                            varlist_budget = [
                                v for v in tf.trainable_variables()
                                if any(x in v.name for x in [
                                    "InceptionV1", "InceptionV2",
                                    "resnet_v1_50", "resnet_v1_101",
                                    "resnet_v2_50", "resnet_v2_101",
                                    "MobilenetV1_1.0", "MobilenetV1_0.75",
                                    "MobilenetV1_0.5", 'MobilenetV1_0.25'
                                ])
                            ]

                            varlist_degrad = [
                                v for v in tf.trainable_variables()
                                if v not in varlist_budget
                            ]
                            tower_grads.append(
                                opt.compute_gradients(loss, varlist_budget))
                            tf.get_variable_scope().reuse_variables()
            loss_op = tf.reduce_mean(losses_lst)
            logits_op = tf.concat(logits_lst, 0)

            grads = average_gradients(tower_grads)

            with tf.device('/cpu:%d' % 0):
                tvs = varlist_budget
                accum_vars = [
                    tf.Variable(tf.zeros_like(tv.initialized_value()),
                                trainable=False) for tv in tvs
                ]
                zero_ops = [tv.assign(tf.zeros_like(tv)) for tv in accum_vars]

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            print(update_ops)
            with tf.control_dependencies([tf.group(*update_ops)]):
                accum_ops = [
                    accum_vars[i].assign_add(gv[0] / FLAGS.n_minibatches)
                    for i, gv in enumerate(grads)
                ]

            apply_gradient_op = opt.apply_gradients(
                [(accum_vars[i].value(), gv[1]) for i, gv in enumerate(grads)],
                global_step=global_step)

            train_image_files = [
                os.path.join(FLAGS.train_images_files_dir, f)
                for f in os.listdir(FLAGS.train_images_files_dir)
                if f.endswith('.tfrecords')
            ]
            val_image_files = [
                os.path.join(FLAGS.val_images_files_dir, f)
                for f in os.listdir(FLAGS.val_images_files_dir)
                if f.endswith('.tfrecords')
            ]
            print(
                '#############################Reading from files###############################'
            )
            print(train_image_files)
            print(val_image_files)

            train_images_op, train_labels_op = inputs_images(
                filenames=train_image_files,
                batch_size=FLAGS.image_batch_size * FLAGS.gpu_num,
                num_epochs=None,
                num_threads=FLAGS.num_threads,
                num_examples_per_epoch=FLAGS.num_examples_per_epoch,
                shuffle=False)
            val_images_op, val_labels_op = inputs_images(
                filenames=val_image_files,
                batch_size=FLAGS.image_batch_size * FLAGS.gpu_num,
                num_epochs=None,
                num_threads=FLAGS.num_threads,
                num_examples_per_epoch=FLAGS.num_examples_per_epoch,
                shuffle=False)

            init_op = tf.group(tf.local_variables_initializer(),
                               tf.global_variables_initializer())
            sess.run(init_op)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            gvar_list = tf.global_variables()
            bn_moving_vars = [g for g in gvar_list if 'moving_mean' in g.name]
            bn_moving_vars += [
                g for g in gvar_list if 'moving_variance' in g.name
            ]
            print([var.name for var in bn_moving_vars])

            def restore_model(dir, varlist, modulename):
                import re
                regex = re.compile(r'(MobilenetV1_?)(\d*\.?\d*)',
                                   re.IGNORECASE)
                if 'mobilenet' in modulename:
                    varlist = {
                        regex.sub('MobilenetV1', v.name[:-2]): v
                        for v in varlist
                    }
                if os.path.isfile(dir):
                    print(varlist)
                    saver = tf.train.Saver(varlist)
                    saver.restore(sess, dir)
                    print(
                        '#############################Session restored from pretrained model at {}!#############################'
                        .format(dir))
                else:
                    ckpt = tf.train.get_checkpoint_state(checkpoint_dir=dir)
                    if ckpt and ckpt.model_checkpoint_path:
                        saver = tf.train.Saver(varlist)
                        saver.restore(sess, ckpt.model_checkpoint_path)
                        print(
                            '#############################Session restored from pretrained model at {}!#############################'
                            .format(ckpt.model_checkpoint_path))

            if continue_from_trained_model:
                varlist = varlist_budget
                varlist += bn_moving_vars
                saver = tf.train.Saver(varlist)
                ckpt = tf.train.get_checkpoint_state(checkpoint_dir=ckpt_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    print(
                        '#############################Session restored from trained model at {}!###############################'
                        .format(ckpt.model_checkpoint_path))
                else:
                    raise FileNotFoundError(errno.ENOENT,
                                            os.strerror(errno.ENOENT),
                                            ckpt_dir)
            else:
                if not train_from_scratch:
                    saver = tf.train.Saver(varlist_degrad)
                    print(degrad_ckpt_file)
                    saver.restore(sess, degrad_ckpt_file)

                    varlist = [
                        v for v in varlist_budget + bn_moving_vars
                        if not any(x in v.name for x in ["logits"])
                    ]
                    restore_model(ckpt_path, varlist, model_name)

            saver = tf.train.Saver(tf.trainable_variables() + bn_moving_vars,
                                   max_to_keep=1)
            for step in xrange(max_steps):
                start_time = time.time()
                loss_value_lst = []
                sess.run(zero_ops)
                for _ in itertools.repeat(None, FLAGS.n_minibatches):
                    train_images, train_labels = sess.run(
                        [train_images_op, train_labels_op])
                    _, loss_value = sess.run(
                        [accum_ops, loss_op],
                        feed_dict={
                            images_placeholder: train_images,
                            labels_placeholder: train_labels,
                            isTraining_placeholder: True
                        })
                    loss_value_lst.append(loss_value)
                sess.run(apply_gradient_op)
                assert not np.isnan(
                    np.mean(loss_value_lst)), 'Model diverged with loss = NaN'
                duration = time.time() - start_time
                print('Step: {:4d} time: {:.4f} loss: {:.8f}'.format(
                    step, duration, np.mean(loss_value_lst)))
                if step % FLAGS.val_step == 0:
                    loss_budget_lst = []
                    pred_probs_lst = []
                    gt_lst = []
                    for _ in itertools.repeat(None, 30):
                        val_images, val_labels = sess.run(
                            [val_images_op, val_labels_op])
                        gt_lst.append(val_labels)
                        logits_budget, loss_budget = sess.run(
                            [logits_op, loss_op],
                            feed_dict={
                                images_placeholder: val_images,
                                labels_placeholder: val_labels,
                                isTraining_placeholder: False
                            })
                        loss_budget_lst.append(loss_budget)
                        pred_probs_lst.append(logits_budget)

                    pred_probs_mat = np.concatenate(pred_probs_lst, axis=0)
                    gt_mat = np.concatenate(gt_lst, axis=0)
                    n_examples, n_labels = gt_mat.shape
                    print('# Examples = ', n_examples)
                    print('# Labels = ', n_labels)
                    print('Macro MAP = {:.2f}'.format(
                        100 * average_precision_score(
                            gt_mat, pred_probs_mat, average='macro')))

                # Save a checkpoint and evaluate the model periodically.
                if step % FLAGS.save_step == 0 or (step + 1) == max_steps:
                    checkpoint_path = os.path.join(ckpt_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=step)

            coord.request_stop()
            coord.join(threads)

    print("done")
    'If left as None, then moving averages are not used.')

tf.app.flags.DEFINE_integer('eval_image_size', None, 'Eval image size')

FLAGS = tf.app.flags.FLAGS

is_training = False
preprocessing_name = FLAGS.model_name

graph = tf.Graph().as_default()

image_preprocessing_fn = preprocessing_factory.get_preprocessing(
    preprocessing_name, is_training)

network_fn = nets_factory.get_network_fn(FLAGS.model_name,
                                         num_classes=100,
                                         is_training=is_training)

placeholder = tf.placeholder(name='input', dtype=tf.string)
image = tf.image.decode_jpeg(placeholder, channels=3)
image = image_preprocessing_fn(image, 320, 320)
image = tf.expand_dims(image, 0)
logit, _ = network_fn(image)

saver = tf.train.Saver()
sess = tf.Session()
saver.restore(sess, FLAGS.checkpoint_path)

filenames = []
recall5word = []
def main(_):
    tic = time.time()
    print('tensorflow version:', tf.__version__)
    tf.logging.set_verbosity(tf.logging.INFO)
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    # init
    print('HG: Train pruned blocks for all valid layers concurrently')
    block_names = valid_block_names
    net_name_scope_checkpoint = FLAGS.net_name_scope_checkpoint
    kept_percentages = sorted(
        [float(x) for x in FLAGS.kept_percentages.split(',')])
    print_list('kept_percentages', kept_percentages)

    # prepare file system
    results_dir = os.path.join(FLAGS.train_dir, 'kp' + FLAGS.kept_percentages)
    train_dir = os.path.join(results_dir, 'train')
    if (not FLAGS.continue_training) or (
            not tf.train.latest_checkpoint(train_dir)):
        print('Start a new training')
        prepare_file_system(train_dir)
    else:
        print('Continue training')

    def write_log_info(info):
        with open(os.path.join(FLAGS.train_dir, 'log.txt'), 'a') as f:
            f.write(info + '\n')

    def write_detailed_info(info):
        with open(os.path.join(train_dir, 'train_details.txt'), 'a') as f:
            f.write(info + '\n')

    info = 'train_dir:' + train_dir
    log_info = info + '\n'
    write_detailed_info(info)

    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.train_dataset_name,
                                              FLAGS.dataset_dir)
        test_dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                                   FLAGS.test_dataset_name,
                                                   FLAGS.dataset_dir)

        batch_queue = train_inputs(dataset, deploy_config, FLAGS)
        test_images, test_labels = test_inputs(test_dataset, deploy_config,
                                               FLAGS)
        images, labels = batch_queue.dequeue()

        ######################
        # Select the network#
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay)
        _, end_points = network_fn(images, is_training=False)
        # for item in end_points.iteritems():
        #     print(item)
        # return

        network_fn_pruned = nets_factory.get_network_fn_pruned(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay)

        # #########################################
        # # Configure the optimization procedure. #
        # #########################################
        with tf.device(deploy_config.variables_device()):
            global_step = tf.Variable(0, trainable=False, name='global_step')
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = configure_learning_rate(dataset.num_samples,
                                                    global_step, FLAGS)
            optimizer = configure_optimizer(learning_rate, FLAGS)
            tf.summary.scalar('learning_rate', learning_rate)

        ####################
        # Define the model #
        ####################
        # each kept_percentage corresponds to a pruned network.
        train_tensors = []
        total_losses = []
        pruned_net_name_scopes = []
        correct_predictions = []
        prune_infos = []
        for kept_percentage in kept_percentages:

            prune_info = kept_percentage_sequence_to_prune_info(
                kept_percentage, block_names)
            set_prune_info_inputs(prune_info,
                                  end_points,
                                  block_size=FLAGS.block_size)
            pprint(prune_info)
            # return
            prune_infos.append(prune_info)

            #  the pruned network scope
            net_name_scope_pruned = FLAGS.net_name_scope_pruned + '_p' + str(
                kept_percentage)
            pruned_net_name_scopes.append(net_name_scope_pruned)

            # generate the pruned network for training
            _, end_points_pruned = network_fn_pruned(
                images,
                prune_info=prune_info,
                is_training=True,
                is_local_train=True,
                reuse_variables=False,
                scope=net_name_scope_pruned)
            # generate the pruned network for testing
            logits, _ = network_fn_pruned(test_images,
                                          prune_info=prune_info,
                                          is_training=False,
                                          is_local_train=False,
                                          reuse_variables=True,
                                          scope=net_name_scope_pruned)
            # add correct prediction to the testing network
            correct_prediction = add_correct_prediction(logits, test_labels)
            correct_predictions.append(correct_prediction)

            #############################
            # Specify the loss functions #
            #############################
            print('HG: block_size', FLAGS.block_size)
            is_first_block = True
            for i, block_name in enumerate(block_names):
                if (i + 1
                    ) % FLAGS.block_size != 0 and i != len(block_names) - 1:
                    continue
                print('HG: i=%d, block_name=%s' % (i, block_name))
                # add l2 losses
                appendix = '_p' + str(kept_percentage) + '_' + str(i)
                collection_name = 'subgraph_losses' + appendix
                # print("HG: collection_name=", collection_name)

                outputs = end_points[block_name]
                outputs_pruned = end_points_pruned[block_name]
                l2_loss = add_l2_loss(outputs,
                                      outputs_pruned,
                                      add_to_collection=True,
                                      collection_name=collection_name)

                # get regularization loss
                if i == len(block_names) - 1 and FLAGS.block_size < len(
                        block_names):
                    tmp_block_names = block_names[
                        int(len(block_names) / FLAGS.block_size) *
                        FLAGS.block_size:]
                    print('HG: last block size:', len(tmp_block_names))
                else:
                    print('HG: this block start from id=',
                          i - FLAGS.block_size + 1, ', end before id=', i + 1)
                    tmp_block_names = block_names[i - FLAGS.block_size + 1:i +
                                                  1]
                print('HG: this block contains names:', tmp_block_names)

                regularization_losses = get_regularization_losses_with_block_names(net_name_scope_pruned, \
                    tmp_block_names, add_to_collection=True, collection_name=collection_name)
                print_list('regularization_losses', regularization_losses)

                # total loss and its summary
                total_loss = tf.add_n(tf.get_collection(collection_name),
                                      name='total_loss')
                for l in tf.get_collection(collection_name) + [total_loss]:
                    tf.summary.scalar(l.op.name + appendix + '/summary', l)
                total_losses.append(total_loss)

                #############################
                # Add train operation       #
                #############################
                variables_to_train = get_trainable_variables_with_block_names(
                    net_name_scope_pruned, tmp_block_names)
                print_list("variables_to_train", variables_to_train)

                # add train_op
                if is_first_block and kept_percentage == kept_percentages[0]:
                    global_step_tmp = global_step
                else:
                    global_step_tmp = tf.Variable(0,
                                                  trainable=False,
                                                  name='global_step' +
                                                  appendix)
                train_op = add_train_op(optimizer,
                                        total_loss,
                                        global_step_tmp,
                                        var_list=variables_to_train)
                is_first_block = False

                # Gather update_ops: the updates for the batch_norm variables created by network_fn_pruned.
                update_ops = get_update_ops_with_block_names(
                    net_name_scope_pruned, tmp_block_names)
                print_list("update_ops", update_ops)

                update_ops.append(train_op)
                update_op = tf.group(*update_ops)
                with tf.control_dependencies([update_op]):
                    train_tensor = tf.identity(total_loss,
                                               name='train_op' + appendix)
                    train_tensors.append(train_tensor)

        # add summary op
        summary_op = tf.summary.merge_all()
        # return

        print("HG: trainable_variables=", len(tf.trainable_variables()))
        print("HG: model_variables=", len(tf.model_variables()))
        print("HG: global_variables=", len(tf.global_variables()))
        print_list(
            'model_variables but not trainable variables',
            list(
                set(tf.model_variables()).difference(
                    tf.trainable_variables())))
        print_list(
            'global_variables but not model variables',
            list(set(tf.global_variables()).difference(tf.model_variables())))
        print(
            "HG: trainable_variables from " + net_name_scope_checkpoint + "=",
            len(
                get_trainable_variables_within_scopes(
                    [net_name_scope_checkpoint + '/'])))
        print(
            "HG: trainable_variables from " + net_name_scope_pruned + "=",
            len(
                get_trainable_variables_within_scopes(
                    [net_name_scope_pruned + '/'])))
        print(
            "HG: model_variables from " + net_name_scope_checkpoint + "=",
            len(
                get_model_variables_within_scopes(
                    [net_name_scope_checkpoint + '/'])))
        print(
            "HG: model_variables from " + net_name_scope_pruned + "=",
            len(
                get_model_variables_within_scopes(
                    [net_name_scope_pruned + '/'])))
        print(
            "HG: global_variables from " + net_name_scope_checkpoint + "=",
            len(
                get_global_variables_within_scopes(
                    [net_name_scope_checkpoint + '/'])))
        print(
            "HG: global_variables from " + net_name_scope_pruned + "=",
            len(
                get_global_variables_within_scopes(
                    [net_name_scope_pruned + '/'])))

        sess_config = tf.ConfigProto(intra_op_parallelism_threads=16,
                                     inter_op_parallelism_threads=16)

        with tf.Session(config=sess_config) as sess:
            ###########################
            # prepare for filewritter #
            ###########################
            train_writer = tf.summary.FileWriter(train_dir, sess.graph)

            if (not FLAGS.continue_training) or (
                    not tf.train.latest_checkpoint(train_dir)):
                ###########################################
                # Restore original model variable values. #
                ###########################################
                variables_to_restore = get_model_variables_within_scopes(
                    [net_name_scope_checkpoint + '/'])
                print_list("restore model variables for original",
                           variables_to_restore)
                load_checkpoint(sess,
                                FLAGS.checkpoint_path,
                                var_list=variables_to_restore)

                #################################################
                # Init  pruned networks  with  well-trained model #
                #################################################

                for i in range(len(pruned_net_name_scopes)):
                    net_name_scope_pruned = pruned_net_name_scopes[i]
                    print('net_name_scope_pruned=', net_name_scope_pruned)

                    ## init pruned variables .
                    kept_percentage = kept_percentages[i]
                    prune_info = prune_infos[i]
                    variables_init_value = get_pruned_kernel_matrix(
                        sess, prune_info, net_name_scope_checkpoint)
                    reinit_scopes = [
                        re.sub(net_name_scope_checkpoint,
                               net_name_scope_pruned, name)
                        for name in variables_init_value.keys()
                    ]
                    variables_to_reinit = get_model_variables_within_scopes(
                        reinit_scopes)
                    print_list("Initialize pruned variables",
                               variables_to_reinit)

                    assign_ops = []
                    for v in variables_to_reinit:
                        key = re.sub(net_name_scope_pruned,
                                     net_name_scope_checkpoint, v.op.name)
                        if key in variables_init_value:
                            value = variables_init_value.get(key)
                            # print(key, value)
                            assign_ops.append(
                                tf.assign(v,
                                          tf.convert_to_tensor(value),
                                          validate_shape=True))
                            # v.set_shape(value.shape)
                        else:
                            raise ValueError(
                                "Key not in variables_init_value, key=", key)
                    assign_op = tf.group(*assign_ops)
                    sess.run(assign_op)

                    #################################################
                    # Restore unchanged model variable. #
                    #################################################
                    variables_to_restore = {
                        re.sub(net_name_scope_pruned,
                               net_name_scope_checkpoint, v.op.name): v
                        for v in get_model_variables_within_scopes(
                            [net_name_scope_pruned + '/'])
                        if v not in variables_to_reinit
                    }
                    print_list(
                        "restore model variables for " + net_name_scope_pruned,
                        variables_to_restore.values())
                    load_checkpoint(sess,
                                    FLAGS.checkpoint_path,
                                    var_list=variables_to_restore)
            else:
                # restore all variables from checkpoint
                variables_to_restore = get_global_variables_within_scopes()
                load_checkpoint(sess, train_dir, var_list=variables_to_restore)

            #################################################
            # init unitialized global variable. #
            #################################################
            # uninitialized_variables =[x.decode('utf-8') for x in sess.run(tf.report_uninitialized_variables())]
            # print_list('uninitialized variables', uninitialized_variables)
            # variables_to_init = [v for v in tf.global_variables() if v.name.split(':')[0] in set(uninitialized_variables)]
            # #get_global_variables_within_scopes(uninitialized_variables)
            # print_list("variables_to_init", variables_to_init)
            # sess.run( tf.variables_initializer(variables_to_init) )
            variables_to_init = get_global_variables_within_scopes(
                sess.run(tf.report_uninitialized_variables()))
            print_list("init unitialized variables", variables_to_init)
            sess.run(tf.variables_initializer(variables_to_init))

            init_global_step_value = sess.run(global_step)
            print('HG: initial global step: ', init_global_step_value)
            if init_global_step_value >= FLAGS.max_number_of_steps:
                print('Exit: init_global_step_value (%d) >= FLAGS.max_number_of_steps (%d)' \
                    %(init_global_step_value, FLAGS.max_number_of_steps))
                return

            ###########################
            # Record CPU usage  #
            ###########################
            mpstat_output_filename = os.path.join(train_dir, "cpu-usage.log")
            os.system("mpstat -P ALL 1 > " + mpstat_output_filename +
                      " 2>&1 &")

            ###########################
            # Kicks off the training. #
            ###########################
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            print('HG: # of threads=', len(threads))

            # saver for models
            if FLAGS.max_to_keep <= 0:
                max_to_keep = int(2 * FLAGS.max_number_of_steps /
                                  FLAGS.evaluate_every_n_steps)
            else:
                max_to_keep = FLAGS.max_to_keep
            saver = tf.train.Saver(max_to_keep=max_to_keep)

            train_time = 0  # the amount of time spending on sgd training only.
            duration = 0  # used to estimate the training speed
            train_only_cnt = 0  # used to calculate the true training time.
            duration_cnt = 0

            print("start to train at:", datetime.now())
            for i in range(init_global_step_value,
                           FLAGS.max_number_of_steps + 1):

                # run optional meta data, or summary, while run train tensor
                if i > init_global_step_value:  # FLAGS.max_number_of_steps:

                    # run metadata
                    if i % FLAGS.runmeta_every_n_steps == FLAGS.runmeta_every_n_steps - 1:
                        run_options = tf.RunOptions(
                            trace_level=tf.RunOptions.FULL_TRACE)
                        run_metadata = tf.RunMetadata()

                        loss_values = sess.run(train_tensors,
                                               options=run_options,
                                               run_metadata=run_metadata)
                        train_writer.add_run_metadata(run_metadata,
                                                      'step%d-train' % i)

                        # Create the Timeline object, and write it to a json file
                        fetched_timeline = timeline.Timeline(
                            run_metadata.step_stats)
                        chrome_trace = fetched_timeline.generate_chrome_trace_format(
                        )
                        with open(
                                os.path.join(train_dir,
                                             'timeline_' + str(i) + '.json'),
                                'w') as f:
                            f.write(chrome_trace)

                    # record summary
                    elif i % FLAGS.summary_every_n_steps == 0:
                        results = sess.run([summary_op] + train_tensors)
                        train_summary, loss_values = results[0], results[1:]
                        train_writer.add_summary(train_summary, i)
                        # print('HG: train with summary')
                        # only run train op
                    else:
                        start_time = time.time()
                        loss_values = sess.run(train_tensors)
                        train_only_cnt += 1
                        duration_cnt += 1
                        train_time += time.time() - start_time
                        duration += time.time() - start_time

                    if i % FLAGS.log_every_n_steps == 0 and duration_cnt > 0:
                        # record speed
                        log_frequency = duration_cnt
                        examples_per_sec = log_frequency * FLAGS.batch_size / duration
                        sec_per_batch = float(duration / log_frequency)
                        summary = tf.Summary()
                        summary.value.add(tag='examples_per_sec',
                                          simple_value=examples_per_sec)
                        summary.value.add(tag='sec_per_batch',
                                          simple_value=sec_per_batch)
                        train_writer.add_summary(summary, i)
                        info = (
                            '%s: step %d, loss = %s (%.1f examples/sec; %.3f sec/batch)'
                        ) % (datetime.now(), i, str(loss_values),
                             examples_per_sec, sec_per_batch)
                        print(info)
                        duration = 0
                        duration_cnt = 0

                        write_detailed_info(info)
                else:
                    # run only total loss when i=0
                    results = sess.run(
                        [summary_op] +
                        total_losses)  #loss_value = sess.run(total_loss)
                    train_summary, loss_values = results[0], results[1:]
                    train_writer.add_summary(train_summary, i)
                    format_str = ('%s: step %d, loss = %s')
                    print(format_str % (datetime.now(), i, str(loss_values)))
                    info = format_str % (datetime.now(), i, str(loss_values))
                    write_detailed_info(info)

                # record the evaluation accuracy
                is_last_step = (i == FLAGS.max_number_of_steps)
                if i % FLAGS.evaluate_every_n_steps == 0 or is_last_step:

                    # test accuracy; each kept_percentage corresponds to a pruned network, and thus an accuracy.
                    test_accuracies = []
                    for p in range(len(kept_percentages)):
                        kept_percentage = kept_percentages[p]
                        appendix = '_p' + str(kept_percentage)
                        correct_prediction = correct_predictions[p]
                        # run_meta = (i==FLAGS.evaluate_every_n_steps)&&(p==0)
                        test_accuracy, run_metadata = evaluate_accuracy(
                            sess,
                            coord,
                            test_dataset.num_samples,
                            test_images,
                            test_labels,
                            test_images,
                            test_labels,
                            correct_prediction,
                            FLAGS.test_batch_size,
                            run_meta=False)
                        summary = tf.Summary()
                        summary.value.add(tag='accuracy' + appendix,
                                          simple_value=test_accuracy)
                        train_writer.add_summary(summary, i)
                        test_accuracies.append(
                            (kept_percentage, test_accuracy))
                    # if run_meta:
                    # eval_writer.add_run_metadata(run_metadata, 'step%d-eval' % i)
                    acc_str = '[' + ', '.join([
                        '(%s, %.6f)' % (str(kp), acc)
                        for kp, acc in test_accuracies
                    ]) + ']'
                    info = ('%s: step %d, test_accuracy = %s') % (
                        datetime.now(), i, str(acc_str))
                    print(info)
                    if i == 0 or is_last_step:
                        # write_log_info(info)
                        log_info += info + '\n'
                    write_detailed_info(info)

                    ###########################
                    # Save model parameters . #
                    ###########################
                    # saver = tf.train.Saver()
                    save_path = saver.save(
                        sess, os.path.join(train_dir, 'model.ckpt-' + str(i)))
                    print("HG: Model saved in file: %s" % save_path)

            coord.request_stop()
            coord.join(threads)
            total_time = time.time() - tic

            train_speed = train_time / train_only_cnt
            train_time = (FLAGS.max_number_of_steps) * train_speed
            info = "HG: training speed(sec/batch): %.6f\n" % (train_speed)
            info += "HG: training time(min): %.1f, total time(min): %.1f \n" % (
                train_time / 60.0, total_time / 60.0)
            print(info)
            log_info += info
            write_log_info(log_info)
            write_detailed_info(info)
示例#41
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        tf_global_step = slim.get_or_create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ####################
        # Select the model #
        ####################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            is_training=False)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        provider = slim.dataset_data_provider.DatasetDataProvider(
            dataset,
            shuffle=False,
            common_queue_capacity=2 * FLAGS.batch_size,
            common_queue_min=FLAGS.batch_size)
        [image, label] = provider.get(['image', 'label'])
        label -= FLAGS.labels_offset

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=False)

        eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size

        image = image_preprocessing_fn(image, eval_image_size, eval_image_size)

        images, labels = tf.train.batch(
            [image, label],
            batch_size=FLAGS.batch_size,
            num_threads=FLAGS.num_preprocessing_threads,
            capacity=5 * FLAGS.batch_size)

        ####################
        # Define the model #
        ####################
        logits, _ = network_fn(images)

        if FLAGS.moving_average_decay:
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, tf_global_step)
            variables_to_restore = variable_averages.variables_to_restore(
                slim.get_model_variables())
            variables_to_restore[tf_global_step.op.name] = tf_global_step
        else:
            variables_to_restore = slim.get_variables_to_restore()

        predictions = tf.argmax(logits, 1)
        labels = tf.squeeze(labels)

        # Define the metrics:
        names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
            'Accuracy':
            slim.metrics.streaming_accuracy(predictions, labels),
            'Recall_5':
            slim.metrics.streaming_recall_at_k(logits, labels, 5),
        })

        # Print the summaries to screen.
        summary_ops = []
        for name, value in names_to_values.items():
            summary_name = 'eval/%s' % name
            op = tf.summary.scalar(summary_name, value, collections=[])
            op = tf.Print(op, [value], summary_name)
            tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)
            summary_ops.append(op)

        # TODO(sguada) use num_epochs=1
        if FLAGS.max_num_batches:
            num_batches = FLAGS.max_num_batches
        else:
            # This ensures that we make a single pass over all of the data.
            num_batches = math.ceil(dataset.num_samples /
                                    float(FLAGS.batch_size))
        """
    if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
      checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
    else:
      checkpoint_path = FLAGS.checkpoint_path
    """

        tf.logging.info('Evaluating %s' % FLAGS.checkpoint_path)

        # GPU memory dynamic allocation
        session_config = tf.ConfigProto()
        session_config.gpu_options.allow_growth = True

        slim.evaluation.evaluation_loop(
            master=FLAGS.master,
            checkpoint_dir=FLAGS.checkpoint_path,
            logdir=FLAGS.eval_dir,
            num_evals=num_batches,
            eval_op=list(names_to_updates.values()),
            summary_op=tf.summary.merge(summary_ops),
            eval_interval_secs=FLAGS.eval_interval_secs,
            variables_to_restore=variables_to_restore,
            session_config=session_config)
        """

# In[3]:

# 获取图片数据和标签
image, label0, label1, label2, label3 = read_and_decode(TFRECORD_FILE)

#使用shuffle_batch可以随机打乱
image_batch, label_batch0, label_batch1, label_batch2, label_batch3 = tf.train.shuffle_batch(
        [image, label0, label1, label2, label3], batch_size = BATCH_SIZE,
        capacity = 50000, min_after_dequeue=10000, num_threads=1)

#定义网络结构
train_network_fn = nets_factory.get_network_fn(
    'alexnet_v2',
    num_classes=CHAR_SET_LEN,
    weight_decay=0.0005,
    is_training=True)
 
    
with tf.Session() as sess:
    # inputs: a tensor of size [batch_size, height, width, channels]
    X = tf.reshape(x, [BATCH_SIZE, 224, 224, 1])
    # 数据输入网络得到输出值
    logits0,logits1,logits2,logits3,end_points = train_network_fn(X)
    
    # 把标签转成one_hot的形式
    one_hot_labels0 = tf.one_hot(indices=tf.cast(y0, tf.int32), depth=CHAR_SET_LEN)
    one_hot_labels1 = tf.one_hot(indices=tf.cast(y1, tf.int32), depth=CHAR_SET_LEN)
    one_hot_labels2 = tf.one_hot(indices=tf.cast(y2, tf.int32), depth=CHAR_SET_LEN)
    one_hot_labels3 = tf.one_hot(indices=tf.cast(y3, tf.int32), depth=CHAR_SET_LEN)
示例#43
0
def main(_):
  if not FLAGS.dataset_dir:
    raise ValueError('You must supply the dataset directory with --dataset_dir')

  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    #######################
    # Config model_deploy #
    #######################
    deploy_config = model_deploy.DeploymentConfig(
      num_clones=FLAGS.num_clones,
      clone_on_cpu=FLAGS.clone_on_cpu,
      replica_id=FLAGS.task,
      num_replicas=FLAGS.worker_replicas,
      num_ps_tasks=FLAGS.num_ps_tasks)

    # Create global_step
    with tf.device(deploy_config.variables_device()):
      global_step = slim.create_global_step()

    ######################
    # Select the dataset #
    ######################
    dataset = dataset_factory.get_dataset(
      FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

    ######################
    # Select the network #
    ######################
    network_fn = nets_factory.get_network_fn(
      FLAGS.model_name,
      num_classes=(dataset.num_classes - FLAGS.labels_offset),
      weight_decay=FLAGS.weight_decay,
      is_training=True,
      width_multiplier=FLAGS.width_multiplier)

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
      preprocessing_name,
      is_training=True)

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    with tf.device(deploy_config.inputs_device()):
      provider = slim.dataset_data_provider.DatasetDataProvider(
        dataset,
        num_readers=FLAGS.num_readers,
        common_queue_capacity=20 * FLAGS.batch_size,
        common_queue_min=10 * FLAGS.batch_size)

      # gt_bboxes format [ymin, xmin, ymax, xmax]
      [image, img_shape, gt_labels, gt_bboxes] = provider.get(['image', 'shape',
                                                               'object/label',
                                                               'object/bbox'])

      # Preprocesing
      # gt_bboxes = scale_bboxes(gt_bboxes, img_shape)  # bboxes format [0,1) for tf draw

      image, gt_labels, gt_bboxes = image_preprocessing_fn(image,
                                                           config.IMG_HEIGHT,
                                                           config.IMG_WIDTH,
                                                           labels=gt_labels,
                                                           bboxes=gt_bboxes,
                                                           )

      #############################################
      # Encode annotations for losses computation #
      #############################################

      # anchors format [cx, cy, w, h]
      anchors = tf.convert_to_tensor(config.ANCHOR_SHAPE, dtype=tf.float32)

      # encode annos, box_input format [cx, cy, w, h]
      input_mask, labels_input, box_delta_input, box_input = encode_annos(gt_labels,
                                                                          gt_bboxes,
                                                                          anchors,
                                                                          config.NUM_CLASSES)

      images, b_input_mask, b_labels_input, b_box_delta_input, b_box_input = tf.train.batch(
        [image, input_mask, labels_input, box_delta_input, box_input],
        batch_size=FLAGS.batch_size,
        num_threads=FLAGS.num_preprocessing_threads,
        capacity=5 * FLAGS.batch_size)

      batch_queue = slim.prefetch_queue.prefetch_queue(
        [images, b_input_mask, b_labels_input, b_box_delta_input, b_box_input], capacity=2 * deploy_config.num_clones)

    ####################
    # Define the model #
    ####################
    def clone_fn(batch_queue):
      """Allows data parallelism by creating multiple clones of network_fn."""
      images, b_input_mask, b_labels_input, b_box_delta_input, b_box_input = batch_queue.dequeue()
      anchors = tf.convert_to_tensor(config.ANCHOR_SHAPE, dtype=tf.float32)
      end_points = network_fn(images)
      end_points["viz_images"] = images
      conv_ds_14 = end_points['MobileNet/conv_ds_14/depthwise_conv']
      dropout = slim.dropout(conv_ds_14, keep_prob=0.5, is_training=True)
      num_output = config.NUM_ANCHORS * (config.NUM_CLASSES + 1 + 4)
      predict = slim.conv2d(dropout, num_output, kernel_size=(3, 3), stride=1, padding='SAME',
                            activation_fn=None,
                            weights_initializer=tf.truncated_normal_initializer(stddev=0.0001),
                            scope="MobileNet/conv_predict")

      with tf.name_scope("Interpre_prediction") as scope:
        pred_box_delta, pred_class_probs, pred_conf, ious, det_probs, det_boxes, det_class = \
          interpre_prediction(predict, b_input_mask, anchors, b_box_input)
        end_points["viz_det_probs"] = det_probs
        end_points["viz_det_boxes"] = det_boxes
        end_points["viz_det_class"] = det_class

      with tf.name_scope("Losses") as scope:
        losses(b_input_mask, b_labels_input, ious, b_box_delta_input, pred_class_probs, pred_conf, pred_box_delta)

      return end_points

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
    first_clone_scope = deploy_config.clone_scope(0)
    # Gather update_ops from the first clone. These contain, for example,
    # the updates for the batch_norm variables created by network_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Add summaries for end_points.
    end_points = clones[0].outputs
    for end_point in end_points:
      if end_point not in ["viz_images", "viz_det_probs", "viz_det_boxes", "viz_det_class"]:
        x = end_points[end_point]
        summaries.add(tf.summary.histogram('activations/' + end_point, x))
        summaries.add(tf.summary.scalar('sparsity/' + end_point,
                                        tf.nn.zero_fraction(x)))

    # Add summaries for det result TODO(shizehao): vizulize prediction


    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Add summaries for variables.
    for variable in slim.get_model_variables():
      summaries.add(tf.summary.histogram(variable.op.name, variable))

    #################################
    # Configure the moving averages #
    #################################
    if FLAGS.moving_average_decay:
      moving_average_variables = slim.get_model_variables()
      variable_averages = tf.train.ExponentialMovingAverage(
        FLAGS.moving_average_decay, global_step)
    else:
      moving_average_variables, variable_averages = None, None

    #########################################
    # Configure the optimization procedure. #
    #########################################
    with tf.device(deploy_config.optimizer_device()):
      learning_rate = _configure_learning_rate(dataset.num_samples, global_step)
      optimizer = _configure_optimizer(learning_rate)
      summaries.add(tf.summary.scalar('learning_rate', learning_rate))

    if FLAGS.sync_replicas:
      # If sync_replicas is enabled, the averaging will be done in the chief
      # queue runner.
      optimizer = tf.train.SyncReplicasOptimizer(
        opt=optimizer,
        replicas_to_aggregate=FLAGS.replicas_to_aggregate,
        variable_averages=variable_averages,
        variables_to_average=moving_average_variables,
        replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
        total_num_replicas=FLAGS.worker_replicas)
    elif FLAGS.moving_average_decay:
      # Update ops executed locally by trainer.
      update_ops.append(variable_averages.apply(moving_average_variables))

    # Variables to train.
    variables_to_train = _get_variables_to_train()

    #  and returns a train_tensor and summary_op
    total_loss, clones_gradients = model_deploy.optimize_clones(
      clones,
      optimizer,
      var_list=variables_to_train)
    # Add total_loss to summary.
    summaries.add(tf.summary.scalar('total_loss', total_loss))

    # Create gradient updates.
    grad_updates = optimizer.apply_gradients(clones_gradients,
                                             global_step=global_step)
    update_ops.append(grad_updates)

    update_op = tf.group(*update_ops)
    train_tensor = control_flow_ops.with_dependencies([update_op], total_loss,
                                                      name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                       first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries), name='summary_op')

    ###########################
    # Kicks off the training. #
    ###########################
    slim.learning.train(
      train_tensor,
      logdir=FLAGS.train_dir,
      master=FLAGS.master,
      is_chief=(FLAGS.task == 0),
      init_fn=_get_init_fn(),
      summary_op=summary_op,
      number_of_steps=FLAGS.max_number_of_steps,
      log_every_n_steps=FLAGS.log_every_n_steps,
      save_summaries_secs=FLAGS.save_summaries_secs,
      save_interval_secs=FLAGS.save_interval_secs,
      sync_optimizer=optimizer if FLAGS.sync_replicas else None)
def main(_):
  if not FLAGS.dataset_dir:
    raise ValueError('You must supply the dataset directory with --dataset_dir')

  os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpus
  if FLAGS.num_clones == -1:
    FLAGS.num_clones = len(FLAGS.gpus.split(','))

  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    # tf.set_random_seed(42)
    tf.set_random_seed(0)
    ######################
    # Config model_deploy#
    ######################
    deploy_config = model_deploy.DeploymentConfig(
        num_clones=FLAGS.num_clones,
        clone_on_cpu=FLAGS.clone_on_cpu,
        replica_id=FLAGS.task,
        num_replicas=FLAGS.worker_replicas,
        num_ps_tasks=FLAGS.num_ps_tasks)

    # Create global_step
    with tf.device(deploy_config.variables_device()):
      global_step = slim.create_global_step()

    ######################
    # Select the dataset #
    ######################
    dataset = dataset_factory.get_dataset(
        FLAGS.dataset_name, FLAGS.dataset_split_name,
        FLAGS.dataset_dir.split(','),
        dataset_list_dir=FLAGS.dataset_list_dir,
        num_samples=FLAGS.frames_per_video,
        modality=FLAGS.modality,
        split_id=FLAGS.split_id)

    ######################
    # Select the network #
    ######################
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        batch_size=FLAGS.batch_size,
        weight_decay=FLAGS.weight_decay,
        is_training=True,
        dropout_keep_prob=(1.0-FLAGS.dropout),
        pooled_dropout_keep_prob=(1.0-FLAGS.pooled_dropout),
        batch_norm=FLAGS.netvlad_batch_norm)

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
        preprocessing_name,
        is_training=True)  # in case of pooling images,
                           # now preprocessing is done video-level

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    with tf.device(deploy_config.inputs_device()):
      provider = dataset_data_provider.DatasetDataProvider(
        dataset,
        num_readers=FLAGS.num_readers,
        common_queue_capacity=20 * FLAGS.batch_size,
        common_queue_min=10 * FLAGS.batch_size,
        bgr_flips=FLAGS.bgr_flip)
      [image, label] = provider.get(['image', 'label'])
      # now note that the above image might be a 23 channel image if you have
      # both RGB and flow streams. It will need to split later, but all the
      # preprocessing will be done consistently for all frames over all streams
      label = tf.string_to_number(label, tf.int32)
      label.set_shape(())
      label -= FLAGS.labels_offset

      train_image_size = FLAGS.train_image_size or network_fn.default_image_size

      scale_ratios=[float(el) for el in FLAGS.scale_ratios.split(',')],
      image = image_preprocessing_fn(image, train_image_size,
                                     train_image_size,
                                     scale_ratios=scale_ratios,
                                     out_dim_scale=FLAGS.out_dim_scale,
                                     model_name=FLAGS.model_name)

      images, labels = tf.train.batch(
          [image, label],
          batch_size=FLAGS.batch_size,
          num_threads=FLAGS.num_preprocessing_threads,
          capacity=5 * FLAGS.batch_size)
      if FLAGS.debug:
        images = tf.Print(images, [labels], 'Read batch')
      labels = slim.one_hot_encoding(
          labels, dataset.num_classes - FLAGS.labels_offset)
      batch_queue = slim.prefetch_queue.prefetch_queue(
          [images, labels], capacity=2 * deploy_config.num_clones)
      summarize_images(images, provider.num_channels_stream)

    ####################
    # Define the model #
    ####################
    kwargs = {}
    if FLAGS.conv_endpoint is not None:
      kwargs['conv_endpoint'] = FLAGS.conv_endpoint
    def clone_fn(batch_queue):
      """Allows data parallelism by creating multiple clones of network_fn."""
      images, labels = batch_queue.dequeue()
      logits, end_points = network_fn(
          images, pool_type=FLAGS.pooling,
          classifier_type=FLAGS.classifier_type,
          num_channels_stream=provider.num_channels_stream,
          netvlad_centers=FLAGS.netvlad_initCenters.split(','),
          stream_pool_type=FLAGS.stream_pool_type,
          **kwargs)

      #############################
      # Specify the loss function #
      #############################
      if 'AuxLogits' in end_points:
        slim.losses.softmax_cross_entropy(
            end_points['AuxLogits'], labels,
            label_smoothing=FLAGS.label_smoothing, weight=0.4, scope='aux_loss')
      slim.losses.softmax_cross_entropy(
          logits, labels, label_smoothing=FLAGS.label_smoothing, weight=1.0)
      return end_points

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
    first_clone_scope = deploy_config.clone_scope(0)
    # Gather update_ops from the first clone. These contain, for example,
    # the updates for the batch_norm variables created by network_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Add summaries for end_points.
    global end_points_debug
    end_points = clones[0].outputs
    end_points_debug = dict(end_points)
    end_points_debug['images'] = images
    end_points_debug['labels'] = labels
    for end_point in end_points:
      x = end_points[end_point]
      summaries.add(tf.histogram_summary('activations/' + end_point, x))
      summaries.add(tf.scalar_summary('sparsity/' + end_point,
                                      tf.nn.zero_fraction(x)))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.scalar_summary('losses/%s' % loss.op.name, loss))

    # Add summaries for variables.
    for variable in slim.get_model_variables():
      summaries.add(tf.histogram_summary(variable.op.name, variable))

    #################################
    # Configure the moving averages #
    #################################
    if FLAGS.moving_average_decay:
      moving_average_variables = slim.get_model_variables()
      variable_averages = tf.train.ExponentialMovingAverage(
          FLAGS.moving_average_decay, global_step)
    else:
      moving_average_variables, variable_averages = None, None

    #########################################
    # Configure the optimization procedure. #
    #########################################
    with tf.device(deploy_config.optimizer_device()):
      learning_rate = _configure_learning_rate(dataset.num_samples, global_step)
      optimizer = _configure_optimizer(learning_rate)
      summaries.add(tf.scalar_summary('learning_rate', learning_rate,
                                      name='learning_rate'))

    if FLAGS.sync_replicas:
      # If sync_replicas is enabled, the averaging will be done in the chief
      # queue runner.
      optimizer = tf.train.SyncReplicasOptimizer(
          opt=optimizer,
          replicas_to_aggregate=FLAGS.replicas_to_aggregate,
          variable_averages=variable_averages,
          variables_to_average=moving_average_variables,
          replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
          total_num_replicas=FLAGS.worker_replicas)
    elif FLAGS.moving_average_decay:
      # Update ops executed locally by trainer.
      update_ops.append(variable_averages.apply(moving_average_variables))

    # Variables to train.
    variables_to_train = _get_variables_to_train()
    logging.info('Training the following variables: %s' % (
      ' '.join([el.name for el in variables_to_train])))

    #  and returns a train_tensor and summary_op
    total_loss, clones_gradients = model_deploy.optimize_clones(
        clones,
        optimizer,
        var_list=variables_to_train)

    # clip the gradients if needed
    if FLAGS.clip_gradients > 0:
      logging.info('Clipping gradients by %f' % FLAGS.clip_gradients)
      with tf.name_scope('clip_gradients'):
        clones_gradients = slim.learning.clip_gradient_norms(
            clones_gradients,
            FLAGS.clip_gradients)

    # Add total_loss to summary.
    summaries.add(tf.scalar_summary('total_loss', total_loss,
                                    name='total_loss'))

    # Create gradient updates.
    train_ops = {}
    if FLAGS.iter_size == 1:
      grad_updates = optimizer.apply_gradients(clones_gradients,
                                               global_step=global_step)
      update_ops.append(grad_updates)

      update_op = tf.group(*update_ops)
      train_tensor = control_flow_ops.with_dependencies([update_op], total_loss,
                                                        name='train_op')
      train_ops = train_tensor
    else:
      gvs = [(grad, var) for grad, var in clones_gradients]
      varnames = [var.name for grad, var in gvs]
      varname_to_var = {var.name: var for grad, var in gvs}
      varname_to_grad = {var.name: grad for grad, var in gvs}
      varname_to_ref_grad = {}
      for vn in varnames:
        grad = varname_to_grad[vn]
        print("accumulating ... ", (vn, grad.get_shape()))
        with tf.variable_scope("ref_grad"):
          with tf.device(deploy_config.variables_device()):
            ref_var = slim.local_variable(
                np.zeros(grad.get_shape(),dtype=np.float32),
                name=vn[:-2])
            varname_to_ref_grad[vn] = ref_var

      all_assign_ref_op = [ref.assign(varname_to_grad[vn]) for vn, ref in varname_to_ref_grad.items()]
      all_assign_add_ref_op = [ref.assign_add(varname_to_grad[vn]) for vn, ref in varname_to_ref_grad.items()]
      assign_gradients_ref_op = tf.group(*all_assign_ref_op)
      accmulate_gradients_op = tf.group(*all_assign_add_ref_op)
      with tf.control_dependencies([accmulate_gradients_op]):
        final_gvs = [(varname_to_ref_grad[var.name] / float(FLAGS.iter_size), var) for grad, var in gvs]
        apply_gradients_op = optimizer.apply_gradients(final_gvs, global_step=global_step)
        update_ops.append(apply_gradients_op)
        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op],
            total_loss, name='train_op')
      for i in range(FLAGS.iter_size):
        if i == 0:
          train_ops[i] = assign_gradients_ref_op
        elif i < FLAGS.iter_size - 1:  # because apply_gradients also computes
                                       # (see control_dependency), so
                                       # no need of running an extra iteration
          train_ops[i] = accmulate_gradients_op
        else:
          train_ops[i] = train_tensor


    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                       first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.merge_summary(list(summaries), name='summary_op')

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.intra_op_parallelism_threads = FLAGS.cpu_threads
    # config.allow_soft_placement = True
    # config.gpu_options.per_process_gpu_memory_fraction=0.7

    ###########################
    # Kicks off the training. #
    ###########################
    logging.info('RUNNING ON SPLIT %d' % FLAGS.split_id)
    slim.learning.train(
        train_ops,
        train_step_fn=train_step,
        logdir=FLAGS.train_dir,
        master=FLAGS.master,
        is_chief=(FLAGS.task == 0),
        init_fn=_get_init_fn(),
        summary_op=summary_op,
        number_of_steps=FLAGS.max_number_of_steps,
        log_every_n_steps=FLAGS.log_every_n_steps,
        save_summaries_secs=FLAGS.save_summaries_secs,
        save_interval_secs=FLAGS.save_interval_secs,
        sync_optimizer=optimizer if FLAGS.sync_replicas else None,
        session_config=config)
示例#45
0
def main(_):

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        tf_global_step = slim.get_or_create_global_step()

        ####################
        # Select the model #
        ####################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(FLAGS.num_classes - FLAGS.labels_offset),
            is_training=False)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=False)

        eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size

        image = tf.placeholder(dtype=tf.float32,
                               shape=(eval_image_size, eval_image_size, 3))

        image = image_preprocessing_fn(image, eval_image_size, eval_image_size)

        image = tf.placeholder(dtype=tf.float32,
                               shape=(1, eval_image_size, eval_image_size, 3))

        logits, end_points = network_fn(image)

        if FLAGS.moving_average_decay:
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, tf_global_step)
            variables_to_restore = variable_averages.variables_to_restore(
                slim.get_model_variables())
            variables_to_restore[tf_global_step.op.name] = tf_global_step
        else:
            variables_to_restore = slim.get_variables_to_restore()

        predictions = tf.argmax(logits, 1)

        if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
            checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
        else:
            checkpoint_path = FLAGS.checkpoint_path

        tf.logging.info('Evaluating %s' % checkpoint_path)

        with tf.Session() as sess:

            saver = tf.train.Saver()
            saver.restore(sess, checkpoint_path)

            sample_images = [
                "/home/soumyadeep_morphle_in/tmp/data/wbc_morphle/wbc_images/neutrophils/x0y11_0.jpg"
            ]

            from os import listdir
            from os.path import isfile, join
            import os
            in_dir = "/home/soumyadeep_morphle_in/tmp/data/wbc_morphle/wbc_images/monocytes"
            sample_images = [
                os.path.join(in_dir, f) for f in listdir(in_dir)
                if isfile(join(in_dir, f))
            ]

            #with tf.Session() as sess:
            for img in sample_images:
                im = Image.open(img).resize((eval_image_size, eval_image_size))
                im = np.array(im)
                im = im.reshape(1, eval_image_size, eval_image_size, 3)

                end_points_values, logit_values, prediction_values = sess.run(
                    [end_points, logits, predictions], feed_dict={image: im})
                print(prediction_values)
def main(_):
  if not FLAGS.dataset_dir:
    raise ValueError('You must supply the dataset directory with --dataset_dir')

  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    #######################
    # Config model_deploy #
    #######################
    deploy_config = model_deploy.DeploymentConfig(
        num_clones=1,
        clone_on_cpu=False,
        replica_id=0,
        num_replicas=1,
        num_ps_tasks=0)

    # Create global_step
    with tf.device(deploy_config.variables_device()):
      global_step = slim.create_global_step()

    ######################
    # Select the dataset #
    ######################
    dataset = dataset_factory.get_dataset(
        'flowers', 'train', FLAGS.dataset_dir)

    ######################
    # Select the network #
    ######################
    network_fn = nets_factory.get_network_fn(
        'mobilenet_v1',
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        weight_decay=FLAGS.weight_decay,
        is_training=True)

    #####################################
    # Select the preprocessing function #
    #####################################
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
        'mobilenet_v1',
        is_training=True)

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    with tf.device(deploy_config.inputs_device()):
      provider = slim.dataset_data_provider.DatasetDataProvider(
          dataset,
          num_readers=4,
          common_queue_capacity=20 * FLAGS.batch_size,
          common_queue_min=10 * FLAGS.batch_size)
      [image, label] = provider.get(['image', 'label'])
      label -= FLAGS.labels_offset

      train_image_size = network_fn.default_image_size

      image = image_preprocessing_fn(image, train_image_size, train_image_size)

      images, labels = tf.train.batch(
          [image, label],
          batch_size=FLAGS.batch_size,
          num_threads=4,
          capacity=5 * FLAGS.batch_size)
      labels = slim.one_hot_encoding(
          labels, dataset.num_classes - FLAGS.labels_offset)
      batch_queue = slim.prefetch_queue.prefetch_queue(
          [images, labels], capacity=2 * deploy_config.num_clones)

    ####################
    # Define the model #
    ####################
    def clone_fn(batch_queue):
      """Allows data parallelism by creating multiple clones of network_fn."""
      images, labels = batch_queue.dequeue()
      logits, end_points = network_fn(images)

      #############################
      # Specify the loss function #
      #############################
      slim.losses.softmax_cross_entropy(
          logits, labels, label_smoothing=FLAGS.label_smoothing, weights=1.0)
      return end_points

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
    first_clone_scope = deploy_config.clone_scope(0)
    # Gather update_ops from the first clone. These contain, for example,
    # the updates for the batch_norm variables created by network_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Add summaries for end_points.
    end_points = clones[0].outputs
    for end_point in end_points:
      x = end_points[end_point]
      summaries.add(tf.summary.histogram('activations/' + end_point, x))
      summaries.add(tf.summary.scalar('sparsity/' + end_point,
                                      tf.nn.zero_fraction(x)))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Add summaries for variables.
    for variable in slim.get_model_variables():
      summaries.add(tf.summary.histogram(variable.op.name, variable))

    moving_average_variables, variable_averages = None, None

    #########################################
    # Configure the optimization procedure. #
    #########################################
    with tf.device(deploy_config.optimizer_device()):

      num_epochs_per_decay = 2.5
      decay_steps = int(dataset.num_samples / FLAGS.batch_size *
                        num_epochs_per_decay)
      learning_rate = tf.train.exponential_decay(FLAGS.learning_rate,
                                  global_step,
                                  decay_steps,
                                  _LEARNING_RATE_DECAY_FACTOR,
                                  staircase=True,
                                    name='exponential_decay_learning_rate')

      optimizer = tf.train.RMSPropOptimizer(
                           learning_rate,
                           decay=FLAGS.rmsprop_decay,
                           momentum=FLAGS.rmsprop_momentum,
                           epsilon=FLAGS.opt_epsilon)
      summaries.add(tf.summary.scalar('learning_rate', learning_rate))

    # Variables to train.
    variables_to_train = _get_variables_to_train()

    #  and returns a train_tensor and summary_op
    total_loss, clones_gradients = model_deploy.optimize_clones(
        clones,
        optimizer,
        var_list=variables_to_train)
    # Add total_loss to summary.
    summaries.add(tf.summary.scalar('total_loss', total_loss))

    # Create gradient updates.
    grad_updates = optimizer.apply_gradients(clones_gradients,
                                             global_step=global_step)
    update_ops.append(grad_updates)

    update_op = tf.group(*update_ops)
    with tf.control_dependencies([update_op]):
      train_tensor = tf.identity(total_loss, name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                       first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries), name='summary_op')

    session_config = tf.ConfigProto(allow_soft_placement=True,
                                    log_device_placement=False)

    ###########################
    # Kicks off the training. #
    ###########################
    slim.learning.train(
        train_tensor,
        logdir=FLAGS.train_dir,
        master=FLAGS.master,
        is_chief=True,
        session_config=session_config,
        init_fn=_get_init_fn(),
        summary_op=summary_op,
        number_of_steps=FLAGS.max_number_of_steps,
        log_every_n_steps=10,
        save_summaries_secs=300,
        save_interval_secs=300,
        sync_optimizer=optimizer if False else None)
示例#47
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    num_batches = 2

    for bb in np.arange(0, num_batches):

        batch_name = 'batch' + str(bb)

        #      tf.app.flags.DEFINE_string(
        #              'dataset_split_name',batch_name, 'The name of the train/test split.')

        tf.logging.set_verbosity(tf.logging.INFO)

        with tf.Graph().as_default():
            tf_global_step = slim.get_or_create_global_step()

            ######################
            # Select the dataset #
            ######################
            dataset = dataset_factory_MMH.get_dataset(FLAGS.dataset_name,
                                                      batch_name,
                                                      FLAGS.dataset_dir)

            ####################
            # Select the model #
            ####################
            network_fn = nets_factory.get_network_fn(
                FLAGS.model_name,
                num_classes=(dataset.num_classes - FLAGS.labels_offset),
                is_training=False)

            ##############################################################
            # Create a dataset provider that loads data from the dataset #
            ##############################################################
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=1,
                shuffle=False,
                common_queue_capacity=2 * FLAGS.batch_size,
                common_queue_min=FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            #####################################
            # Select the preprocessing function #
            #####################################
            preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
            image_preprocessing_fn = preprocessing_factory_MMH.get_preprocessing(
                preprocessing_name, is_training=False, flipLR=False)

            eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, eval_image_size,
                                           eval_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)

            #    ims_orig = tf.identity(images);
            #    labels_orig = tf.identity(labels);

            ####################
            # Define the model #
            ####################
            logits, end_pts = network_fn(images)

            if FLAGS.moving_average_decay:
                variable_averages = tf.train.ExponentialMovingAverage(
                    FLAGS.moving_average_decay, tf_global_step)
                variables_to_restore = variable_averages.variables_to_restore(
                    slim.get_model_variables())
                variables_to_restore[tf_global_step.op.name] = tf_global_step
            else:
                variables_to_restore = slim.get_variables_to_restore()

            predictions = tf.argmax(logits, 1)
            labels = tf.squeeze(labels)

            # Define the metrics:
            names_to_values, names_to_updates = slim.metrics.aggregate_metric_map(
                {
                    'Accuracy':
                    slim.metrics.streaming_accuracy(predictions, labels),
                    'Recall_5':
                    slim.metrics.streaming_recall_at_k(logits, labels, 5),
                })

            # Print the summaries to screen.
            for name, value in names_to_values.items():
                summary_name = 'eval/%s' % name
                op = tf.summary.scalar(summary_name, value, collections=[])
                op = tf.Print(op, [value], summary_name)
                tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)

            # TODO(sguada) use num_epochs=1
            if FLAGS.max_num_batches:
                num_batches = FLAGS.max_num_batches
            else:
                # This ensures that we make a single pass over all of the data.
                num_batches = math.ceil(dataset.num_samples /
                                        float(FLAGS.batch_size))

            if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
                checkpoint_path = tf.train.latest_checkpoint(
                    FLAGS.checkpoint_path)
            else:
                checkpoint_path = FLAGS.checkpoint_path

            tf.logging.info('Evaluating %s' % checkpoint_path)

            out = slim.evaluation.evaluate_once(
                master=FLAGS.master,
                checkpoint_path=checkpoint_path,
                logdir=FLAGS.eval_dir,
                num_evals=num_batches,
                eval_op=list(names_to_updates.values()),
                final_op={
                    'logits': logits,
                    'end_pts': end_pts,
                    'images': images,
                    'labels': labels,
                    'predictions': predictions
                },
                variables_to_restore=variables_to_restore)

            end_pts = out['end_pts']

            logits = out['logits']

            prelogits = end_pts['PreLogits']

            images = out['images']

            labels = out['labels']
            #        print(np.max(labels))
            predictions = out['predictions']
            # this is the very last layer before logit conversion
            #        lastlayer_name = 'PreLogits'

            #        lastlayer_weights =end_pts[lastlayer_name]

            fn2save = save_weights_dir + '/' + batch_name + '_logits.npy'
            np.save(fn2save, logits)

            fn2save = save_weights_dir + '/' + batch_name + '_prelogits.npy'
            np.save(fn2save, prelogits)

            fn2save = save_weights_dir + '/' + batch_name + '_ims_orig.npy'
            np.save(fn2save, images)

            fn2save = save_weights_dir + '/' + batch_name + '_labels_orig.npy'
            np.save(fn2save, labels)

            fn2save = save_weights_dir + '/' + batch_name + '_labels_predicted.npy'
            np.save(fn2save, predictions)
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_classification.get_dataset(
            FLAGS.dataset_dir, FLAGS.num_samples, FLAGS.num_classes,
            FLAGS.labels_to_names_path)
        """
    ######################
    # Select the network #
    ######################
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        weight_decay=FLAGS.weight_decay,
        is_training=True)
    """
        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=None,
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset
            print('label')
            print(label)

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=2 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)
            with tf.variable_scope("InceptionV3", reuse=True) as scope:
                # 辅助分类节点部分
                with slim.arg_scope(
                    [slim.conv2d, slim.max_pool2d, slim.avg_pool2d],
                        stride=1,
                        padding="SAME"):
                    # 通过end_points取到Mixed_6e
                    aux_logits = end_points["Mixed_6e"]
                    with tf.variable_scope("AuxLogits", reuse=tf.AUTO_REUSE):
                        aux_logits = slim.avg_pool2d(aux_logits,
                                                     kernel_size=[5, 5],
                                                     stride=3,
                                                     padding="VALID",
                                                     scope="Avgpool_1a_5x5")
                        aux_logits = slim.conv2d(aux_logits,
                                                 num_outputs=128,
                                                 kernel_size=[1, 1],
                                                 scope="Conv2d_1b_1x1")
                        aux_logits = slim.conv2d(
                            aux_logits,
                            num_outputs=768,
                            kernel_size=[5, 5],
                            weights_initializer=trunc_normal(0.01),
                            padding="VALID",
                            scope="Conv2d_2a_5x5")
                        print('aux_logits')
                        print(aux_logits)
                        aux_logits = slim.conv2d(
                            aux_logits,
                            num_outputs=2,
                            kernel_size=[1, 1],
                            activation_fn=None,
                            normalizer_fn=None,
                            weights_initializer=trunc_normal(0.001),
                            scope="Conv2d_10b_1x1")
                        # 消除tensor中前两个维度为1的维度
                        aux_logits = tf.squeeze(aux_logits,
                                                axis=[1, 2],
                                                name="SpatialSqueeze")
                        end_points[
                            "AuxLogits"] = aux_logits  # 将辅助节点分类的输出aux_logits存到end_points中
            net = slim.dropout(logits, keep_prob=0.8, scope='Dropout_lb')
            net = tf.squeeze(net, axis=[1, 2])
            print('logits')
            print(logits)
            print('labels')
            print(labels)
            with tf.name_scope('output'):
                weights = tf.Variable(
                    tf.truncated_normal([2048, 2], stddev=0.001))
                biases = tf.Variable(tf.zeros([2]))
                logits2 = tf.matmul(net, weights) + biases
                final_tensor = tf.nn.softmax(logits2, name='prob')
                end_points["final_tensor"] = final_tensor
                end_points["labels"] = labels
                print('weights')
                print(weights)
                cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
                    logits=logits2, labels=labels)
                cross_entropy_mean = tf.reduce_mean(cross_entropy)
                slim.losses.add_loss(cross_entropy_mean)
            #weights = tf.get_collection('weights')
            #print('weights')
            #print(weights)
            #biase = tf.add_n(tf.get_collection('biases'), 'loss2')
            """
      # 损失
      regularization_loss = tf.reduce_mean(tf.square(weights))
      hinge_loss = tf.reduce_mean(
          tf.square(
              tf.maximum(
                  tf.zeros([16, 2]), 1 - labels * logits
              )
          )
      )
      # with tf.name_scope("loss"):
      loss = regularization_loss + 1 * hinge_loss
      """
            """
      bottleneck_input = tf.squeeze(logits)
      # 全连接层
      with tf.name_scope('output'):
          weights = tf.Variable(tf.truncated_normal([2048, 2], stddev=0.001))
          biases = tf.Variable(tf.zeros([2]))
          logits = tf.matmul(bottleneck_input, weights) + biases
          final_tensor = tf.nn.softmax(logits, name='prob')
          # 损失
          regularization_loss = tf.reduce_mean(tf.square(weights))
          hinge_loss = tf.reduce_mean(
              tf.square(
                  tf.maximum(
                      tf.zeros([16, 2]), 1 - labels * logits
                  )
              )
          )
          # with tf.name_scope("loss"):
          my_loss = regularization_loss + 1.0 * hinge_loss
          slim.losses.add_loss(my_loss)
      """
            """
      #############################
      # Specify the loss function #
      #############################
      if 'AuxLogits' in end_points:
        slim.losses.softmax_cross_entropy(
            end_points['AuxLogits'], labels,
            label_smoothing=FLAGS.label_smoothing, weights=0.4,
            scope='aux_loss')
      slim.losses.softmax_cross_entropy(
          logits, labels, label_smoothing=FLAGS.label_smoothing, weights=1.0)
      """
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        losses_sum = 0
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))
            losses_sum += loss
        print('tf.GraphKeys.LOSSES')
        print(tf.GraphKeys.LOSSES)

        with tf.name_scope('evaluation'):
            correct_prediction = tf.equal(
                tf.argmax(end_points["final_tensor"], 1),
                tf.argmax(end_points["labels"], 1))
            evaluation_step = tf.reduce_mean(
                tf.cast(correct_prediction, tf.float32))
            summaries.add(tf.summary.scalar('accuracy', evaluation_step))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))
        print('total_loss')
        print(total_loss)

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            init_fn=_get_init_fn(),
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            sync_optimizer=optimizer if FLAGS.sync_replicas else None)
示例#49
0
def main(_):
  if not FLAGS.dataset_dir:
    raise ValueError('You must supply the dataset directory with --dataset_dir')

  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    tf_global_step = slim.get_or_create_global_step()

    ######################
    # Select the dataset #
    ######################
    dataset = dataset_factory.get_dataset(
        FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

    ####################
    # Select the model #
    ####################
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        is_training=False)

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    provider = slim.dataset_data_provider.DatasetDataProvider(
        dataset,
        shuffle=False,
        common_queue_capacity=2 * FLAGS.batch_size,
        common_queue_min=FLAGS.batch_size)
    [image, label] = provider.get(['image', 'label'])
    label -= FLAGS.labels_offset

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
        preprocessing_name,
        is_training=False)

    eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size

    image = image_preprocessing_fn(image, eval_image_size, eval_image_size)

    images, labels = tf.train.batch(
        [image, label],
        batch_size=FLAGS.batch_size,
        num_threads=FLAGS.num_preprocessing_threads,
        capacity=5 * FLAGS.batch_size)

    ####################
    # Define the model #
    ####################
    logits, _ = network_fn(images)

    if FLAGS.moving_average_decay:
      variable_averages = tf.train.ExponentialMovingAverage(
          FLAGS.moving_average_decay, tf_global_step)
      variables_to_restore = variable_averages.variables_to_restore(
          slim.get_model_variables())
      variables_to_restore[tf_global_step.op.name] = tf_global_step
    else:
      variables_to_restore = slim.get_variables_to_restore()

    predictions = tf.argmax(logits, 1)
    labels = tf.squeeze(labels)

    # Define the metrics:
    names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
        'Accuracy': slim.metrics.streaming_accuracy(predictions, labels),
        'Recall_5': slim.metrics.streaming_recall_at_k(
            logits, labels, 5),
    })

    # Print the summaries to screen.
    for name, value in names_to_values.items():
      summary_name = 'eval/%s' % name
      op = tf.summary.scalar(summary_name, value, collections=[])
      op = tf.Print(op, [value], summary_name)
      tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)

    # TODO(sguada) use num_epochs=1
    if FLAGS.max_num_batches:
      num_batches = FLAGS.max_num_batches
    else:
      # This ensures that we make a single pass over all of the data.
      num_batches = math.ceil(dataset.num_samples / float(FLAGS.batch_size))

    if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
      checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
    else:
      checkpoint_path = FLAGS.checkpoint_path

    tf.logging.info('Evaluating %s' % checkpoint_path)

    slim.evaluation.evaluate_once(
        master=FLAGS.master,
        checkpoint_path=checkpoint_path,
        logdir=FLAGS.eval_dir,
        num_evals=num_batches,
        eval_op=list(names_to_updates.values()),
        variables_to_restore=variables_to_restore)
示例#50
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        ######################
        # Config model_deploy#
        ######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ####################
        # Select the network #
        ####################

        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size
            new_height = FLAGS.New_Height_Of_Image or network_fn.default_image_size
            new_width = FLAGS.New_Width_Of_Image or network_fn.default_image_size

            #         image = image_preprocessing_fn(image, train_image_size, train_image_size)
            image = image_preprocessing_fn(image, new_height, new_width)

            #  io.imshow(image)
            #  io.show()
            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            #      tf.image_summary('images', images)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weight=0.4,
                    scope='aux_loss')
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weight=1.0)

            # Adding the accuracy metric
            with tf.name_scope('accuracy'):
                predictions = tf.argmax(logits, 1)
                labels = tf.argmax(labels, 1)
                accuracy = tf.reduce_mean(
                    tf.to_float(tf.equal(predictions, labels)))
                tf.add_to_collection('accuracy', accuracy)
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.histogram_summary('activations/' + end_point, x))
            summaries.add(
                tf.scalar_summary('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.scalar_summary('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.histogram_summary(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(
                tf.scalar_summary('learning_rate',
                                  learning_rate,
                                  name='learning_rate'))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables,
                replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
                total_num_replicas=FLAGS.worker_replicas)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)

        # # Add total_loss to summary.
        # summaries.add(tf.scalar_summary('total_loss', total_loss,
        #                                 name='total_loss'))

        # Add total_loss and accuacy to summary.
        summaries.add(
            tf.scalar_summary('eval/Total_Loss', total_loss,
                              name='total_loss'))
        accuracy = tf.get_collection('accuracy', first_clone_scope)[0]
        summaries.add(
            tf.scalar_summary('eval/Accuracy', accuracy, name='accuracy'))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op],
                                                          total_loss,
                                                          name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.merge_summary(list(summaries), name='summary_op')

        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            init_fn=_get_init_fn(),
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            sync_optimizer=optimizer if FLAGS.sync_replicas else None)
def main(_):
  if not FLAGS.dataset_dir:
    raise ValueError('You must supply the dataset directory with --dataset_dir')

  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    #######################
    # Config model_deploy #
    #######################
    deploy_config = model_deploy.DeploymentConfig(
        num_clones=FLAGS.num_clones,
        clone_on_cpu=FLAGS.clone_on_cpu,
        replica_id=FLAGS.task,
        num_replicas=FLAGS.worker_replicas,
        num_ps_tasks=FLAGS.num_ps_tasks)

    # Create global_step
    with tf.device(deploy_config.variables_device()):
      global_step = slim.create_global_step()

    ######################
    # Select the dataset #
    ######################
    dataset = dataset_factory.get_dataset(
        FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

    ######################
    # Select the network #
    ######################
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        weight_decay=FLAGS.weight_decay,
        is_training=True)

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
        preprocessing_name,
        is_training=True)

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    with tf.device(deploy_config.inputs_device()):
      provider = slim.dataset_data_provider.DatasetDataProvider(
          dataset,
          num_readers=FLAGS.num_readers,
          common_queue_capacity=20 * FLAGS.batch_size,
          common_queue_min=10 * FLAGS.batch_size)
      [image, label] = provider.get(['image', 'label'])
      label -= FLAGS.labels_offset

      train_image_size = FLAGS.train_image_size or network_fn.default_image_size

      image = image_preprocessing_fn(image, train_image_size, train_image_size)

      images, labels = tf.train.batch(
          [image, label],
          batch_size=FLAGS.batch_size,
          num_threads=FLAGS.num_preprocessing_threads,
          capacity=5 * FLAGS.batch_size)
      labels = slim.one_hot_encoding(
          labels, dataset.num_classes - FLAGS.labels_offset)
      batch_queue = slim.prefetch_queue.prefetch_queue(
          [images, labels], capacity=2 * deploy_config.num_clones)

    ####################
    # Define the model #
    ####################
    def clone_fn(batch_queue):
      """Allows data parallelism by creating multiple clones of network_fn."""
      with tf.device(deploy_config.inputs_device()):
        images, labels = batch_queue.dequeue()
      logits, end_points = network_fn(images)

      #############################
      # Specify the loss function #
      #############################
      if 'AuxLogits' in end_points:
        tf.losses.softmax_cross_entropy(
            logits=end_points['AuxLogits'], onehot_labels=labels,
            label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss')
      tf.losses.softmax_cross_entropy(
          logits=logits, onehot_labels=labels,
          label_smoothing=FLAGS.label_smoothing, weights=1.0)
      return end_points

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
    first_clone_scope = deploy_config.clone_scope(0)
    # Gather update_ops from the first clone. These contain, for example,
    # the updates for the batch_norm variables created by network_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Add summaries for end_points.
    end_points = clones[0].outputs
    for end_point in end_points:
      x = end_points[end_point]
      summaries.add(tf.summary.histogram('activations/' + end_point, x))
      summaries.add(tf.summary.scalar('sparsity/' + end_point,
                                      tf.nn.zero_fraction(x)))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Add summaries for variables.
    for variable in slim.get_model_variables():
      summaries.add(tf.summary.histogram(variable.op.name, variable))

    #################################
    # Configure the moving averages #
    #################################
    if FLAGS.moving_average_decay:
      moving_average_variables = slim.get_model_variables()
      variable_averages = tf.train.ExponentialMovingAverage(
          FLAGS.moving_average_decay, global_step)
    else:
      moving_average_variables, variable_averages = None, None

    #########################################
    # Configure the optimization procedure. #
    #########################################
    with tf.device(deploy_config.optimizer_device()):
      learning_rate = _configure_learning_rate(dataset.num_samples, global_step)
      optimizer = _configure_optimizer(learning_rate)
      summaries.add(tf.summary.scalar('learning_rate', learning_rate))

    if FLAGS.sync_replicas:
      # If sync_replicas is enabled, the averaging will be done in the chief
      # queue runner.
      optimizer = tf.train.SyncReplicasOptimizer(
          opt=optimizer,
          replicas_to_aggregate=FLAGS.replicas_to_aggregate,
          variable_averages=variable_averages,
          variables_to_average=moving_average_variables,
          replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
          total_num_replicas=FLAGS.worker_replicas)
    elif FLAGS.moving_average_decay:
      # Update ops executed locally by trainer.
      update_ops.append(variable_averages.apply(moving_average_variables))

    # Variables to train.
    variables_to_train = _get_variables_to_train()

    #  and returns a train_tensor and summary_op
    total_loss, clones_gradients = model_deploy.optimize_clones(
        clones,
        optimizer,
        var_list=variables_to_train)
    # Add total_loss to summary.
    summaries.add(tf.summary.scalar('total_loss', total_loss))

    # Create gradient updates.
    grad_updates = optimizer.apply_gradients(clones_gradients,
                                             global_step=global_step)
    update_ops.append(grad_updates)

    update_op = tf.group(*update_ops)
    with tf.control_dependencies([update_op]):
      train_tensor = tf.identity(total_loss, name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                       first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries), name='summary_op')

    ###########################
    # Kicks off the training. #
    ###########################
    slim.learning.train(
        train_tensor,
        logdir=FLAGS.train_dir,
        master=FLAGS.master,
        is_chief=(FLAGS.task == 0),
        init_fn=_get_init_fn(),
        summary_op=summary_op,
        number_of_steps=FLAGS.max_number_of_steps,
        log_every_n_steps=FLAGS.log_every_n_steps,
        save_summaries_secs=FLAGS.save_summaries_secs,
        save_interval_secs=FLAGS.save_interval_secs,
        sync_optimizer=optimizer if FLAGS.sync_replicas else None)
def main(_):
  if not FLAGS.dataset_dir:
    raise ValueError('You must supply the dataset directory with --dataset_dir')

  if not os.path.isfile(FLAGS.checkpoint_path):
    FLAGS.eval_dir = os.path.join(FLAGS.checkpoint_path, 'eval')
  else:
    FLAGS.eval_dir = os.path.join(
        os.path.dirname(FLAGS.checkpoint_path), 'eval')

  try:
    os.makedirs(FLAGS.eval_dir)
  except OSError:
    pass

  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    tf_global_step = slim.get_or_create_global_step()

    ######################
    # Select the dataset #
    ######################
    dataset = dataset_factory.get_dataset(
        FLAGS.dataset_name, FLAGS.dataset_split_name,
        FLAGS.dataset_dir.split(','),
        FLAGS.dataset_list_dir,
        num_samples=FLAGS.frames_per_video,
        modality=FLAGS.modality,
        split_id=FLAGS.split_id)

    ####################
    # Select the model #
    ####################
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        batch_size=FLAGS.batch_size,
        is_training=False)

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    provider = dataset_data_provider.DatasetDataProvider(
        dataset,
        shuffle=FLAGS.force_random_shuffle,
        common_queue_capacity=2 * FLAGS.batch_size,
        common_queue_min=FLAGS.batch_size,
        bgr_flips=FLAGS.bgr_flip)
    [image, label] = provider.get(['image', 'label'])
    label = tf.cast(tf.string_to_number(label, tf.int32),
        tf.int64)
    label.set_shape(())
    label -= FLAGS.labels_offset

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
        preprocessing_name,
        is_training=False)

    eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size

    image = image_preprocessing_fn(image, eval_image_size, eval_image_size,
                                   model_name=FLAGS.model_name,
                                   ncrops=FLAGS.ncrops,
                                   out_dim_scale=FLAGS.out_dim_scale)

    images, labels = tf.train.batch(
        [image, label],
        batch_size=FLAGS.batch_size,
        num_threads=1 if FLAGS.store_feat is not None else FLAGS.num_preprocessing_threads,
        capacity=5 * FLAGS.batch_size)

    ####################
    # Define the model #
    ####################
    kwargs = {}
    if FLAGS.conv_endpoint is not None:
      kwargs['conv_endpoint'] = FLAGS.conv_endpoint
    logits, end_points = network_fn(
        images, pool_type=FLAGS.pooling,
        classifier_type=FLAGS.classifier_type,
        num_channels_stream=provider.num_channels_stream,
        netvlad_centers=FLAGS.netvlad_initCenters.split(','),
        stream_pool_type=FLAGS.stream_pool_type,
        **kwargs)
    end_points['images'] = images
    end_points['labels'] = labels

    if FLAGS.moving_average_decay:
      variable_averages = tf.train.ExponentialMovingAverage(
          FLAGS.moving_average_decay, tf_global_step)
      variables_to_restore = variable_averages.variables_to_restore(
          slim.get_model_variables())
      variables_to_restore[tf_global_step.op.name] = tf_global_step
    else:
      variables_to_restore = slim.get_variables_to_restore()

    predictions = tf.argmax(logits, 1)
    # rgirdhar: Because of the following, can't use with batch_size=1
    if FLAGS.batch_size > 1:
      labels = tf.squeeze(labels)

    # Define the metrics:
    names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
        'Accuracy': slim.metrics.streaming_accuracy(predictions, labels),
        'Recall@5': slim.metrics.streaming_recall_at_k(
            logits, labels, 5),
    })

    # Print the summaries to screen.
    for name, value in names_to_values.iteritems():
      summary_name = 'eval/%s' % name
      op = tf.scalar_summary(summary_name, value, collections=[])
      op = tf.Print(op, [value], summary_name)
      tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)

    # TODO(sguada) use num_epochs=1
    if FLAGS.max_num_batches:
      num_batches = FLAGS.max_num_batches
    else:
      # This ensures that we make a single pass over all of the data.
      num_batches = int(math.ceil(dataset.num_samples /
                                  float(FLAGS.batch_size)))

    if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
      checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
    else:
      checkpoint_path = FLAGS.checkpoint_path

    tf.logging.info('Evaluating %s' % checkpoint_path)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True

    if FLAGS.store_feat is not None:
      assert(FLAGS.store_feat_path is not None)
      from tensorflow.python.training import supervisor
      from tensorflow.python.framework import ops
      import h5py
      saver = tf.train.Saver(variables_to_restore)
      sv = supervisor.Supervisor(graph=ops.get_default_graph(),
                                 logdir=None,
                                 summary_op=None,
                                 summary_writer=None,
                                 global_step=None,
                                 saver=None)
      ept_names_to_store = FLAGS.store_feat.split(',')
      try:
        ept_to_store = [end_points[el] for el in ept_names_to_store]
      except:
        logging.error('Endpoint not found')
        logging.error('Choose from %s' % ','.join(end_points.keys()))
        raise KeyError()
      res = dict([(epname, []) for epname in ept_names_to_store])
      with sv.managed_session(
          FLAGS.master, start_standard_services=False,
          config=config) as sess:
        saver.restore(sess, checkpoint_path)
        sv.start_queue_runners(sess)
        for j in range(num_batches):
          if j % 10 == 0:
            logging.info('Doing batch %d/%d' % (j, num_batches))
          feats = sess.run(ept_to_store)
          for eid, epname in enumerate(ept_names_to_store):
            res[epname].append(feats[eid])
      logging.info('Writing out features to %s' % FLAGS.store_feat_path)
      with h5py.File(FLAGS.store_feat_path, 'w') as fout:
        for epname in res.keys():
          fout.create_dataset(epname,
              data=np.concatenate(res[epname], axis=0),
              compression='gzip',
              compression_opts=FLAGS.feat_store_compression_opt)
    else:
      slim.evaluation.evaluate_once(
          master=FLAGS.master,
          checkpoint_path=checkpoint_path,
          logdir=FLAGS.eval_dir,
          num_evals=num_batches,
          eval_op=names_to_updates.values(),
          variables_to_restore=variables_to_restore,
          session_config=config)