def main(_):
    begin_time = time.time()
    print("begin time:", begin_time)
    logging.debug("begin begin_time:{}".format(begin_time))

    base_dir = "/data/oHongMenYan/distracted-driver-detection-dataset"
    out_dir = "/output"
    # base_dir = r"E:\tmp\data\state-farm-distracted-driver-detection"
    # out_dir = r"E:\tmp\data\state-farm-distracted-driver-detection\output"

    model_image_size = (240, 320)
    # fine_tune_layer = 152
    # final_layer = 176
    # visual_layer = 172
    num_classes = 10
    # batch_size = FLAGS.batch_size
    batch_size = 128
    batch_size = 64
    batch_size = 32
    train_examples_num = 20787
    # train_examples_num = 64
    # train_examples_num = 32
    epochs_num_per_optimizer = 50
    # epochs_num_per_optimizer = 1
    num_steps = int(train_examples_num * epochs_num_per_optimizer / batch_size)

    imgs_dir = os.path.join(out_dir, "img")
    if not os.path.exists(imgs_dir):
        os.makedirs(imgs_dir, exist_ok=True)

    logs_dir = os.path.join(out_dir, "logs")
    if not os.path.exists(logs_dir):
        os.makedirs(logs_dir, exist_ok=True)

    # 加载数据集
    # 读取tfrecord文件

    # dataset_train = FLAGS.dataset_train
    dataset_train = os.path.join(base_dir, 'train.record')
    # dataset_val = FLAGS.dataset_val
    dataset_val = os.path.join(base_dir, 'val.record')

    # data_provider = slim.dataset_data_provider.DatasetDataProvider(dataset)
    # image, label = data_provider.get(['image', 'label'])
    # 加载数据文件
    image_train, label_train = utils.read_TFRecord(
        dataset_train,
        image_shape=model_image_size,
        batch_size=batch_size,
        num_epochs=1e4)
    image_valid, label_valid = utils.read_TFRecord(
        dataset_val,
        image_shape=model_image_size,
        batch_size=batch_size,
        num_epochs=1e4)

    # tfrecord数据已经预处理了,此处省略
    # resnet50 ImageNet的ckpt,
    checkpoint_path = os.path.join(base_dir, 'resnet_v1_50.ckpt')
    # checkpoint_path = os.path.join(base_dir, 'model.ckpt-10391')
    # checkpoint_path = os.path.join(base_dir, 'ckpt')

    resnet_model = model.Model(num_classes=num_classes,
                               is_training=True,
                               fixed_resize_side=model_image_size[0],
                               default_image_size=model_image_size[0])
    prediction_dict = resnet_model.predict(image_train)
    loss_dict = resnet_model.loss(prediction_dict, label_train)
    loss = loss_dict['loss']
    postprocess_dict = resnet_model.postprocess(prediction_dict)
    accuracy = resnet_model.accuracy(postprocess_dict, label_train)

    tf.summary.scalar('loss', loss)
    tf.summary.scalar('accuracy', accuracy)

    global_step = slim.create_global_step()
    if not global_step:
        print("global_step is none")
        # Creates a variable to hold the global_step.
        global_step = tf.Variable(0,
                                  trainable=False,
                                  name='global_step',
                                  dtype=tf.int64)
        print('global_step:', global_step)
    init_fn = utils.get_init_fn(checkpoint_path=checkpoint_path)

    # learning_rate = 1e-4
    # adam优化器
    with tf.variable_scope("adam_vars"):
        learning_rate = tf.Variable(initial_value=1e-5,
                                    dtype=tf.float32,
                                    name='learning_rate')
        adam_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
        adam_gradients = adam_optimizer.compute_gradients(loss=loss)

        # for grad_var_pair in adam_gradients:
        #     current_variable = grad_var_pair[1]
        #     current_gradient = grad_var_pair[0]

        # gradient_name_to_save = current_variable.name.replace(":", "_")
        # tf.summary.histogram(gradient_name_to_save, current_gradient)
        adam_train_step = adam_optimizer.apply_gradients(
            grads_and_vars=adam_gradients, global_step=global_step)
        # train_op = slim.learning.create_train_op(loss, adam_optimizer, summarize_gradients=True)
    lr_op = tf.summary.scalar('learning_rate', learning_rate)
    # tf.summary.scalar('learning_rate', learning_rate)

    # RMSprop优化器 lr=1e-5
    # with tf.variable_scope("rmsprop_vars"):
    #     rmsprop_lr = 1e-5
    #     rmsprop_optimizer = tf.train.AdamOptimizer(learning_rate=rmsprop_lr)
    #     rmsprop_gradients = rmsprop_optimizer.compute_gradients(loss=loss)
    #
    #     for grad_var_pair in rmsprop_gradients:
    #         current_variable = grad_var_pair[1]
    #         current_gradient = grad_var_pair[0]
    #
    #         gradient_name_to_save = current_variable.name.replace(":", "_")
    #         tf.summary.histogram(gradient_name_to_save, current_gradient)
    #     rmsprop_train_step = rmsprop_optimizer.apply_gradients(grads_and_vars=adam_gradients, global_step=global_step)
    # rmsprop_train_op = slim.learning.create_train_op(loss, rmsprop_optimizer, summarize_gradients=True)

    # # adam优化器
    # with tf.variable_scope("adam_vars"):
    #     adam_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    #     # adam_gradients = adam_optimizer.compute_gradients(loss=loss)
    #
    #     # for grad_var_pair in adam_gradients:
    #     #     current_variable = grad_var_pair[1]
    #     #     current_gradient = grad_var_pair[0]
    #     #
    #     #     gradient_name_to_save = current_variable.name.replace(":", "_")
    #     #     tf.summary.histogram(gradient_name_to_save, current_gradient)
    #     # train_step = adam_optimizer.apply_gradients(grads_and_vars=adam_gradients, global_step=global_step)
    #     # train_op = slim.learning.create_train_op(loss, adam_optimizer, summarize_gradients=True)
    #     train_adam = adam_optimizer.minimize(loss, global_step=global_step)
    #
    # # tf.summary.scalar('learning_rate', learning_rate)
    #
    #
    # # slim.learning.train(train_op=train_op, logdir=logs_dir, global_step=global_step, init_fn=init_fn,
    # #                     number_of_steps=num_steps, save_summaries_secs=20, save_interval_secs=600)
    #
    #
    # # num_steps = 2*num_steps
    # # # RMSprop优化器 lr=1e-5
    # with tf.variable_scope("rmsprop_vars"):
    #     rmsprop_lr = 1e-5
    #     rmsprop_optimizer = tf.train.AdamOptimizer(learning_rate=rmsprop_lr)
    #     # rmsprop_gradients = rmsprop_optimizer.compute_gradients(loss=loss)
    #
    #     # for grad_var_pair in rmsprop_gradients:
    #     #     current_variable = grad_var_pair[1]
    #     #     current_gradient = grad_var_pair[0]
    #     #
    #     #     gradient_name_to_save = current_variable.name.replace(":", "_")
    #     #     tf.summary.histogram(gradient_name_to_save, current_gradient)
    #     # train_step = rmsprop_optimizer.apply_gradients(grads_and_vars=adam_gradients, global_step=global_step)
    #     train_rmsprop = rmsprop_optimizer.minimize(loss, global_step=global_step)
    #     # rmsprop_train_op = slim.learning.create_train_op(loss, rmsprop_optimizer, summarize_gradients=True)
    #
    # # slim.learning.train(train_op=rmsprop_train_op, logdir=logs_dir, global_step=global_step, init_fn=init_fn,
    # #                     number_of_steps=num_steps, save_summaries_secs=20, save_interval_secs=600)

    merged_summary_op = tf.summary.merge_all()
    summary_string_writer = tf.summary.FileWriter(logs_dir)

    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init_op = tf.global_variables_initializer()
    init_local_op = tf.local_variables_initializer()

    with sess:
        sess.run(init_op)
        sess.run(init_local_op)
        saver = tf.train.Saver(max_to_keep=5)
        init_fn(sess)
        # saver.restore(sess, checkpoint_path)

        logging.debug('checkpoint restored from [{0}]'.format(checkpoint_path))

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        start = time.time()
        print('adam go-----------------')
        for i in range(num_steps):
            gs, _ = sess.run([global_step, adam_train_step],
                             feed_dict={learning_rate: 1e-5})
            logging.debug("Current adam step: {0} _:{1} index:{2} ".format(
                gs, _, i))
            lr, adam_loss, summary_string, acc_score = sess.run(
                [learning_rate, loss, merged_summary_op, accuracy])
            logging.debug(
                "adam step {0} Current Loss: {1} acc_score:{2} index:{3}, learning_rate:{4}"
                .format(gs, adam_loss, acc_score, i, lr))
            end = time.time()
            logging.debug("adam [{0:.2f}] imgs/s".format(batch_size /
                                                         (end - start)))
            start = end

            summary_string_writer.add_summary(summary_string, i)
            if i == num_steps - 1:
                save_path = saver.save(sess,
                                       os.path.join(logs_dir,
                                                    "model_adam.ckpt"),
                                       global_step=gs)
                logging.debug("Model saved in file: %s" % save_path)

        # print('rmsprop go-----------------')
        # for i in range(num_steps):
        #     gs, _ = sess.run([global_step, rmsprop_train_step])
        #     logging.debug("Current rmsprop step: {0} _:{1} index:{2} ".format(gs, _, i))
        #     rmsprop_loss, summary_string, acc_score = sess.run([loss, merged_summary_op, accuracy])
        #     logging.debug("rmsprop step {0} Current Loss: {1} acc_score:{2} index:{3}".format(gs, rmsprop_loss, acc_score, i))
        #     end = time.time()
        #     logging.debug("rmsprop [{0:.2f}] imgs/s".format(batch_size / (end - start)))
        #     start = end
        #
        #     summary_string_writer.add_summary(summary_string, i)
        #     if i == num_steps - 1:
        #         save_path = saver.save(sess, os.path.join(logs_dir, "model_rmsprop.ckpt"), global_step=gs)
        #         logging.debug("Model saved in file: %s" % save_path)

        coord.request_stop()
        coord.join(threads)
        save_path = saver.save(sess,
                               os.path.join(logs_dir, "model.ckpt"),
                               global_step=gs)
        logging.debug("Model finally saved in file: %s" % save_path)

    cost_time = int(time.time() - begin_time)
    print("All done cost_time: %d " % (cost_time))

    summary_string_writer.close()

    logging.debug("All done cost_time:{}".format(cost_time))
Esempio n. 2
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError('You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
                num_clones=FLAGS.num_clones,
                clone_on_cpu=FLAGS.clone_on_cpu,
                replica_id=FLAGS.task,
                num_replicas=FLAGS.worker_replicas,
                num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(
                FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
                FLAGS.model_name,
                num_classes=(dataset.num_classes - FLAGS.labels_offset),
                weight_decay=FLAGS.weight_decay,
                is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
                preprocessing_name,
                is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                    dataset,
                    num_readers=FLAGS.num_readers,
                    common_queue_capacity=20 * FLAGS.batch_size,
                    common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size, train_image_size)

            images, labels = tf.train.batch(
                    [image, label],
                    batch_size=FLAGS.batch_size,
                    num_threads=FLAGS.num_preprocessing_threads,
                    capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                    labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                    [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                tf.losses.softmax_cross_entropy(
                        logits=end_points['AuxLogits'], onehot_labels=labels,
                        label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss')
            tf.losses.softmax_cross_entropy(
                    logits=logits, onehot_labels=labels,
                    label_smoothing=FLAGS.label_smoothing, weights=1.0)
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(tf.summary.scalar('sparsity/' + end_point,
                                                                            tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                    FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples, global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                    opt=optimizer,
                    replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                    variable_averages=variable_averages,
                    variables_to_average=moving_average_variables,
                    replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
                    total_num_replicas=FLAGS.worker_replicas)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #    and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
                clones,
                optimizer,
                var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                                                         global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op], total_loss,
                                                                                                            name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                                                             first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')


        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(
                train_tensor,
                logdir=FLAGS.train_dir,
                master=FLAGS.master,
                is_chief=(FLAGS.task == 0),
                init_fn=_get_init_fn(),
                summary_op=summary_op,
                number_of_steps=FLAGS.max_number_of_steps,
                log_every_n_steps=FLAGS.log_every_n_steps,
                save_summaries_secs=FLAGS.save_summaries_secs,
                save_interval_secs=FLAGS.save_interval_secs,
                sync_optimizer=optimizer if FLAGS.sync_replicas else None)
Esempio n. 3
0
def _train():
  """Training loop."""
  params = dict()
  params['batch_size'] = FLAGS.batch_size
  params['learning_rate'] = FLAGS.learning_rate

  with tf.device('/gpu:0'):
    global_step = slim.create_global_step()

  with tf.device('/gpu:0'):
    lrn_rate = tf.maximum(
        FLAGS.min_learning_rate,  # min_lr_rate.
        tf.train.exponential_decay(
            params['learning_rate'], global_step, 5000, 0.92, staircase=True))
    tf.summary.scalar('learning_rate', lrn_rate)
    optimizer = tf.train.AdamOptimizer(lrn_rate)

  with tf.device('/cpu:0'):
    (input_blocks, target_blocks, target_sem_blocks, target_lo_blocks,
     target_sem_lo_blocks) = reader.ReadSceneBlocksLevel(
         FLAGS.data_filepattern,
         FLAGS.train_samples,
         FLAGS.dim_block,
         FLAGS.height_block,
         FLAGS.stored_dim_block_hi,
         FLAGS.stored_height_block_hi,
         FLAGS.is_base_level,
         FLAGS.hierarchy_level,
         FLAGS.num_quant_levels,
         quantize=not FLAGS.p_norm,
         params=params,
         shuffle=True)

    if FLAGS.is_base_level:
      inputs_queue = slim.python.slim.data.prefetch_queue.prefetch_queue(
          (input_blocks, target_blocks, target_sem_blocks))
    else:
      inputs_queue = slim.python.slim.data.prefetch_queue.prefetch_queue(
          (input_blocks, target_blocks, target_sem_blocks, target_lo_blocks,
           target_sem_lo_blocks))

  def tower_fn(inputs_queue):
    """The tower function."""
    target_lo_blocks = None
    target_sem_blocks = None
    target_sem_lo_blocks = None
    if FLAGS.is_base_level:
      input_blocks, target_blocks, target_sem_blocks = inputs_queue.dequeue()
    else:
      (input_blocks, target_blocks, target_sem_blocks, target_lo_blocks,
       target_sem_lo_blocks) = inputs_queue.dequeue()
      
    ops = model.model(
        input_scan=input_blocks,
        target_scan_low_resolution=target_lo_blocks,
        target_scan=target_blocks,
        target_semantics_low_resolution=target_sem_lo_blocks,
        target_semantics=target_sem_blocks,
        predict_semantics=FLAGS.predict_semantics,
        use_p_norm=FLAGS.p_norm > 0,
        num_quant_levels=FLAGS.num_quant_levels)
    logits = ops['logits_geometry']
    logits_sem = ops['logits_semantics']

    # TODO(angeladai) change p-norm to l1
    if FLAGS.p_norm > 0:
      loss = losses.get_l1_loss_allgroups(
          logit_groups=logits,
          labels=target_blocks,
          logit_groups_sem=logits_sem,
          labels_sem=target_sem_blocks,
          weight_semantic=FLAGS.weight_semantic)
    else:
      loss = losses.get_probabilistic_loss_allgroups(
          logit_groups=logits,
          labels=target_blocks,
          logit_groups_sem=logits_sem,
          labels_sem=target_sem_blocks,
          num_quant_levels=FLAGS.num_quant_levels,
          weight_semantic=FLAGS.weight_semantic)
    if FLAGS.predict_semantics:
      tf.summary.scalar('Loss_Geo', loss['loss_geo'])
      tf.summary.scalar('Loss_Sem', loss['loss_sem'])

    # Reconstruct
    predictions_list = []
    temp = 100.0
    for l in logits:
      if FLAGS.p_norm > 0:
        predictions_list.append(l[:, :, :, :, 0])
      else:
        sz = l.shape_as_list()
        l = tf.reshape(l, [-1, sz[-1]])
        s = tf.multinomial(temp * l, 1)
        predictions_list.append(tf.reshape(s, sz[:-1]))
    if FLAGS.predict_semantics:
      target_sem_groups = [
          target_sem_blocks[:, ::2, ::2, ::2],
          target_sem_blocks[:, ::2, ::2, 1::2],
          target_sem_blocks[:, ::2, 1::2, ::2],
          target_sem_blocks[:, ::2, 1::2, 1::2],
          target_sem_blocks[:, 1::2, ::2, ::2],
          target_sem_blocks[:, 1::2, ::2, 1::2],
          target_sem_blocks[:, 1::2, 1::2, ::2],
          target_sem_blocks[:, 1::2, 1::2, 1::2]
      ]
      error_count = error_count_1 = 0
      error_norm = 0
      for n in range(len(logits_sem)):
        pred_sem = tf.argmax(logits_sem[n], 4)
        mask = tf.greater(target_sem_groups[n], 0)
        error_count += tf.count_nonzero(
            tf.cast(tf.boolean_mask(tensor=pred_sem, mask=mask), tf.int32) -
            tf.cast(
                tf.boolean_mask(tensor=target_sem_groups[n], mask=mask),
                tf.int32))
        error_norm += tf.count_nonzero(mask)
        if n == 0:
          error_count_1 = tf.cast(error_count, tf.float32) / tf.cast(
              error_norm, tf.float32)
      tf.summary.scalar(
          'Sem_Accuracy', 1.0 -
          tf.cast(error_count, tf.float32) / tf.cast(error_norm, tf.float32))
      tf.summary.scalar('Sem_Accuracy_Group_1', 1.0 - error_count_1)

    target_groups = [
        target_blocks[:, ::2, ::2, ::2, 0], target_blocks[:, ::2, ::2, 1::2, 0],
        target_blocks[:, ::2, 1::2, ::2, 0],
        target_blocks[:, ::2, 1::2, 1::2, 0],
        target_blocks[:, 1::2, ::2, ::2, 0],
        target_blocks[:, 1::2, ::2, 1::2, 0],
        target_blocks[:, 1::2, 1::2, ::2, 0],
        target_blocks[:, 1::2, 1::2, 1::2, 0]
    ]
    l1_recon_loss = 0.0
    recon_using_pred_occ = 0.0
    recon_using_target_occ = 0.0
    for k in range(len(predictions_list)):
      if FLAGS.p_norm > 0:
        p = (predictions_list[k] + 1) * 0.5 * constants.TRUNCATION
        t = (target_groups[k] + 1) * 0.5 * constants.TRUNCATION
      else:
        p = preprocessor.dequantize(predictions_list[k], FLAGS.num_quant_levels,
                                    constants.TRUNCATION)
        t = preprocessor.dequantize(
            (target_groups[k] + 1) * 0.5 * FLAGS.num_quant_levels,
            FLAGS.num_quant_levels, constants.TRUNCATION)
      l1_recon_loss += losses.get_recon_loss(pred=p, target=t)
      # for occupied space error use 1.5 to include transition from occ to empty
      recon_using_pred_occ += losses.get_recon_loss_for_occupied_space(
          t, p, 1.5)
      recon_using_target_occ += losses.get_recon_loss_for_occupied_space(
          p, t, 1.5)
    l1_recon_loss /= len(predictions_list)
    recon_using_pred_occ /= len(predictions_list)
    recon_using_target_occ /= len(predictions_list)
    tf.summary.scalar('Recon_Loss', l1_recon_loss)
    tf.summary.scalar('Recon_Loss_From_Pred_Occ', recon_using_pred_occ)
    tf.summary.scalar('Recon_Loss_From_Target_Occ', recon_using_target_occ)

    return {'logits': logits, 'loss': loss['loss']}

  with tf.device('/gpu:0'):
    total_loss = tower_fn(inputs_queue)['loss']
    tf.summary.scalar('Total_Loss', total_loss)

  session_config = tf.ConfigProto(
      allow_soft_placement=True, log_device_placement=False)
  train_op = slim.learning.create_train_op(total_loss, optimizer)

  # Run training.
  tf.logging.info('Running training')
  slim.learning.train(
      train_op,
      FLAGS.train_dir,
      session_config=session_config,
      save_summaries_secs=60,
      save_interval_secs=180,
      number_of_steps=FLAGS.number_of_steps,
      saver=tf.train.Saver(keep_checkpoint_every_n_hours=1., max_to_keep=50))
Esempio n. 4
0
from libs.nets.train_utils import _configure_learning_rate, _configure_optimizer, \
  _get_variables_to_train, _get_init_fn, get_var_list_to_restore

resnet50 = resnet_v1.resnet_v1_50
FLAGS = tf.app.flags.FLAGS

DEBUG = False

with tf.Graph().as_default():
    gpu_options = tf.GPUOptions(
        per_process_gpu_memory_fraction=0.8,
        allow_growth=True,
    )
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options,
                                          allow_soft_placement=True)) as sess:
        global_step = slim.create_global_step()

        ## data
        image, ih, iw, gt_boxes, gt_masks, num_instances, img_id = \
          coco.read('./data/coco/records/coco_train2014_00000-of-00040.tfrecord')
        image, gt_boxes, gt_masks = coco_preprocess.preprocess_image(
            image, gt_boxes, gt_masks, is_training=True)

        ##  network
        with slim.arg_scope(resnet_v1.resnet_arg_scope()):
            logits, end_points = resnet50(image, 1000, is_training=False)
        end_points['inputs'] = image

        for x in sorted(end_points.keys()):
            print(x, end_points[x].name, end_points[x].shape)
def main(_):
  # if not FLAGS.dataset_dir:
  #   raise ValueError('You must supply the dataset directory with --dataset_dir')

  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    #######################
    # Config model_deploy #
    #######################
    deploy_config = model_deploy.DeploymentConfig(
        num_clones=FLAGS.num_clones,
        clone_on_cpu=FLAGS.clone_on_cpu,
        replica_id=FLAGS.task,
        num_replicas=FLAGS.worker_replicas,
        num_ps_tasks=FLAGS.num_ps_tasks)

    # Create global_step
    with tf.device(deploy_config.variables_device()):
      global_step = slim.create_global_step()

    ######################
    # Select the dataset #
    ######################
    # dataset = dataset_factory.get_dataset(
    #     FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

    ######################
    # Select the network #
    ######################
    # network_fn = nets_factory.get_network_fn(
    #     FLAGS.model_name,
    #     num_classes=(dataset.num_classes - FLAGS.labels_offset),
    #     weight_decay=FLAGS.weight_decay,
    #     is_training=True)

    def localization_net_alpha(inputs, num_transformer, num_theta_params):
        """
        Utilize inception_v2 as the localization net of spatial transformer
        """
        # outputs 7*7*1024: default final_endpoint='Mixed_5c' before full connection layer
        with tf.variable_scope('inception_net'):
            net, _ = inception_v2.inception_v2_base(inputs)

        # fc layer using [1, 1] convolution kernel: 1*1*1024
        with tf.variable_scope('logits'):
            net = slim.conv2d(net, 128, [1, 1], scope='conv2d_a_1x1')
            kernel_size = inception_v2._reduced_kernel_size_for_small_input(net, [7, 7])
            net = slim.conv2d(net, 128, kernel_size, padding='VALID', scope='conv2d_b_{}x{}'.format(*kernel_size))
            init_biase = tf.constant_initializer([1.1, .0, 1.1, .0] * num_transformer)
            logits = slim.conv2d(net, num_transformer * num_theta_params, [1, 1],
                                 weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
                                 biases_initializer=init_biase,
                                 normalizer_fn=None, activation_fn=tf.nn.tanh, scope='conv2d_c_1x1')

            return tf.squeeze(logits, [1, 2])

    def _inception_logits(inputs, num_outputs, dropout_keep_prob, activ_fn=None):
        with tf.variable_scope('logits'):
            kernel_size = inception_v2._reduced_kernel_size_for_small_input(inputs, [7, 7])
            # shape ?*1*1*?
            net = slim.avg_pool2d(inputs, kernel_size, padding='VALID')
            # drop out neuron before fc conv
            net = slim.dropout(net, keep_prob=dropout_keep_prob, scope='dropout')
            # [1, 1] fc conv
            logits = slim.conv2d(net, num_outputs, [1, 1], normalizer_fn=None, activation_fn=activ_fn,
                                 scope='conv2_a_1x1')

        return tf.squeeze(logits, [1, 2])

    def network_fn(inputs):
        """Fine grained classification with multiplex spatial transformation channels utilizing inception nets

                """
        end_points = {}
        arg_scope = inception_v2.inception_v2_arg_scope(weight_decay=FLAGS.weight_decay)
        with slim.arg_scope(arg_scope):
            with tf.variable_scope('stn'):
                with tf.variable_scope('localization'):
                    transformer_theta = localization_net_alpha(inputs, NUM_TRANSFORMER, NUM_THETA_PARAMS)
                    transformer_theta_split = tf.split(transformer_theta, NUM_TRANSFORMER, axis=1)
                    end_points['stn/localization/transformer_theta'] = transformer_theta

                transformer_outputs = []
                for theta in transformer_theta_split:
                    transformer_outputs.append(
                        transformer(inputs, theta, transformer_output_size, sampling_kernel='bilinear'))

                inception_outputs = []
                transformer_outputs_shape = [FLAGS.batch_size, transformer_output_size[0],
                                             transformer_output_size[1], 3]
                with tf.variable_scope('classification'):
                    for path_idx, inception_inputs in enumerate(transformer_outputs):
                        with tf.variable_scope('path_{}'.format(path_idx)):
                            inception_inputs.set_shape(transformer_outputs_shape)
                            net, _ = inception_v2.inception_v2_base(inception_inputs)
                            inception_outputs.append(net)
                    # concatenate the endpoints: num_batch*7*7*(num_transformer*1024)
                    multipath_outputs = tf.concat(inception_outputs, axis=-1)

                    # final fc layer logits
                    classification_logits = _inception_logits(multipath_outputs, NUM_CLASSES, dropout_keep_prob)
                    end_points['stn/classification/logits'] = classification_logits

        return classification_logits, end_points


    #####################################
    # Select the preprocessing function #
    #####################################
    # preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    # image_preprocessing_fn = preprocessing_factory.get_preprocessing(
    #     preprocessing_name,
    #     is_training=True)

    def image_preprocessing_fn(image, out_height, out_width):
        if image.dtype != tf.float32:
            image = tf.image.convert_image_dtype(image, dtype=tf.float32)
        image = tf.image.central_crop(image, central_fraction=0.975)
        image = tf.expand_dims(image, 0)
        image = tf.image.resize_bilinear(image, [out_height, out_width], align_corners=False)
        image = tf.squeeze(image, [0])
        image = tf.image.random_flip_left_right(image)
        image = tf.subtract(image, 0.5)
        image = tf.multiply(image, 2.0)
        image.set_shape((out_height, out_width, 3))
        return image

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    def _get_filename_list(file_dir, file):
        filename_path = os.path.join(file_dir, file)
        filename_list = []
        cls_label_list = []
        with open(filename_path, 'r') as f:
            for line in f:
                filename, label, nid, attr = line.strip().split(',')
                filename_list.append(filename)
                cls_label_list.append(int(label))

        return filename_list, cls_label_list

    with tf.device(deploy_config.inputs_device()):
      # create the filename and label example
      filename_list, label_list = _get_filename_list(filename_dir, file)
      num_samples = len(filename_list)
      filename, label = tf.train.slice_input_producer([filename_list, label_list], num_epochs)

      # decode and preprocess the image
      file_content = tf.read_file(filename)
      image = tf.image.decode_jpeg(file_content, channels=3)

      train_image_size = FLAGS.train_image_size or default_image_size
      image = image_preprocessing_fn(image, train_image_size, train_image_size)

      images, labels = tf.train.batch(
          [image, label],
          batch_size=FLAGS.batch_size,
          num_threads=FLAGS.num_preprocessing_threads,
          capacity=5 * FLAGS.batch_size)
      labels = slim.one_hot_encoding(
          labels, NUM_CLASSES - FLAGS.labels_offset)
      batch_queue = slim.prefetch_queue.prefetch_queue(
          [images, labels], capacity=2 * deploy_config.num_clones)

    ####################
    # Define the model #
    ####################
    def clone_fn(batch_queue):
      """Allows data parallelism by creating multiple clones of network_fn."""
      with tf.device(deploy_config.inputs_device()):
        images, labels = batch_queue.dequeue()
      logits, end_points = network_fn(images)

      #############################
      # Specify the loss function #
      #############################
      if 'AuxLogits' in end_points:
        tf.losses.softmax_cross_entropy(
            logits=end_points['AuxLogits'], onehot_labels=labels,
            label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss')
      tf.losses.softmax_cross_entropy(
          logits=logits, onehot_labels=labels,
          label_smoothing=FLAGS.label_smoothing, weights=1.0)
      return end_points

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
    first_clone_scope = deploy_config.clone_scope(0)
    # Gather update_ops from the first clone. These contain, for example,
    # the updates for the batch_norm variables created by network_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Add summaries for end_points.
    end_points = clones[0].outputs
    for end_point in end_points:
      x = end_points[end_point]
      summaries.add(tf.summary.histogram('activations/' + end_point, x))
      summaries.add(tf.summary.scalar('sparsity/' + end_point,
                                      tf.nn.zero_fraction(x)))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Add summaries for variables.
    for variable in slim.get_model_variables():
      summaries.add(tf.summary.histogram(variable.op.name, variable))

    #################################
    # Configure the moving averages #
    #################################
    if FLAGS.moving_average_decay:
      moving_average_variables = slim.get_model_variables()
      variable_averages = tf.train.ExponentialMovingAverage(
          FLAGS.moving_average_decay, global_step)
    else:
      moving_average_variables, variable_averages = None, None

    #########################################
    # Configure the optimization procedure. #
    #########################################
    with tf.device(deploy_config.optimizer_device()):
      learning_rate_loc = _configure_learning_rate_loc(num_samples, global_step)
      learning_rate_cls = _configure_learning_rate_cls(num_samples, global_step)
      optimizer_loc = _configure_optimizer(learning_rate_loc)
      optimizer_cls = _configure_optimizer(learning_rate_cls)
      summaries.add(tf.summary.scalar('learning_rate_loc', learning_rate_loc))
      summaries.add(tf.summary.scalar('learning_rate_cls', learning_rate_cls))

    if FLAGS.sync_replicas:
      # If sync_replicas is enabled, the averaging will be done in the chief
      # queue runner.
      optimizer_loc = tf.train.SyncReplicasOptimizer(
          opt=optimizer_loc,
          replicas_to_aggregate=FLAGS.replicas_to_aggregate,
          variable_averages=variable_averages,
          variables_to_average=moving_average_variables,
          replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
          total_num_replicas=FLAGS.worker_replicas)
      optimizer_cls = tf.train.SyncReplicasOptimizer(
          opt=optimizer_cls,
          replicas_to_aggregate=FLAGS.replicas_to_aggregate,
          variable_averages=variable_averages,
          variables_to_average=moving_average_variables,
          replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
          total_num_replicas=FLAGS.worker_replicas)
    elif FLAGS.moving_average_decay:
      # Update ops executed locally by trainer.
      update_ops.append(variable_averages.apply(moving_average_variables))

    # Variables to train.
    # variables_to_train = _get_variables_to_train()

    loc_vars_to_train = _get_localization_vars_to_train(loc_train_vars_scope)
    cls_vars_to_train = _get_classification_vars_to_train(cls_train_vars_scope)

    #  and returns a train_tensor and summary_op
    _, clones_gradients_loc = model_deploy.optimize_clones(
        clones,
        optimizer_loc,
        var_list=loc_vars_to_train)
    total_loss, clones_gradients_cls = model_deploy.optimize_clones(
        clones,
        optimizer_cls,
        var_list=cls_vars_to_train)

    # Add total_loss to summary.
    summaries.add(tf.summary.scalar('total_loss', total_loss))

    # Create gradient updates.
    grad_updates_loc = optimizer_loc.apply_gradients(clones_gradients_loc)
    grad_updates_cls = optimizer_cls.apply_gradients(clones_gradients_cls, global_step=global_step)
    update_ops.append(grad_updates_loc)
    update_ops.append(grad_updates_cls)

    update_op = tf.group(*update_ops)
    with tf.control_dependencies([update_op]):
      train_tensor = tf.identity(total_loss, name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                       first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries), name='summary_op')


    ###########################
    # Kicks off the training. #
    ###########################
    slim.learning.train(
        train_tensor,
        logdir=FLAGS.train_dir,
        master=FLAGS.master,
        is_chief=(FLAGS.task == 0),
        init_fn=_get_init_fn(),
        summary_op=summary_op,
        number_of_steps=FLAGS.max_number_of_steps,
        log_every_n_steps=FLAGS.log_every_n_steps,
        save_summaries_secs=FLAGS.save_summaries_secs,
        save_interval_secs=FLAGS.save_interval_secs,
        sync_optimizer=None)
Esempio n. 6
0
  def main(self):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Set session_config to allow some operations to be run on cpu.
    session_config = tf.ConfigProto(allow_soft_placement=True, )

    with tf.Graph().as_default():
      ######################
      # Select the dataset #
      ######################
      dataset = self._select_dataset()

      ######################
      # Select the network #
      ######################
      networks = self._select_network()

      #####################################
      # Select the preprocessing function #
      #####################################
      image_preprocessing_fn = self._select_image_preprocessing_fn()

      #######################
      # Config model_deploy #
      #######################
      deploy_config = model_deploy.DeploymentConfig(
        num_clones=FLAGS.num_clones,
        clone_on_cpu=FLAGS.clone_on_cpu,
        replica_id=FLAGS.task,
        num_replicas=FLAGS.worker_replicas,
        num_ps_tasks=FLAGS.num_ps_tasks)

      global_step = slim.create_global_step()

      ##############################################################
      # Create a dataset provider that loads data from the dataset #
      ##############################################################
      data = self._prepare_data(dataset, image_preprocessing_fn, deploy_config, )
      data_batched = self._get_batch(data)
      batch_names = data_batched.keys()
      batch = data_batched.values()

      ###############
      # Is Training #
      ###############
      if FLAGS.is_training:
        if not os.path.isdir(FLAGS.train_dir):
          util_io.touch_folder(FLAGS.train_dir)
        if not os.path.exists(os.path.join(FLAGS.train_dir, FLAGS_FILE_NAME)):
          FLAGS.append_flags_into_file(os.path.join(FLAGS.train_dir, FLAGS_FILE_NAME))

        try:
          batch_queue = slim.prefetch_queue.prefetch_queue(
            batch, capacity=4 * deploy_config.num_clones)
        except ValueError as e:
          tf.logging.warning('Cannot use batch_queue due to error %s', e)
          batch_queue = batch
        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, self._clone_fn,
                                            GeneralModel._dtype_string_to_dtype(FLAGS.variable_dtype),
                                            [networks, batch_queue, batch_names],
                                            {'global_step': global_step,
                                             'is_training': FLAGS.is_training})
        first_clone_scope = deploy_config.clone_scope(0)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        self._end_points_for_debugging = end_points
        self._add_end_point_summaries(end_points, summaries)
        # Add summaries for images, if there are any.
        self._add_image_summaries(end_points, summaries)
        # Add summaries for losses.
        self._add_loss_summaries(first_clone_scope, summaries, end_points)
        # Add summaries for variables.
        for variable in slim.get_model_variables():
          summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
          moving_average_variables = slim.get_model_variables()
          variable_averages = tf.train.ExponentialMovingAverage(
            FLAGS.moving_average_decay, global_step)
        else:
          moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by generator_network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

        with tf.device(deploy_config.optimizer_device()):
          learning_rate = self._configure_learning_rate(self.num_samples, global_step)
          optimizer = self._configure_optimizer(learning_rate)

        if FLAGS.sync_replicas:
          # If sync_replicas is enabled, the averaging will be done in the chief
          # queue runner.
          optimizer = tf.train.SyncReplicasOptimizer(
            opt=optimizer,
            replicas_to_aggregate=FLAGS.replicas_to_aggregate,
            total_num_replicas=FLAGS.worker_replicas,
            variable_averages=variable_averages,
            variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
          # Update ops executed locally by trainer.
          update_ops.append(variable_averages.apply(moving_average_variables))

        summaries.add(tf.summary.scalar('learning_rate', learning_rate))
        # Define optimization process.
        train_tensor = self._add_optimization(clones, optimizer, summaries, update_ops, global_step)

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                           first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # Define train_step with eval every `eval_every_n_steps`.
        def train_step_fn(session, *args, **kwargs):
          self.do_extra_train_step(session, end_points, global_step)
          total_loss, should_stop = slim.learning.train_step(session, *args, **kwargs)
          return [total_loss, should_stop]

        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(
          train_tensor,
          train_step_fn=train_step_fn,
          logdir=FLAGS.train_dir,
          master=FLAGS.master,
          is_chief=(FLAGS.task == 0),
          init_fn=self._get_init_fn(FLAGS.checkpoint_path, FLAGS.checkpoint_exclude_scopes),
          summary_op=summary_op,
          number_of_steps=FLAGS.max_number_of_steps,
          log_every_n_steps=FLAGS.log_every_n_steps,
          save_summaries_secs=FLAGS.save_summaries_secs,
          save_interval_secs=FLAGS.save_interval_secs,
          sync_optimizer=optimizer if FLAGS.sync_replicas else None,
          session_config=session_config)
      ##########################
      # Eval, Export or Output #
      ##########################
      else:
        # Write flags file.
        if not os.path.isdir(FLAGS.eval_dir):
          util_io.touch_folder(FLAGS.eval_dir)
        if not os.path.exists(os.path.join(FLAGS.eval_dir, FLAGS_FILE_NAME)):
          FLAGS.append_flags_into_file(os.path.join(FLAGS.eval_dir, FLAGS_FILE_NAME))

        with tf.variable_scope(tf.get_variable_scope(),
                               custom_getter=model_deploy.get_custom_getter(
                                 GeneralModel._dtype_string_to_dtype(FLAGS.variable_dtype)),
                               reuse=False):
          end_points = self._clone_fn(networks, batch_queue=None, batch_names=batch_names, data_batched=data_batched,
                                      is_training=False, global_step=global_step)

        num_batches = int(math.ceil(self.num_samples / float(FLAGS.batch_size)))

        checkpoint_path = util_misc.get_latest_checkpoint_path(FLAGS.checkpoint_path)

        if FLAGS.moving_average_decay:
          variable_averages = tf.train.ExponentialMovingAverage(
            FLAGS.moving_average_decay, global_step)
          variables_to_restore = variable_averages.variables_to_restore(
            slim.get_model_variables())
          variables_to_restore[global_step.op.name] = global_step
        else:
          variables_to_restore = slim.get_variables_to_restore()

        saver = None
        if variables_to_restore is not None:
          saver = tf.train.Saver(variables_to_restore)

        session_creator = tf.train.ChiefSessionCreator(
          scaffold=tf.train.Scaffold(saver=saver),
          checkpoint_filename_with_path=checkpoint_path,
          master=FLAGS.master,
          config=session_config)

        ##########
        # Output #
        ##########
        if FLAGS.do_output:
          tf.logging.info('Output mode.')
          output_ops = self._maybe_encode_output_tensor(self._define_outputs(end_points, data_batched))
          start_time = time.time()
          with tf.train.MonitoredSession(
              session_creator=session_creator) as session:
            for i in range(num_batches):
              output_results = session.run([item[-1] for item in output_ops])
              self._write_outputs(output_results, output_ops)
              if i % FLAGS.log_every_n_steps == 0:
                current_time = time.time()
                speed = (current_time - start_time) / (i + 1)
                time_left = speed * (num_batches - i + 1)
                tf.logging.info('%d / %d done. Time left: %f', i + 1, num_batches, time_left)


        ################
        # Export Model #
        ################
        elif FLAGS.do_export:
          tf.logging.info('Exporting trained model to %s', FLAGS.export_path)
          with tf.Session(config=session_config) as session:
            saver.restore(session, checkpoint_path)
            builder = tf.saved_model.builder.SavedModelBuilder(FLAGS.export_path)
            signature_def_map = self._build_signature_def_map(end_points, data_batched)
            assets_collection = self._build_assets_collection(end_points, data_batched)
            legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
            builder.add_meta_graph_and_variables(
              session, [tf.saved_model.tag_constants.SERVING],
              signature_def_map=signature_def_map,
              legacy_init_op=legacy_init_op,
              assets_collection=assets_collection,
            )
          builder.save()
          tf.logging.info('Done exporting!')

        ########
        # Eval #
        ########
        else:
          tf.logging.info('Eval mode.')
          # Add summaries for images, if there are any.
          self._add_image_summaries(end_points, None)

          # Define the metrics:
          metric_map = self._define_eval_metrics(end_points, data_batched)

          names_to_values, names_to_updates = slim.metrics.aggregate_metric_map(metric_map)
          names_to_values = collections.OrderedDict(**names_to_values)
          names_to_updates = collections.OrderedDict(**names_to_updates)

          # Print the summaries to screen.
          for name, value in names_to_values.items():
            summary_name = 'eval/%s' % name
            if len(value.shape):
              op = tf.summary.tensor_summary(summary_name, value, collections=[])
            else:
              op = tf.summary.scalar(summary_name, value, collections=[])
            op = tf.Print(op, [value], summary_name)
            tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)

          if not (FLAGS.do_eval_debug or FLAGS.do_custom_eval):
            tf.logging.info('Evaluating %s' % checkpoint_path)

            slim.evaluation.evaluate_once(
              master=FLAGS.master,
              checkpoint_path=checkpoint_path,
              logdir=FLAGS.eval_dir,
              num_evals=num_batches,
              eval_op=list(names_to_updates.values()),
              variables_to_restore=variables_to_restore,
              session_config=session_config)
            return

          ################################
          # `do_eval_debug` flag is true.#
          ################################
          if FLAGS.do_eval_debug:
            eval_ops = list(names_to_updates.values())
            eval_names = list(names_to_updates.keys())

            # Items to write to a html page.
            encode_ops = self._maybe_encode_output_tensor(self.get_items_to_encode(end_points, data_batched))

            with tf.train.MonitoredSession(session_creator=session_creator) as session:
              if eval_ops is not None:
                for i in range(num_batches):
                  eval_result = session.run(eval_ops, None)
                  print('; '.join(('%s:%s' % (name, str(eval_result[i])) for i, name in enumerate(eval_names))))

              # Write to HTML
              if encode_ops:
                for i in range(num_batches):
                  encode_ops_feed_dict = self._get_encode_op_feed_dict(end_points, encode_ops, i)
                  encoded_items = session.run([item[-1] for item in encode_ops], encode_ops_feed_dict)
                  encoded_list = []
                  for j in range(len(encoded_items)):
                    encoded_list.append((encode_ops[j][0], encode_ops[j][1], encoded_items[j].tolist()))

                  eval_items = self.save_images(encoded_list, os.path.join(FLAGS.eval_dir, 'images'))
                  eval_items = self.to_human_friendly(eval_items, )
                  self._write_eval_html(eval_items)
                  if i % 10 == 0:
                    tf.logging.info('%d/%d' % (i, num_batches))
          if FLAGS.do_custom_eval:
            extra_eval = self._define_extra_eval_actions(end_points, data_batched)
            with tf.train.MonitoredSession(session_creator=session_creator) as session:
              self._do_extra_eval_actions(session, extra_eval)
Esempio n. 7
0
def train():
    dataset_dir = cfg.FLAGS.TFrecord_dir
    num_classes = cfg.FLAGS.num_classes
    batch_size = cfg.FLAGS.batch_size
    num_food_images = 75750

    meta_dir = 'data/food-101/meta/'
    fp = open(meta_dir + 'classes.txt', 'r')

    classes = [l.strip() for l in fp.readlines()]

    # Create global_step
    with tf.device('/cpu:0'):
        global_step = slim.create_global_step()

    """ load data """
    with tf.device('/cpu:0'):
        image, target_labels = datasets.get_dataset('train', dataset_dir, True, batch_size)

    labels_onehot = tf.one_hot(target_labels, depth=num_classes, on_value=1.0, off_value=0.0)

    nets = model.MobileNet(image,
                           num_classes=num_classes,
                           is_training=True,
                           weight_decay=cfg.FLAGS.weight_decay)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    logits = nets.end_points['logits']
    print(target_labels.shape, labels_onehot.shape, logits.shape)
    tf.losses.softmax_cross_entropy(onehot_labels=labels_onehot,
                                                    logits=logits, weights=1.0)

    """ compute total loss """
    all_losses = []
    cross_entropy = tf.get_collection(tf.GraphKeys.LOSSES)
    cross_loss = tf.add_n(cross_entropy, name='cross_loss')
    all_losses.append(cross_loss)

    regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    regularization_loss = tf.add_n(regularization_losses, name='sum_regularization_loss')
    all_losses.append(regularization_loss)

    total_loss = tf.add_n(all_losses)


    """ Configure the optimization procedure. """
    with tf.device('/cpu:0'):
        learning_rate = _get_learning_rate(num_food_images, global_step)
        optimizer = tf.train.RMSPropOptimizer(learning_rate,
                                              decay=cfg.FLAGS.rmsprop_decay,
                                              momentum=cfg.FLAGS.rmsprop_momentum,
                                              epsilon=cfg.FLAGS.opt_epsilon)

    """ Variables to train """
    train_vars = tf.trainable_variables()
    grad_op = optimizer.minimize(total_loss, global_step=global_step, var_list=train_vars)
    update_ops.append(grad_op)
    update_op = tf.group(*update_ops)

    """ estimate Accurancy """
    predictions = tf.argmax(logits, 1)
    labels = tf.squeeze(target_labels)

    # Define the metrics:
    names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
        'Accuracy': slim.metrics.streaming_accuracy(predictions, labels),
        'Recall_5': slim.metrics.streaming_recall_at_k(
            logits, labels, 5),
    })

    pred_cls = tf.cast(tf.argmax(nets.end_points['predictions'], axis=1), tf.int32)
    correct_predictio = tf.equal(pred_cls, target_labels)
    accurancy = tf.reduce_mean(tf.cast(correct_predictio, tf.float32))

    """ set Summary and log info """
    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    """ Add summaries for total loss """
    summaries.add(tf.summary.scalar('total_loss', total_loss))

    """ Add summaries for accurancy. """
    for name, value in names_to_values.iteritems():
        summary_name = 'eval/%s' % name
        op = tf.summary.scalar(summary_name, value, collections=[])
        op = tf.Print(op, [value], summary_name)
        tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)

    """ Add summaries for end points """
    for i in nets.end_points:
        x = nets.end_points[i]
        summaries.add(tf.summary.histogram('activations/' + i, x))
        summaries.add(tf.summary.scalar('sparsity/' + i,
                                        tf.nn.zero_fraction(x)))

    """ Add summaries for losses. """
    for loss in tf.get_collection(tf.GraphKeys.LOSSES):
        summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    tf.summary.scalar('regularization_loss', regularization_loss)

    """ Add summaries for variables. """
    for variable in slim.get_model_variables():
        summaries.add(tf.summary.histogram(variable.op.name, variable))

    summaries.add(tf.summary.scalar('learning_rate', learning_rate))

    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES))
    summary_op = tf.summary.merge(list(summaries), name='summary_op')
    logdir = os.path.join(cfg.FLAGS.train_dir, strftime('%Y%m%d%H%M%S', gmtime()))
    if not os.path.exists(logdir):
        os.makedirs(logdir)
    summary_writer = tf.summary.FileWriter(logdir, graph=tf.Session().graph)


    """ create saver and initialize variables """
    saver = tf.train.Saver(max_to_keep=20)
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    with tf.Session(config=tf.ConfigProto(log_device_placement=False)) as sess:
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        try:
            print ('Evaluating...')
            while not coord.should_stop():
                current_step = sess.run(global_step)
                _, losses, acc, gt_lables, pred_labels = sess.run([update_op, total_loss, accurancy, target_labels, pred_cls])
                print(""" iter %d: total_loss %.4f, accuracy %.4f """ %(current_step, losses, acc))

                """ write summary """
                if current_step % 500 == 0:
                    """ write summary """
                    summary = sess.run(summary_op)
                    summary_writer.add_summary(summary, current_step)
                    print('gt_lables', gt_lables)
                    print('pred_labels', pred_labels)


                """ save trained model datas """
                if current_step % 1000 == 0:
                    saver.save(sess, cfg.FLAGS.training_model, global_step=current_step)

                if current_step == cfg.FLAGS.max_iters:
                    print('step is reached the maximum iteration')
                    print('Done training!!!!!')
        except tf.errors.OutOfRangeError:
            print('Error is occured and stop the training')
        finally:
            saver.save(sess, cfg.FLAGS.checkpoint_model, write_meta_graph=False)
            coord.request_stop()

        coord.join(threads)

    print('Done.')
Esempio n. 8
0
def train():
    model = tclstm()
    training_steps = 50000
    display_step = tcopts['display_step']

    with tf.device('cpu:0'):
        global_step = slim.create_global_step()
        # tf Graph input
        maps = tf.placeholder("float", [None, 19, 19, 1])
        map_logits = model.map_net(maps)
        X = tf.placeholder(
            "float", [None, tcopts['time_steps'], tcopts['lstm_num_input']])
        Inputs = tf.concat((X, map_logits), axis=2)
        Y = tf.placeholder("float", [None, tcopts['lstm_num_classes']])
        lrOp = tf.train.exponential_decay(tcopts['lstm_initial_lr'],
                                          global_step,
                                          tcopts['lstm_decay_steps'],
                                          tcopts['lr_decay_factor'],
                                          staircase=True)
        logits, _ = model.net(Inputs)
        # optimizer = tf.train.GradientDescentOptimizer(learning_rate=lrOp)
        optimizer = tf.train.MomentumOptimizer(learning_rate=lrOp,
                                               momentum=0.9)
        loss_op = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=Y))
        prediction = tf.nn.softmax(logits)
        grads = optimizer.compute_gradients(loss_op)
        train_op = optimizer.apply_gradients(grads, global_step=global_step)
        # Evaluate model (with test logits, for dropout to be disabled)
        correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
        test_accuracy = accuracy

        # Initialize the variables (i.e. assign their default value)
        init = tf.global_variables_initializer()

        pos_data, neg_data = load_training_data('pos_data.npy', 'neg_data.npy')
        test_pos_data, test_neg_data = load_training_data(
            'test_pos_data.npy', 'test_neg_data.npy')
        test_pos_num, test_neg_num, test_batch_X_input, test_batch_map_input, test_labels = get_test_batch_input(
            test_pos_data, test_neg_data)

        test_accuracy_pos = tf.reduce_mean(
            tf.cast(correct_pred[:test_pos_num], tf.float32))
        test_accuracy_neg = tf.reduce_mean(
            tf.cast(correct_pred[test_pos_num:], tf.float32))
        accuracy_pos = tf.reduce_mean(
            tf.cast(correct_pred[:tcopts['batch_size']], tf.float32))
        accuracy_neg = tf.reduce_mean(
            tf.cast(correct_pred[tcopts['batch_size']:], tf.float32))

        saver = tf.train.Saver(keep_checkpoint_every_n_hours=tcopts[
            'keep_checkpoint_every_n_hours'])
        # add summary
        tf.summary.scalar('learning_rate', lrOp)
        tf.summary.scalar('loss', loss_op)
        tf.summary.scalar('training_accuracy', accuracy)
        tf.summary.scalar('training_accuracy_pos', accuracy_pos)
        tf.summary.scalar('training_accuracy_neg', accuracy_neg)
        # grads
        for grad, var in grads:
            if grad is not None:
                tf.summary.histogram(var.op.name + '/gradients', grad)
        # # trainable var
        for var in tf.trainable_variables():
            tf.summary.histogram(var.op.name, var)

        summary_op = tf.summary.merge_all()
        tf.summary.scalar('testing_accuracy_pos', test_accuracy_pos)
        tf.summary.scalar('testing_accuracy_neg', test_accuracy_neg)
        tf.summary.scalar('testing_accuracy', test_accuracy)
        test_merge_summary = tf.summary.merge(
            [tf.get_collection(tf.GraphKeys.SUMMARIES, 'testing_accuracy')])

        summary_writer = tf.summary.FileWriter(tcopts['lstm_train_dir'],
                                               graph=tf.get_default_graph())
        tfconfig = tf.ConfigProto()
        # tfconfig.gpu_options.per_process_gpu_memory_fraction = 1
        # Start training
        with tf.Session(config=tfconfig) as sess:

            # Run the initializer
            sess.run(init)
            checkpoint = tf.train.latest_checkpoint(tcopts['lstm_train_dir'])
            if checkpoint is not None:
                saver.restore(sess, checkpoint)
            best_model = (tcopts['save_neg_thr'] + tcopts['save_pos_thr']) / 2
            while True:
                batch_X_input, batch_map_input, labels = get_batch_input(
                    pos_data, neg_data, tcopts['batch_size'])
                # Reshape data to get 28 seq of 28 elements
                # Run optimization op (backprop)
                _, g_step = sess.run(
                    [train_op, global_step],
                    feed_dict={
                        X: batch_X_input[:, :, tcopts['lstm_input']],
                        Y: labels,
                        maps: batch_map_input
                    })
                if g_step % display_step == 0:
                    # Calculate batch loss and accuracy
                    loss, acc, acc_pos, acc_neg, summary_str = sess.run(
                        [
                            loss_op, accuracy, accuracy_pos, accuracy_neg,
                            summary_op
                        ],
                        feed_dict={
                            X: batch_X_input[:, :, tcopts['lstm_input']],
                            Y: labels,
                            maps: batch_map_input
                        })
                    summary_writer.add_summary(summary_str, g_step)
                    print("Step " + str(g_step) + ", Minibatch Loss= " + \
                          "{:.4f}".format(loss) + ", Training Accuracy= " + \
                          "{:.3f}".format(acc) + ", Accuracy pos= " + "{:.3f}".format(
                        acc_pos) + ", Accuracy neg= " + "{:.3f}".format(acc_neg))

                if g_step % tcopts['model_save_interval'] == 0:
                    checkpoint_path = os.path.join(tcopts['lstm_train_dir'],
                                                   'lstm_model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=g_step)
                    print('Save model, global step: %d' % g_step)
                if g_step % tcopts['eval_interval'] == 0:
                    test_pos_num, test_neg_num, test_batch_X_input, test_batch_map_input, test_labels = get_test_batch_input(
                        test_pos_data, test_neg_data)
                    test_acc, test_acc_pos, test_acc_neg, test_summary_str = sess.run(
                        [
                            test_accuracy, test_accuracy_pos,
                            test_accuracy_neg, test_merge_summary
                        ],
                        feed_dict={
                            X: test_batch_X_input[:, :, tcopts['lstm_input']],
                            Y: test_labels,
                            maps: test_batch_map_input
                        })
                    summary_writer.add_summary(test_summary_str, g_step)
                    print("test accuracy:" + "{:.4f}".format(test_acc) +
                          "  test accuracy pos:" +
                          "{:.4f}".format(test_acc_pos) +
                          "  test accuracy neg:" +
                          "{:.4f}".format(test_acc_neg))
                    if test_acc_pos > tcopts[
                            'save_pos_thr'] and test_acc_neg > tcopts[
                                'save_neg_thr'] and (
                                    test_acc_pos +
                                    test_acc_neg) / 2 > best_model:
                        best_model = (test_acc_pos + test_acc_neg) / 2
                        checkpoint_path = os.path.join(
                            tcopts['lstm_train_dir'], 'lstm_model.ckpt')
                        saver.save(sess, checkpoint_path, global_step=g_step)
                        print('Save model, global step: %d' % g_step)
Esempio n. 9
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')
    # Sets the threshold for what messages will be logged. (DEBUG / INFO / WARN / ERROR / FATAL)
    tf.logging.set_verbosity(tf.logging.DEBUG)

    with tf.Graph().as_default():
        # Config model_deploy. Keep TF Slim Models structure.
        # Useful if want to need multiple GPUs and/or servers in the future.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=0,
            num_replicas=1,
            num_ps_tasks=0)

        # Create global_step, the training iteration counter.
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        # Select the dataset.
        dataset = TFrecords2Dataset.get_datasets(FLAGS.dataset_dir)

        # Get the TextBoxes++ network and its anchors.
        text_net = txtbox_384.TextboxNet()

        # Stage 2 training using the 768x768 input size.
        if FLAGS.large_training:
            # replace the input image shape and the extracted feature map size from each indicated layer which
            #associated to each textbox layer.
            text_net.params = text_net.params._replace(img_shape=(768, 768))
            text_net.params = text_net.params._replace(
                feat_shapes=[(96, 96), (48, 48), (24, 24), (12,
                                                            12), (10,
                                                                  10), (8, 8)])

        img_shape = text_net.params.img_shape
        print('img_shape: ' + str(img_shape))

        # Compute the default anchor boxes with the given image shape, get anchor list.
        text_anchors = text_net.anchors(img_shape)

        # Print the training configuration before training.
        tf_utils.print_configuration(FLAGS.__flags, text_net.params,
                                     dataset.data_sources, FLAGS.train_dir)

        # =================================================================== #
        # Create a dataset provider and batches.
        # =================================================================== #
        with tf.device(deploy_config.inputs_device()):
            # setting the dataset provider
            with tf.name_scope(FLAGS.dataset_name + '_data_provider'):
                provider = slim.dataset_data_provider.DatasetDataProvider(
                    dataset,
                    num_readers=FLAGS.num_readers,
                    common_queue_capacity=1000 * FLAGS.batch_size,
                    common_queue_min=300 * FLAGS.batch_size,
                    shuffle=True)
            # Get for SSD network: image, labels, bboxes.
            [image, shape, glabels, gbboxes, x1, x2, x3, x4, y1, y2, y3,
             y4] = provider.get([
                 'image', 'shape', 'object/label', 'object/bbox',
                 'object/oriented_bbox/x1', 'object/oriented_bbox/x2',
                 'object/oriented_bbox/x3', 'object/oriented_bbox/x4',
                 'object/oriented_bbox/y1', 'object/oriented_bbox/y2',
                 'object/oriented_bbox/y3', 'object/oriented_bbox/y4'
             ])
            gxs = tf.transpose(tf.stack([x1, x2, x3, x4]))  #shape = (N,4)
            gys = tf.transpose(tf.stack([y1, y2, y3, y4]))
            image = tf.identity(image, 'input_image')
            init_op = tf.global_variables_initializer()
            # tf.global_variables_initializer()

            # Pre-processing image, labels and bboxes.
            training_image_crop_area = FLAGS.training_image_crop_area
            area_split = training_image_crop_area.split(',')
            assert len(area_split) == 2
            training_image_crop_area = [
                float(area_split[0]),
                float(area_split[1])
            ]

            image, glabels, gbboxes, gxs, gys= \
                ssd_vgg_preprocessing.preprocess_for_train(image, glabels, gbboxes, gxs, gys,
                                                        img_shape,
                                                        data_format='NHWC', crop_area_range=training_image_crop_area)

            # Encode groundtruth labels and bboxes.
            image = tf.identity(image, 'processed_image')

            glocalisations, gscores, glabels = \
                text_net.bboxes_encode( glabels, gbboxes, text_anchors, gxs, gys)
            batch_shape = [1] + [len(text_anchors)] * 3

            # Training batches and queue.
            r = tf.train.batch(tf_utils.reshape_list(
                [image, glocalisations, gscores, glabels]),
                               batch_size=FLAGS.batch_size,
                               num_threads=FLAGS.num_preprocessing_threads,
                               capacity=5 * FLAGS.batch_size)

            b_image, b_glocalisations, b_gscores, b_glabels= \
                tf_utils.reshape_list(r, batch_shape)

            # Intermediate queueing: unique batch computation pipeline for all
            # GPUs running the training.
            batch_queue = slim.prefetch_queue.prefetch_queue(
                tf_utils.reshape_list(
                    [b_image, b_glocalisations, b_gscores, b_glabels]),
                capacity=2 * deploy_config.num_clones)

        # =================================================================== #
        # Define the model running on every GPU.
        # =================================================================== #
        def clone_fn(batch_queue):

            #Allows data parallelism by creating multiple
            #clones of network_fn.
            # Dequeue batch.
            b_image, b_glocalisations, b_gscores, b_glabels = \
                tf_utils.reshape_list(batch_queue.dequeue(), batch_shape)

            # Construct TextBoxes network.
            arg_scope = text_net.arg_scope(weight_decay=FLAGS.weight_decay)
            with slim.arg_scope(arg_scope):
                predictions,localisations, logits, end_points = \
                    text_net.net(b_image, is_training=True)
            # Add loss function.

            text_net.losses(logits,
                            localisations,
                            b_glabels,
                            b_glocalisations,
                            b_gscores,
                            match_threshold=FLAGS.match_threshold,
                            negative_ratio=FLAGS.negative_ratio,
                            alpha=FLAGS.loss_alpha,
                            label_smoothing=FLAGS.label_smoothing,
                            batch_size=FLAGS.batch_size)
            return end_points

        # Gather initial tensorboard summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # =================================================================== #
        # Add summaries from first clone.
        # =================================================================== #
        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))
        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES):
            summaries.add(tf.summary.scalar(loss.op.name, loss))
        # Add summaries for extra losses.
        for loss in tf.get_collection('EXTRA_LOSSES'):
            summaries.add(tf.summary.scalar(loss.op.name, loss))
        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        # =================================================================== #
        # Configure the moving averages.
        # =================================================================== #
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        # =================================================================== #
        # Configure the optimization procedure.
        # =================================================================== #
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = tf_utils.configure_learning_rate(
                FLAGS, dataset.num_samples, global_step)
            optimizer = tf_utils.configure_optimizer(FLAGS, learning_rate)
            # Add summaries for learning_rate.
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = tf_utils.get_variables_to_train(FLAGS)

        # and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)

        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)
        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op],
                                                          total_loss,
                                                          name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # =================================================================== #
        # Kicks off the training.
        # =================================================================== #
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)

        config = tf.ConfigProto(log_device_placement=False,
                                allow_soft_placement=True,
                                gpu_options=gpu_options)

        saver = tf.train.Saver(max_to_keep=100,
                               keep_checkpoint_every_n_hours=1.0,
                               write_version=2,
                               pad_step_number=False)

        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            master='',
            is_chief=True,
            # init_op=init_op,
            init_fn=tf_utils.get_init_fn(FLAGS),
            summary_op=summary_op,  ##output variables to logdir
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            saver=saver,
            save_interval_secs=FLAGS.save_interval_secs,
            session_config=config,
            sync_optimizer=None)
Esempio n. 10
0
def train():
    ret = build_model()
    outputs, gt_masks, total_loss, regular_loss, img_id, losses, gt_boxes, batch_info, input_image, \
            final_box, final_cls, final_prob, final_gt_cls, gt, tmp_0, tmp_1, tmp_2, tmp_3, tmp_4 = ret

    ## solvers
    global_step = slim.create_global_step()
    update_op = solve(global_step)

    cropped_rois = tf.get_collection('__CROPPED__')[0]
    transposed = tf.get_collection('__TRANSPOSED__')[0]

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.95,
                                allow_growth=True)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    sess.run(init_op)

    summary_op = tf.summary.merge_all()
    logdir = os.path.join(FLAGS.train_dir, strftime('%Y%m%d%H%M%S', gmtime()))
    if not os.path.exists(logdir):
        os.makedirs(logdir)
    summary_writer = tf.summary.FileWriter(logdir, graph=sess.graph)

    ## restore
    restore(sess)

    ## main loop
    coord = tf.train.Coordinator()
    threads = []
    # print (tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS))
    for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
        threads.extend(
            qr.create_threads(sess, coord=coord, daemon=True, start=True))

    tf.train.start_queue_runners(sess=sess, coord=coord)
    saver = tf.train.Saver(max_to_keep=20)

    for step in range(FLAGS.max_iters):

        start_time = time.time()

        s_, tot_loss, reg_lossnp, img_id_str, \
        rpn_box_loss, rpn_cls_loss, refined_box_loss, refined_cls_loss, mask_loss, \
        gt_boxesnp, \
        rpn_batch_pos, rpn_batch, refine_batch_pos, refine_batch, mask_batch_pos, mask_batch, \
        input_imagenp, final_boxnp, final_clsnp, final_probnp, final_gt_clsnp, gtnp, tmp_0np, tmp_1np, tmp_2np, tmp_3np, tmp_4np= \
                     sess.run([update_op, total_loss, regular_loss, img_id] +
                              losses +
                              [gt_boxes] +
                              batch_info +
                              [input_image] + [final_box] + [final_cls] + [final_prob] + [final_gt_cls] + [gt] + [tmp_0] + [tmp_1] + [tmp_2] + [tmp_3] + [tmp_4])

        duration_time = time.time() - start_time
        if step % 1 == 0:
            print(
                """iter %d: image-id:%07d, time:%.3f(sec), regular_loss: %.6f, """
                """total-loss %.4f(%.4f, %.4f, %.6f, %.4f, %.4f), """
                """instances: %d, """
                """batch:(%d|%d, %d|%d, %d|%d)""" %
                (step, img_id_str, duration_time, reg_lossnp, tot_loss,
                 rpn_box_loss, rpn_cls_loss, refined_box_loss,
                 refined_cls_loss, mask_loss, gt_boxesnp.shape[0],
                 rpn_batch_pos, rpn_batch, refine_batch_pos, refine_batch,
                 mask_batch_pos, mask_batch))

            # draw_bbox(step,
            #           np.uint8((np.array(input_imagenp[0])/2.0+0.5)*255.0),
            #           name='est',
            #           bbox=final_boxnp,
            #           label=final_clsnp,
            #           prob=final_probnp,
            #           gt_label=np.argmax(np.asarray(final_gt_clsnp),axis=1),
            #           )

            # draw_bbox(step,
            #           np.uint8((np.array(input_imagenp[0])/2.0+0.5)*255.0),
            #           name='gt',
            #           bbox=gtnp[:,0:4],
            #           label=np.asarray(gtnp[:,4], dtype=np.uint8),
            #           )

            print("labels")
            # print (cat_id_to_cls_name(np.unique(np.argmax(np.asarray(final_gt_clsnp),axis=1)))[1:])
            # print (cat_id_to_cls_name(np.unique(np.asarray(gt_boxesnp, dtype=np.uint8)[:,4])))
            print(
                cat_id_to_cls_name(
                    np.unique(np.argmax(np.asarray(tmp_3np), axis=1)))[1:])
            #print (cat_id_to_cls_name(np.unique(np.argmax(np.asarray(gt_boxesnp)[:,4],axis=1))))
            print("classes")
            print(
                cat_id_to_cls_name(
                    np.unique(np.argmax(np.array(tmp_4np), axis=1))))
            # print (np.asanyarray(tmp_3np))

            #print ("ordered rois")
            #print (np.asarray(tmp_0np)[0])
            #print ("pyramid_feature")
            #print ()
            #print(np.unique(np.argmax(np.array(final_probnp),axis=1)))
            #for var, val in zip(tmp_2, tmp_2np):
            #    print(var.name)
            #print(np.argmax(np.array(tmp_0np),axis=1))

            if np.isnan(tot_loss) or np.isinf(tot_loss):
                print(gt_boxesnp)
                raise

        if step % 100 == 0:
            summary_str = sess.run(summary_op)
            summary_writer.add_summary(summary_str, step)
            summary_writer.flush()

        if (step % 3000 == 0 or step + 1 == FLAGS.max_iters) and step != 0:
            checkpoint_path = os.path.join(
                logdir,
                FLAGS.dataset_name + '_' + FLAGS.network + '_model.ckpt')
            #checkpoint_path = os.path.join(FLAGS.train_dir,
            #                               FLAGS.dataset_name + '_' + FLAGS.network + '_model.ckpt')
            saver.save(sess, checkpoint_path, global_step=step)

        if coord.should_stop():
            coord.request_stop()
            coord.join(threads)
Esempio n. 11
0
def train(
    create_input_dict_fn,
    create_model_fn,
    train_config,
    train_dir,
    task,
    num_clones,
    worker_replicas,
    clone_on_cpu,
    ps_tasks,
    worker_job_name,
    is_chief,
):

    #模型实例先留着,后面预训练模型时再用
    detection_model = create_model_fn()

    with tf.Graph().as_default():
        #配置类
        deploy_config = Deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            replica_id=task,
            num_replicas=worker_replicas,
            num_ps_tasks=ps_tasks,
            worker_job_name=worker_job_name)

        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        batch_size = train_config["batch_size"] // num_clones

        with tf.device(deploy_config.inputs_device()):
            #从tfrecord读数据,组成batch,生成样本队列
            input_queue = get_inputs.read_and_transform_dataset(
                per_clone_batch_size=batch_size,
                create_tensor_dict_fn=create_input_dict_fn)
        #前向计算 forward
        model_fn = functools.partial(_create_losses,
                                     model_fn=create_model_fn,
                                     train_config=train_config)

        clones = Deploy.create_clones(deploy_config, model_fn, input_queue)

        first_clone_scope = clones[0].scope

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        global_summaries = set([])

        with tf.device(deploy_config.optimizer_device()):

            #构建优化器,返回优化器实例和参数
            training_optimizer, optimizer_summary_vars = optimizer_builder.build(
                train_config["optimizer"])
            for var in optimizer_summary_vars:
                global_summaries.add(
                    tf.summary.scalar(var.op.name, var, family='LearningRate'))

        with tf.device(deploy_config.optimizer_device()):
            regularization_losses = None if train_config[
                "add_regularization_loss"] else []
            #梯度Gradient
            total_loss, grads_and_vars = Deploy.optimize_clones(
                clones,
                training_optimizer,
                regularization_losses=regularization_losses)
            #检查Loss为NaN的情况
            total_loss = tf.check_numerics(total_loss,
                                           'LossTensor is inf or nan.')
            #反向传播 Back propagation
            grad_updates = training_optimizer.apply_gradients(
                grads_and_vars, global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops, name='update_barrier')
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        for model_var in slim.get_model_variables():
            global_summaries.add(
                tf.summary.histogram('ModelVars/' + model_var.op.name,
                                     model_var))
        for loss_tensor in tf.losses.get_losses():
            global_summaries.add(
                tf.summary.scalar('Losses/' + loss_tensor.op.name,
                                  loss_tensor))
        global_summaries.add(
            tf.summary.scalar('Losses/TotalLoss', tf.losses.get_total_loss()))

        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        summaries |= global_summaries

        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        keep_checkpoint_every_n_hours = train_config[
            "keep_checkpoint_every_n_hours"]
        saver = tf.train.Saver(
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)
        if train_config["num_steps"]:
            step = train_config["num_steps"]
        else:
            step = None
        #开始训练start
        slim.learning.train(train_tensor,
                            logdir=train_dir,
                            is_chief=is_chief,
                            session_config=session_config,
                            startup_delay_steps=0,
                            summary_op=summary_op,
                            number_of_steps=step,
                            save_summaries_secs=120,
                            saver=saver)
Esempio n. 12
0
def test():
    """The main function that runs training"""

    ## data
    image, original_image_height, original_image_width, image_height, image_width, gt_boxes, gt_masks, num_instances, image_id = \
        datasets.get_dataset(FLAGS.dataset_name,
                             FLAGS.dataset_split_name_test,
                             FLAGS.dataset_dir,
                             FLAGS.im_batch,
                             is_training=False)

    im_shape = tf.shape(image)
    image = tf.reshape(image, (im_shape[0], im_shape[1], im_shape[2], 3))

    ## network
    logits, end_points, pyramid_map = network.get_network(FLAGS.network,
                                                          image,
                                                          weight_decay=0.0,
                                                          batch_norm_decay=0.0,
                                                          is_training=True)
    outputs = pyramid_network.build(end_points,
                                    im_shape[1],
                                    im_shape[2],
                                    pyramid_map,
                                    num_classes=81,
                                    base_anchors=3,
                                    is_training=False,
                                    gt_boxes=None,
                                    gt_masks=None,
                                    loss_weights=[0.0, 0.0, 0.0, 0.0, 0.0])

    input_image = end_points['input']

    testing_mask_rois = outputs['mask_ordered_rois']
    testing_mask_final_mask = outputs['mask_final_mask']
    testing_mask_final_clses = outputs['mask_final_clses']
    testing_mask_final_scores = outputs['mask_final_scores']

    ## solvers
    global_step = slim.create_global_step()

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
    # init_op = tf.group(
    #         tf.global_variables_initializer(),
    #         tf.local_variables_initializer()
    #         )
    # sess.run(init_op)

    # summary_op = tf.summary.merge_all()
    logdir = os.path.join(FLAGS.train_dir, strftime('%Y%m%d%H%M%S', gmtime()))
    if not os.path.exists(logdir):
        os.makedirs(logdir)
    summary_writer = tf.summary.FileWriter(logdir, graph=sess.graph)

    ## restore
    restore(sess)
    tf.train.start_queue_runners(sess=sess)

    ## main loop
    # for step in range(FLAGS.max_iters):
    for step in range(82783):  #range(40503):

        start_time = time.time()

        image_id_str, original_image_heightnp, original_image_widthnp, image_heightnp, image_widthnp, \
        gt_boxesnp, gt_masksnp,\
        input_imagenp,\
        testing_mask_roisnp, testing_mask_final_masknp, testing_mask_final_clsesnp, testing_mask_final_scoresnp = \
                     sess.run([image_id] + [original_image_height] + [original_image_width] + [image_height] + [image_width] +\
                              [gt_boxes] + [gt_masks] +\
                              [input_image] + \
                              [testing_mask_rois] + [testing_mask_final_mask] + [testing_mask_final_clses] + [testing_mask_final_scores])

        duration_time = time.time() - start_time
        if step % 1 == 0:
            print("""iter %d: image-id:%07d, time:%.3f(sec), """
                  """instances: %d, """ %
                  (step, image_id_str, duration_time, gt_boxesnp.shape[0]))

        if step % 1 == 0:
            draw_bbox(step,
                      np.uint8(
                          (np.array(input_imagenp[0]) / 2.0 + 0.5) * 255.0),
                      name='test_est',
                      bbox=testing_mask_roisnp,
                      label=testing_mask_final_clsesnp,
                      prob=testing_mask_final_scoresnp,
                      mask=testing_mask_final_masknp,
                      vis_th=0.5)

            draw_bbox(
                step,
                np.uint8((np.array(input_imagenp[0]) / 2.0 + 0.5) * 255.0),
                name='test_gt',
                bbox=gt_boxesnp[:, 0:4],
                label=gt_boxesnp[:, 4].astype(np.int32),
                prob=np.ones((gt_boxesnp.shape[0], 81), dtype=np.float32),
            )

            print("predict")
            # LOG (cat_id_to_cls_name(np.unique(np.argmax(np.array(training_rcnn_clsesnp),axis=1))))
            print(cat_id_to_cls_name(testing_mask_final_clsesnp))
            print(np.max(np.array(testing_mask_final_scoresnp), axis=1))

        _collectData(image_id_str, testing_mask_final_clsesnp,
                     testing_mask_roisnp, testing_mask_final_scoresnp,
                     original_image_heightnp, original_image_widthnp,
                     image_heightnp, image_widthnp, testing_mask_final_masknp)
Esempio n. 13
0
def train():
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        assert FLAGS.batch_size % FLAGS.num_gpus == 0, (
            'Batch size must be divisible by number of GPUs')

        bs_l = FLAGS.batch_size
        num_iter_per_epoch = int(FLAGS.num_train_l / bs_l)
        max_steps = int(FLAGS.num_epochs * num_iter_per_epoch)
        num_classes = FLAGS.num_classes

        global_step = slim.create_global_step()
        lr = tf.placeholder(tf.float32, shape=[], name="learning_rate")
        opt = tf.train.MomentumOptimizer(learning_rate=lr,
                                         momentum=0.9,
                                         use_nesterov=True)

        images, labels = utils.prepare_traindata(FLAGS.dataset_dir_l,
                                                 int(bs_l))
        images_splits = tf.split(images, FLAGS.num_gpus, 0)
        labels_splits = tf.split(labels, FLAGS.num_gpus, 0)

        tower_grads = []
        top_1_op = []
        reuse_variables = None
        for i in range(FLAGS.num_gpus):
            with tf.device('/gpu:%d' % i):
                with tf.name_scope('%s_%d' % (network.TOWER_NAME, i)) as scope:
                    with slim.arg_scope(slim.get_model_variables(scope=scope),
                                        device='/cpu:0'):
                        loss, logits = \
                            _build_training_graph(images_splits[i], labels_splits[i], num_classes, reuse_variables)
                        top_1_op.append(
                            tf.nn.in_top_k(logits, labels_splits[i], 1))

                    reuse_variables = True
                    summaries = tf.get_collection(tf.GraphKeys.SUMMARIES,
                                                  scope)
                    batchnorm_updates = tf.get_collection(
                        tf.GraphKeys.UPDATE_OPS, scope)
                    grads = opt.compute_gradients(loss)
                    tower_grads.append(grads)

        grads = network.average_gradients(tower_grads)
        gradient_op = opt.apply_gradients(grads, global_step=global_step)

        var_averages = tf.train.ExponentialMovingAverage(
            FLAGS.ema_decay, global_step)
        var_op = var_averages.apply(tf.trainable_variables() +
                                    tf.moving_average_variables())

        batchnorm_op = tf.group(*batchnorm_updates)
        train_op = tf.group(gradient_op, var_op, batchnorm_op)

        saver = tf.train.Saver(tf.global_variables(), max_to_keep=None)
        summary_op = tf.summary.merge(summaries)
        init_op = tf.global_variables_initializer()

        config = tf.ConfigProto(allow_soft_placement=True,
                                log_device_placement=False)
        if FLAGS.gpu_memory:
            config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_memory
        sess = tf.Session(config=config)

        boundaries, values = utils.config_lr(max_steps)
        sess.run([init_op], feed_dict={lr: values[0]})

        tf.train.start_queue_runners(sess=sess)
        summary_writer = tf.summary.FileWriter(FLAGS.train_dir,
                                               graph=sess.graph)

        iter_count = epoch = sum_loss = sum_top_1 = 0
        start = time.time()

        for step in range(max_steps):

            decayed_lr = utils.decay_lr(step, boundaries, values, max_steps)
            _, loss_value, top_1_value = \
                sess.run([train_op, loss, top_1_op], feed_dict={lr: decayed_lr})

            sum_loss += loss_value
            top_1_value = np.sum(top_1_value) / bs_l
            sum_top_1 += top_1_value
            iter_count += 1

            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            if step % num_iter_per_epoch == 0:
                end = time.time()
                sum_loss = sum_loss / num_iter_per_epoch
                sum_top_1 = min(sum_top_1 / num_iter_per_epoch, 1.0)
                time_per_iter = float(end - start) / iter_count
                format_str = (
                    'epoch %d, L = %.2f, top_1 = %.2f, lr = %.4f (time_per_iter: %.4f s)'
                )
                print(format_str % (epoch, sum_loss, sum_top_1 * 100,
                                    decayed_lr, time_per_iter))
                epoch += 1
                sum_loss = sum_top_1 = 0

            if step % 100 == 0:
                summary_str = sess.run(summary_op, feed_dict={lr: decayed_lr})
                summary_writer.add_summary(summary_str, step)

            if (step + 1) == max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=epoch)
Esempio n. 14
0
def train(train_dataset, model, config, lr, train_layers, epochs):
    """ set Solver for losses """
    global_step = slim.create_global_step()
    learning_rate = tf.placeholder(dtype=tf.float32, shape=(), name='learning_rate')

    optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=config.LEARNING_MOMENTUM, name='Momentum')

    losses = tf.get_collection(tf.GraphKeys.LOSSES)
    model_loss = tf.add_n(losses)
    regular_loss = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
    total_loss = model_loss + regular_loss

    """ set the update operations for training """
    update_ops = []
    variables_to_train = set_trainable(train_layers)
    update_opt = optimizer.minimize(total_loss, global_step=global_step, var_list=variables_to_train)
    update_ops.append(update_opt)

    update_bns = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    if len(update_bns):
        update_bn = tf.group(*update_bns)
        update_ops.append(update_bn)
    update_op = tf.group(*update_ops)

    """ set Summary and log info """
    tf.summary.scalar('total_loss', total_loss)
    tf.summary.scalar('model_loss', model_loss)
    tf.summary.scalar('regular_loss', regular_loss)
    tf.summary.scalar('learning_rate', learning_rate)

    summary_op = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(model.log_dir, graph=tf.Session().graph)

    """ set saver for saving final model and backbone model for restore """
    # variables_to_restore = _get_restore_vars('FeatureExtractor/MobilenetV1')
    # re_saver = tf.train.Saver(var_list=variables_to_restore)

    saver = tf.train.Saver(max_to_keep=3)
    """ Set Gpu Env """
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    """ Starting Training..... """
    gpu_opt = tf.GPUOptions(per_process_gpu_memory_fraction=0.9, allow_growth=True)
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_opt)) as sess:
        sess.run(init_op)
        # re_saver.restore(sess, 'data/pretrained_models/mobilenet_v1_coco/model.ckpt')
        ckpt = tf.train.get_checkpoint_state("output/training")
        """ resotre checkpoint of Backbone network """
        if ckpt:
            lastest_ckpt = tf.train.latest_checkpoint("output/training")
            print('lastest', lastest_ckpt)
            saver.restore(sess, lastest_ckpt)

        b=0 # batch index
        num_epoch = 0
        batch_size = config.BATCH_SIZE
        image_index = -1
        image_ids = np.copy(train_dataset.image_ids)
        num_epochs_per_decay = len(image_ids)*epochs
        print('num_epochs_per_decay : ', num_epochs_per_decay)
        print("============ Start for ===================")
        try:
            while True:
                image_index = (image_index + 1) % len(image_ids)
                # shuffle images if at the start of an epoch.
                if image_index == 0:
                    np.random.shuffle(image_ids)
                    num_epoch +=1

                # Get gt_boxes and gt_masks for image.
                image_id = image_ids[image_index]
                image, image_meta, gt_class_ids, gt_boxes, gt_masks = coco_train.load_image_gt(coco_train,
                                                                                               config, image_id,
                                                                                               augment=True,
                                                                                               use_mini_mask=config.USE_MINI_MASK)

                # Skip images that have no instances. This can happen in cases
                # where we train on a subset of classes and the image doesn't
                # have any of the classes we care about.
                if not np.any(gt_class_ids > 0):
                    continue

                # RPN Targets
                rpn_match, rpn_bbox = utils.build_rpn_targets(image.shape, anchors, gt_class_ids, gt_boxes, config)

                # Init batch arrays
                if b == 0:
                    batch_image_meta = np.zeros( (batch_size,) + image_meta.shape, dtype=image_meta.dtype)
                    batch_rpn_match = np.zeros( [batch_size, anchors.shape[0], 1], dtype=rpn_match.dtype)
                    batch_rpn_bbox = np.zeros( [batch_size, config.RPN_TRAIN_ANCHORS_PER_IMAGE, 4], dtype=rpn_bbox.dtype)
                    batch_images = np.zeros( (batch_size,) + image.shape, dtype=np.float32)
                    batch_gt_class_ids = np.zeros( (batch_size, config.MAX_GT_INSTANCES), dtype=np.int32)
                    batch_gt_boxes = np.zeros((batch_size, config.MAX_GT_INSTANCES, 4), dtype=np.int32)
                    if config.USE_MINI_MASK:
                        batch_gt_masks = np.zeros((batch_size, config.MINI_MASK_SHAPE[0], config.MINI_MASK_SHAPE[1],
                                                   config.MAX_GT_INSTANCES))
                    else:
                        batch_gt_masks = np.zeros((batch_size, image.shape[0], image.shape[1], config.MAX_GT_INSTANCES))

                # If more instances than fits in the array, sub-sample from them.
                if gt_boxes.shape[0] > config.MAX_GT_INSTANCES:
                    print("Gt is too much!!")
                    ids = np.random.choice(
                        np.arange(gt_boxes.shape[0]), config.MAX_GT_INSTANCES, replace=False)
                    gt_class_ids = gt_class_ids[ids]
                    gt_boxes = gt_boxes[ids]
                    gt_masks = gt_masks[:, :, ids]

                # Add to batch
                batch_images[b] = gen_cocodb.mold_image(image.astype(np.float32), config)
                batch_image_meta[b] = image_meta
                batch_rpn_match[b] = rpn_match[:, np.newaxis]
                batch_rpn_bbox[b] = rpn_bbox
                batch_gt_class_ids[b, :gt_class_ids.shape[0]] = gt_class_ids
                batch_gt_boxes[b, :gt_boxes.shape[0]] = gt_boxes
                batch_gt_masks[b, :, :, :gt_masks.shape[-1]] = gt_masks
                b += 1

                # Batch full?
                if b >= batch_size:
                    feed_dict={model.input_image: batch_images,  model.input_image_meta: batch_image_meta,
                               model.input_rpn_match: batch_rpn_match, model.input_rpn_bbox: batch_rpn_bbox,
                               model.input_gt_class_ids: batch_gt_class_ids, model.input_gt_boxes: batch_gt_boxes,
                               model.input_gt_masks: batch_gt_masks, learning_rate: lr}

                    _, loss, rpn_cls_loss, rpn_bbox_loss, cls_loss, bbox_loss, mask_loss, r_loss, current_step, summary = \
                        sess.run([update_op, total_loss, losses[0], losses[1], losses[2], losses[3], losses[4], regular_loss, global_step, summary_op], feed_dict=feed_dict)
                    print ("""iter %d : total-loss %.4f (r_c : %.4f, r_b : %.4f, cls : %.4f, box : %.4f, mask : %.4f, reglur : %.4f)""" %
                           (current_step, loss, rpn_cls_loss, rpn_bbox_loss, cls_loss, bbox_loss, mask_loss, r_loss))

                    if np.isnan(loss) or np.isinf(loss):
                        print('isnan or isinf', loss)
                        raise

                    if current_step % num_epochs_per_decay == 0:
                        m = current_step // num_epochs_per_decay
                        lr = pow(m, 0.94) * lr
                        print(lr)
                    if current_step % 1000 == 0:
                        # write summary
                        # summary = sess.run(summary_op, feed_dict=feed_dict)
                        summary_writer.add_summary(summary, current_step)
                        summary_writer.flush()

                    if current_step % 3000 == 0:
                        # Save a checkpoint
                        save_path = 'output/training/mrcnn.ckpt'
                        saver.save(sess, save_path, global_step=current_step)

                    if num_epoch > epochs:
                        print("num epoch : %d and training End!!!" % num_epoch)
                        break

                    b = 0
                    # break
        except Exception as ex:
            print('Error occured!!!! => ', ex)
            # 40040
        finally:
            print("Final!!")
            saver.save(sess, 'output/models/mrcnn_final.ckpt', write_meta_graph=False)
Esempio n. 15
0
def main(_):

    with tf.Graph().as_default() as graph:
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        global_summaries = set([])

        num_batches_epoch = num_samples // (FLAGS.batch_size *
                                            FLAGS.num_clones)
        print(num_batches_epoch)

        #######################
        # Config model_deploy #
        #######################
        config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.ps_tasks)

        # Create global_step
        with tf.device(config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        with tf.device(config.inputs_device()):
            # Train Process
            dataset = get_split('train', FLAGS.dataset_dir)
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=FLAGS.batch_size * 20,
                common_queue_min=FLAGS.batch_size * 10)
            [image_a, image_b,
             label] = provider.get(['image_a', 'image_b', 'label'])

            image_a = process_image(image_a)
            image_b = process_image(image_b)
            image_a.set_shape([FLAGS.target_height, FLAGS.target_width, 3])
            image_b.set_shape([FLAGS.target_height, FLAGS.target_width, 3])
            images_a, images_b, labels = tf.train.batch(
                [image_a, image_b, label],
                batch_size=FLAGS.batch_size,
                num_threads=8,
                capacity=FLAGS.batch_size * 10)

            inputs_queue = prefetch_queue([images_a, images_b, labels])

        ######################
        # Select the network #
        ######################
        def model_fn(inputs_queue):
            images_a, images_b, labels = inputs_queue.dequeue()
            model = find_class_by_name(FLAGS.model, [models])()
            if 'ContrastiveModel' in FLAGS.model:
                vec_a, vec_b = model.create_model(images_a,
                                                  images_b,
                                                  reuse=False,
                                                  is_training=True)
                contrastive_loss = tf.contrib.losses.metric_learning.contrastive_loss(
                    labels, vec_a, vec_b)
                tf.losses.add_loss(contrastive_loss)
            else:

                logits = model.create_model(images_a,
                                            images_b,
                                            reuse=False,
                                            is_training=True)
                label_onehot = tf.one_hot(labels, 2)
                crossentropy_loss = tf.losses.softmax_cross_entropy(
                    onehot_labels=label_onehot, logits=logits)

        clones = model_deploy.create_clones(config, model_fn, [inputs_queue])
        first_clone_scope = clones[0].scope

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(config.optimizer_device()):

            learning_rate_step_boundaries = [
                int(num_batches_epoch * num_epoches * 0.50),
                int(num_batches_epoch * num_epoches * 0.75),
                int(num_batches_epoch * num_epoches * 0.90)
            ]
            learning_rate_sequence = [FLAGS.learning_rate]
            learning_rate_sequence += [
                FLAGS.learning_rate * 0.1, FLAGS.learning_rate * 0.01,
                FLAGS.learning_rate * 0.001
            ]
            learning_rate = learning_schedules.manual_stepping(
                global_step, learning_rate_step_boundaries,
                learning_rate_sequence)
            #             learning_rate = learning_schedules.exponential_decay_with_burnin(global_step,
            #                                   FLAGS.learning_rate,num_batches_epoch*num_epoches,0.001/FLAGS.learning_rate,
            #                                   burnin_learning_rate=0.01,
            #                                   burnin_steps=5000)
            if FLAGS.optimizer == 'adam':
                opt = tf.train.AdamOptimizer(learning_rate)
            if FLAGS.optimizer == 'momentum':
                opt = tf.train.MomentumOptimizer(learning_rate, momentum=0.9)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)
        with tf.device(config.optimizer_device()):
            training_optimizer = opt

        # Create ops required to initialize the model from a given checkpoint. TODO!!
        init_fn = None
        if FLAGS.model == 'DCSL':
            if FLAGS.weights is None:
                # if not FLAGS.moving_average_decay:
                variables = slim.get_model_variables('InceptionResnetV2')
                init_fn = slim.assign_from_checkpoint_fn(
                    os.path.join(FLAGS.checkpoints_dir,
                                 'inception_resnet_v2.ckpt'),
                    slim.get_model_variables('InceptionResnetV2'))
        if FLAGS.model == 'DCSL_inception_v1':
            if FLAGS.weights is None:
                # if not FLAGS.moving_average_decay:
                variables = slim.get_model_variables('InceptionV1')
                init_fn = slim.assign_from_checkpoint_fn(
                    os.path.join(FLAGS.checkpoints_dir, 'inception_v1.ckpt'),
                    slim.get_model_variables('InceptionV1'))
        if FLAGS.model == 'DCSL_NAS':
            #             if FLAGS.weights is None:
            #                 # if not FLAGS.moving_average_decay:
            #                 variables = slim.get_model_variables('NAS')
            #                 init_fn = slim.assign_from_checkpoint_fn(
            #                     os.path.join(FLAGS.checkpoints_dir, 'nasnet-a_large_04_10_2017/model.ckpt'),
            #                     slim.get_model_variables('NAS'))
            def restore_map():
                variables_to_restore = {}
                for variable in tf.global_variables():
                    for scope_name in ['NAS']:
                        if variable.op.name.startswith(scope_name):
                            var_name = variable.op.name.replace(
                                scope_name + '/', '')
                            #                             var_name = variable.op.name
                            variables_to_restore[
                                var_name +
                                '/ExponentialMovingAverage'] = variable


#                             variables_to_restore[var_name] = variable
                return variables_to_restore

            var_map = restore_map()
            # restore_var = [v for v in tf.global_variables() if 'global_step' not in v.name]
            available_var_map = (
                variables_helper.get_variables_available_in_checkpoint(
                    var_map, FLAGS.weights))
            init_saver = tf.train.Saver(available_var_map)

            def initializer_fn(sess):
                init_saver.restore(sess, FLAGS.weights)

            init_fn = initializer_fn

        if FLAGS.model == 'CoAttention':
            if FLAGS.weights is None:
                # if not FLAGS.moving_average_decay:
                variables = slim.get_model_variables('InceptionV1')
                init_fn = slim.assign_from_checkpoint_fn(
                    os.path.join(FLAGS.checkpoints_dir, 'inception_v1.ckpt'),
                    slim.get_model_variables('InceptionV1'))
        if FLAGS.model == 'AttentionBaseModel':
            if FLAGS.weights is None:
                # if not FLAGS.moving_average_decay:
                variables = slim.get_model_variables('InceptionV1')
                init_fn = slim.assign_from_checkpoint_fn(
                    os.path.join(FLAGS.checkpoints_dir, 'inception_v1.ckpt'),
                    slim.get_model_variables('InceptionV1'))
        if FLAGS.model == 'CoAttentionBaseModel':
            if FLAGS.weights is None:
                # if not FLAGS.moving_average_decay:
                variables = slim.get_model_variables('InceptionV1')
                init_fn = slim.assign_from_checkpoint_fn(
                    os.path.join(FLAGS.checkpoints_dir, 'inception_v1.ckpt'),
                    slim.get_model_variables('InceptionV1'))

        if FLAGS.model == 'MultiHeadCoAttention':
            if FLAGS.weights is None:
                # if not FLAGS.moving_average_decay:
                variables = slim.get_model_variables('InceptionV1')
                init_fn = slim.assign_from_checkpoint_fn(
                    os.path.join(FLAGS.checkpoints_dir, 'inception_v1.ckpt'),
                    slim.get_model_variables('InceptionV1'))
        if FLAGS.model == 'MultiHeadAttentionBaseModel':
            if FLAGS.weights is None:
                # if not FLAGS.moving_average_decay:
                variables = slim.get_model_variables('InceptionV1')
                init_fn = slim.assign_from_checkpoint_fn(
                    os.path.join(FLAGS.checkpoints_dir, 'inception_v1.ckpt'),
                    slim.get_model_variables('InceptionV1'))
        if FLAGS.model == 'MultiHeadAttentionBaseModel_fixed':
            if FLAGS.weights is None:
                # if not FLAGS.moving_average_decay:
                variables = slim.get_model_variables('InceptionV1')
                init_fn = slim.assign_from_checkpoint_fn(
                    os.path.join(FLAGS.checkpoints_dir, 'inception_v1.ckpt'),
                    slim.get_model_variables('InceptionV1'))
        if FLAGS.model == 'MultiHeadAttentionBaseModel_res':
            if FLAGS.weights is None:
                # if not FLAGS.moving_average_decay:
                variables = slim.get_model_variables('InceptionV1')
                init_fn = slim.assign_from_checkpoint_fn(
                    os.path.join(FLAGS.checkpoints_dir, 'inception_v1.ckpt'),
                    slim.get_model_variables('InceptionV1'))
        if FLAGS.model == 'MultiHeadAttentionBaseModel_set_share_softmax':
            if FLAGS.weights is None:
                # if not FLAGS.moving_average_decay:
                variables = slim.get_model_variables('InceptionV1')
                init_fn = slim.assign_from_checkpoint_fn(
                    os.path.join(FLAGS.checkpoints_dir, 'inception_v1.ckpt'),
                    slim.get_model_variables('InceptionV1'))
        if FLAGS.model == 'CoAttentionBaseModel_v2':
            if FLAGS.weights is None:
                # if not FLAGS.moving_average_decay:
                variables = slim.get_model_variables('InceptionV1')
                init_fn = slim.assign_from_checkpoint_fn(
                    os.path.join(FLAGS.checkpoints_dir, 'inception_v1.ckpt'),
                    slim.get_model_variables('InceptionV1'))

        if 'ParallelAttentionBaseModel' in FLAGS.model:
            if FLAGS.weights is None:
                # if not FLAGS.moving_average_decay:
                variables = slim.get_model_variables('InceptionV1')
                init_fn = slim.assign_from_checkpoint_fn(
                    os.path.join(FLAGS.checkpoints_dir, 'inception_v1.ckpt'),
                    slim.get_model_variables('InceptionV1'))
        if 'ContrastiveModel' in FLAGS.model:
            if FLAGS.weights is None:
                # if not FLAGS.moving_average_decay:
                variables = slim.get_model_variables('InceptionV1')
                init_fn = slim.assign_from_checkpoint_fn(
                    os.path.join(FLAGS.checkpoints_dir, 'inception_v1.ckpt'),
                    slim.get_model_variables('InceptionV1'))

        if FLAGS.model == 'MultiHeadCoAttention_inv4':
            if FLAGS.weights is None:
                # if not FLAGS.moving_average_decay:
                variables = slim.get_model_variables('InceptionV4')
                init_fn = slim.assign_from_checkpoint_fn(
                    os.path.join(FLAGS.checkpoints_dir, 'inception_v4.ckpt'),
                    slim.get_model_variables('InceptionV4'))
        if FLAGS.model == 'MultiLayerMultiHeadCoAttention':
            if FLAGS.weights is None:
                # if not FLAGS.moving_average_decay:
                variables = slim.get_model_variables('InceptionV1')
                init_fn = slim.assign_from_checkpoint_fn(
                    os.path.join(FLAGS.checkpoints_dir, 'inception_v1.ckpt'),
                    slim.get_model_variables('InceptionV1'))
        if FLAGS.model == 'DCSL_inception_v4':
            if FLAGS.weights is None:
                # if not FLAGS.moving_average_decay:
                variables = slim.get_model_variables('InceptionV4')
                init_fn = slim.assign_from_checkpoint_fn(
                    os.path.join(FLAGS.checkpoints_dir, 'inception_v4.ckpt'),
                    slim.get_model_variables('InceptionV4'))

        # compute and update gradients
        with tf.device(config.optimizer_device()):
            if FLAGS.moving_average_decay:
                update_ops.append(
                    variable_averages.apply(moving_average_variables))

            # Variables to train.
            all_trainable = tf.trainable_variables()

            #  and returns a train_tensor and summary_op
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones,
                training_optimizer,
                regularization_losses=None,
                var_list=all_trainable)

            # Optionally clip gradients
            # with tf.name_scope('clip_grads'):
            #     grads_and_vars = slim.learning.clip_gradient_norms(grads_and_vars, 10)

            total_loss = tf.check_numerics(total_loss,
                                           'LossTensor is inf or nan.')

            # Create gradient updates.
            grad_updates = training_optimizer.apply_gradients(
                grads_and_vars, global_step=global_step)
            update_ops.append(grad_updates)

            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add summaries.
        for loss_tensor in tf.losses.get_losses():
            global_summaries.add(
                tf.summary.scalar(loss_tensor.op.name, loss_tensor))
        global_summaries.add(
            tf.summary.scalar('TotalLoss', tf.losses.get_total_loss()))

        # Add the summaries from the first clone. These contain the summaries
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        summaries |= global_summaries
        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # GPU settings
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)
        session_config.gpu_options.allow_growth = False
        # Save checkpoints regularly.
        keep_checkpoint_every_n_hours = 2.0

        saver = tf.train.Saver(
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(train_tensor,
                            logdir=logdir,
                            master=FLAGS.master,
                            is_chief=(FLAGS.task == 0),
                            session_config=session_config,
                            startup_delay_steps=10,
                            summary_op=summary_op,
                            init_fn=init_fn,
                            number_of_steps=num_batches_epoch *
                            FLAGS.num_epoches,
                            save_summaries_secs=240,
                            sync_optimizer=None,
                            saver=saver)
Esempio n. 16
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.DEBUG)
    with tf.Graph().as_default():
        # Config model_deploy. Keep TF Slim Models structure.
        # Useful if want to need multiple GPUs and/or servers in the future.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=0,
            num_replicas=1,
            num_ps_tasks=0)
        # Create global_step.
        with tf.device(deploy_config.variables_device()):  # 分配设备
            global_step = slim.create_global_step()

        # Select the dataset.#得到数据
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        # Get the SSD network and its anchors.
        ssd_class = nets_factory.get_network(
            FLAGS.model_name)  # 返回ssd_vgg_300.SSDNet
        ssd_params = ssd_class.default_params._replace(
            num_classes=FLAGS.num_classes)
        ssd_net = ssd_class(ssd_params)
        ssd_shape = ssd_net.params.img_shape
        ssd_anchors = ssd_net.anchors(ssd_shape)  # 为每个特征图生成anchors

        # Select the preprocessing function.
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        # 得到处理数据的程序
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        tf_utils.print_configuration(FLAGS.__flags, ssd_params,
                                     dataset.data_sources, FLAGS.train_dir)
        # =================================================================== #
        # Create a dataset provider and batches.
        # =================================================================== #
        with tf.device(deploy_config.inputs_device()):
            with tf.name_scope(FLAGS.dataset_name + '_data_provider'):
                provider = slim.dataset_data_provider.DatasetDataProvider(
                    dataset,
                    num_readers=FLAGS.num_readers,
                    common_queue_capacity=20 * FLAGS.batch_size,
                    common_queue_min=10 * FLAGS.batch_size,
                    shuffle=True)
            # Get for SSD network: image, labels, bboxes.
            [image, glabels,
             gbboxes] = provider.get(['image', 'object/label', 'object/bbox'])

            # Pre-processing image, labels and bboxes.
            # 对图像进行预处理
            image, glabels, gbboxes = image_preprocessing_fn(
                image,
                glabels,
                gbboxes,
                out_shape=ssd_shape,
                data_format=DATA_FORMAT)

            # Encode groundtruth labels and bboxes.
            ###############################################################没看懂
            gclasses, glocalisations, gscores = ssd_net.bboxes_encode(
                glabels, gbboxes, ssd_anchors)
            batch_shape = [1] + [len(ssd_anchors)] * 3

            # Training batches and queue.
            r = tf.train.batch(tf_utils.reshape_list(
                [image, gclasses, glocalisations, gscores]),
                               batch_size=FLAGS.batch_size,
                               num_threads=FLAGS.num_preprocessing_threads,
                               capacity=5 * FLAGS.batch_size)
            b_image, b_gclasses, b_glocalisations, b_gscores = \
                tf_utils.reshape_list(r, batch_shape)

            # Intermediate queueing: unique batch computation pipeline for all
            # GPUs running the training.
            batch_queue = slim.prefetch_queue.prefetch_queue(
                tf_utils.reshape_list(
                    [b_image, b_gclasses, b_glocalisations, b_gscores]),
                capacity=2 * deploy_config.num_clones)

        # =================================================================== #
        # Define the model running on every GPU.
        # =================================================================== #
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple
            clones of network_fn."""
            # Dequeue batch.
            b_image, b_gclasses, b_glocalisations, b_gscores = \
                tf_utils.reshape_list(batch_queue.dequeue(), batch_shape)

            # Construct SSD network.
            arg_scope = ssd_net.arg_scope(weight_decay=FLAGS.weight_decay,
                                          data_format=DATA_FORMAT)
            with slim.arg_scope(arg_scope):
                predictions, localisations, logits, end_points = \
                    ssd_net.net(b_image, is_training=True,DSSD_FLAG = FLAGS.DSSD_FLAG)
            # Add loss function.
            ssd_net.losses(logits,
                           localisations,
                           b_gclasses,
                           b_glocalisations,
                           b_gscores,
                           match_threshold=FLAGS.match_threshold,
                           negative_ratio=FLAGS.negative_ratio,
                           alpha=FLAGS.loss_alpha,
                           label_smoothing=FLAGS.label_smoothing)
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # =================================================================== #
        # Add summaries from first clone.
        # =================================================================== #
        ##########################################没看懂
        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))
        # Add summaries for losses and extra losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar(loss.op.name, loss))
        for loss in tf.get_collection('EXTRA_LOSSES', first_clone_scope):
            summaries.add(tf.summary.scalar(loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        # =================================================================== #
        # Configure the moving averages.
        # =================================================================== #
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        # =================================================================== #
        # Configure the optimization procedure.
        # =================================================================== #
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = tf_utils.configure_learning_rate(
                FLAGS, dataset.num_samples, global_step)
            optimizer = tf_utils.configure_optimizer(FLAGS, learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = tf_utils.get_variables_to_train(FLAGS)

        # and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)
        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op],
                                                          total_loss,
                                                          name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # =================================================================== #
        # Kicks off the training.
        # =================================================================== #
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)
        config = tf.ConfigProto(log_device_placement=False,
                                gpu_options=gpu_options)
        saver = tf.train.Saver(max_to_keep=5,
                               keep_checkpoint_every_n_hours=1.0,
                               write_version=2,
                               pad_step_number=False)
        # n = tf.all_variables()
        if FLAGS.DSSD_FLAG:
            ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_path)
            # reader = tf.train.NewCheckpointReader(ckpt.model_checkpoint_path)

            variables_to_restore = [
                var.name for var in tf.all_variables()
                if var.name.startswith("_box", 18)
                or var.name.startswith("_box", 19)
            ]

            variables_to_restore = slim.get_variables_to_restore(
                exclude=variables_to_restore)
            #
            # restore = tf.train.Saver(variables_to_restore)
            init_fn = slim.assign_from_checkpoint_fn(
                ckpt.model_checkpoint_path,
                variables_to_restore,
                ignore_missing_vars=True,
                reshape_variables=False)
        else:
            init_fn = tf_utils.get_init_fn(FLAGS)

        # with tf.Session() as sess:
        #     # init_fn(sess)
        #     ckpt_filename = './checkpoints_fpn/model.ckpt-87149'
        #     saver.restore(sess, ckpt_filename)
        #     print(".................................")

        slim.learning.train(train_tensor,
                            logdir=FLAGS.train_dir,
                            master='',
                            is_chief=True,
                            init_fn=init_fn,
                            summary_op=summary_op,
                            number_of_steps=FLAGS.max_number_of_steps,
                            log_every_n_steps=FLAGS.log_every_n_steps,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            saver=saver,
                            save_interval_secs=FLAGS.save_interval_secs,
                            session_config=config,
                            sync_optimizer=None)
Esempio n. 17
0
def train():
    """The main function that runs training"""
    ## data
    #this will return the placeholders from tfrecords
    image, ih, iw, gt_boxes, gt_masks, num_instances, img_id = datasets.get_dataset(FLAGS.dataset_name,  FLAGS.dataset_split_name, FLAGS.dataset_dir, FLAGS.im_batch,is_training=True)

    data_queue = tf.RandomShuffleQueue(capacity=32, min_after_dequeue=16,dtypes=(
                image.dtype, ih.dtype, iw.dtype, 
                gt_boxes.dtype, gt_masks.dtype, 
                num_instances.dtype, img_id.dtype)) 
    enqueue_op = data_queue.enqueue((image, ih, iw, gt_boxes, gt_masks, num_instances, img_id))
    data_queue_runner = tf.train.QueueRunner(data_queue, [enqueue_op] * 4)
    tf.add_to_collection(tf.GraphKeys.QUEUE_RUNNERS, data_queue_runner)
    (image, ih, iw, gt_boxes, gt_masks, num_instances, img_id) =  data_queue.dequeue()
    im_shape = tf.shape(image)
    image = tf.reshape(image, (im_shape[0], im_shape[1], im_shape[2], 3))

    ## network
    logits, end_points, pyramid_map = network.get_network(FLAGS.network, image,weight_decay=FLAGS.weight_decay, is_training=True)
    outputs = pyramid_network.build(end_points, im_shape[1], im_shape[2], pyramid_map,num_classes=2, base_anchors=9,is_training=True, gt_boxes=gt_boxes, gt_masks=gt_masks,loss_weights=[1.0, 1.0, 1.0, 1.0, 1.0])


    total_loss = outputs['total_loss']
    losses  = outputs['losses']

    regular_loss = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
    
    input_image = end_points['input']
    final_box = outputs['final_boxes']['box']
    final_cls = outputs['final_boxes']['cls']
    final_prob = outputs['final_boxes']['prob']
    final_gt_cls = outputs['final_boxes']['gt_cls']

    #this flag is used for including the mask or not. initally I trained the network without the mask branch, because I wanted to train better the region proposal network
    # so that the network proposes better boxes. If the boxes are better proposed, the branch network will learn easier. Initially I thought that this is the problem
    # for the model memory issue. The idea is that at some point the network was proposing too many regions, like 120, and the Tensor for the mask branch would cause an out of memory error
    # because the shape of tensor would be [120,112,112,7]
    print ("FLAGS INCLUDE MASK IS ",FLAGS.INCLUDE_MASK)
    if FLAGS.INCLUDE_MASK:
        final_mask = outputs['mask']['final_mask_for_drawing']
    gt = outputs['gt']

    

    #############################
    tmp_0 = outputs['losses']
    tmp_1 = outputs['losses']
    tmp_2 = outputs['losses']
    tmp_3 = outputs['tmp_3']
    tmp_4 = outputs['tmp_4']
    ############################


    ## solvers
    global_step = slim.create_global_step()
    update_op = solve(global_step)
    
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.95)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
    init_op = tf.group(
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
            )
    sess.run(init_op)

    summary_op = tf.summary.merge_all()
    logdir = os.path.join(FLAGS.train_dir, strftime('%Y%m%d%H%M%S', gmtime()))
    if not os.path.exists(logdir):
        os.makedirs(logdir)
    summary_writer = tf.summary.FileWriter(logdir, graph=sess.graph)

    ## restore
    restore(sess)

    ## main loop
    coord = tf.train.Coordinator()
    threads = []
    for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
        threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
                                         start=True))

    tf.train.start_queue_runners(sess=sess, coord=coord)
    saver = tf.train.Saver(max_to_keep=20)

    for step in range(FLAGS.max_iters):
        
        start_time = time.time()
        if FLAGS.INCLUDE_MASK:
            s_, tot_loss, reg_lossnp, img_id_str, rpn_box_loss, rpn_cls_loss, refined_box_loss, refined_cls_loss,mask_loss, gt_boxesnp, input_imagenp, final_boxnp, final_clsnp, final_probnp, final_gt_clsnp, gtnp, tmp_0np, tmp_1np, tmp_2np, tmp_3np, tmp_4np, final_masknp,gt_masksnp= sess.run([update_op, total_loss, regular_loss, img_id] + losses + [gt_boxes] + [input_image] + [final_box] + [final_cls] + [final_prob] + [final_gt_cls] + [gt] + [tmp_0] + [tmp_1] + [tmp_2] + [tmp_3] + [tmp_4]+[final_mask]+[gt_masks])
        else:
            s_, tot_loss, reg_lossnp, img_id_str,\
            rpn_box_loss, rpn_cls_loss,refined_box_loss, refined_cls_loss,\
            gt_boxesnp, input_imagenp, final_boxnp,\
            final_clsnp, final_probnp, final_gt_clsnp, gtnp=\
                sess.run([update_op, total_loss, regular_loss, img_id] +\
                         losses +\
                         [gt_boxes] + [input_image] + [final_box] + \
                         [final_cls] + [final_prob] + [final_gt_cls] + [gt])

        duration_time = time.time() - start_time
        if step % 1 == 0:
            if FLAGS.INCLUDE_MASK:
                print ( """iter %d: image-id:%07d, time:%.3f(sec), regular_loss: %.9f, """
                        """total-loss %.10f(%.4f, %.4f, %.6f, %.4f,%.5f), """ #%.4f
                        """instances: %d, proposals: %d """
                       % (step, img_id_str, duration_time, reg_lossnp,
                          tot_loss, rpn_box_loss, rpn_cls_loss, refined_box_loss, refined_cls_loss, mask_loss,
                          gt_boxesnp.shape[0],len(final_boxnp)))
            else:
                print ( """iter %d: image-id:%07d, time:%.3f(sec), regular_loss: %.9f, """
                        """total-loss %.4f(%.4f, %.4f, %.6f, %.4f), """ #%.4f
                        """instances: %d, proposals: %d """
                       % (step, img_id_str, duration_time, reg_lossnp,
                          tot_loss, rpn_box_loss, rpn_cls_loss, refined_box_loss, refined_cls_loss, #mask_loss,
                          gt_boxesnp.shape[0],len(final_boxnp)))

            if sys.argv[1]=='--draw':
                if FLAGS.INCLUDE_MASK:
                    input_imagenp = np.uint8((np.array(input_imagenp[0])/2.0+0.5)*255.0)
                    final_gt_clsnp = np.argmax(np.asarray(final_gt_clsnp),axis=1)
                    draw_human_body_parts(step, input_imagenp,  bbox=final_boxnp, label=final_clsnp, gt_label=final_gt_clsnp, prob=final_probnp,final_mask=final_masknp)

                else:
                    save(step,input_imagenp,final_boxnp,gt_boxesnp,final_clsnp,final_probnp,final_gt_clsnp,None,None)

            if np.isnan(tot_loss) or np.isinf(tot_loss):
                print (gt_boxesnp)
                raise
          
        if step % 1000 == 0:
            summary_str = sess.run(summary_op)
            summary_writer.add_summary(summary_str, step)
            summary_writer.flush()

        if (step % 1000 == 0 or step + 1 == FLAGS.max_iters) and step != 0:
            checkpoint_path = os.path.join(FLAGS.train_dir, 
                                           FLAGS.dataset_name + '_' + FLAGS.network + '_model.ckpt')
            saver.save(sess, checkpoint_path, global_step=step)

        if coord.should_stop():
            coord.request_stop()
            coord.join(threads)
def train():
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        # split the batch across GPUs.
        assert FLAGS.batch_size % FLAGS.num_gpus == 0, (
            'Batch size must be divisible by number of GPUs')

        start_sign_placeholder = tf.placeholder(tf.bool, name='start_sign')

        images, labels, cams = utils.prepare_data('train')

        # Split the batch of images and labels for towers.
        images_splits = tf.split(images, FLAGS.num_gpus, 0)
        labels_splits = tf.split(labels, FLAGS.num_gpus, 0)
        cams_splits = tf.split(cams, FLAGS.num_gpus, 0)

        num_classes = FLAGS.num_classes + 1
        global_step = slim.create_global_step()

        # Create an optimizer that performs gradient descent.
        if FLAGS.optimizer == 'rmsprop':
            # Calculate the learning rate schedule.
            num_batches_per_epoch = (FLAGS.num_samples / FLAGS.batch_size)
            decay_steps = int(num_batches_per_epoch *
                              FLAGS.num_epochs_per_decay)
            # Decay the learning rate exponentially based on the number of steps.
            lr = tf.train.exponential_decay(FLAGS.initial_learning_rate,
                                            global_step,
                                            decay_steps,
                                            FLAGS.learning_rate_decay_factor,
                                            staircase=True)
            opt = tf.train.RMSPropOptimizer(lr,
                                            RMSPROP_DECAY,
                                            momentum=RMSPROP_MOMENTUM,
                                            epsilon=RMSPROP_EPSILON)
        elif FLAGS.optimizer == 'sgd':
            boundaries = [int(1 / 2 * float(FLAGS.max_steps))]
            boundaries = list(np.array(boundaries, dtype=np.int64))
            values = [0.01, 0.001]
            lr = tf.train.piecewise_constant(global_step, boundaries, values)
            opt = tf.train.MomentumOptimizer(learning_rate=lr,
                                             momentum=0.9,
                                             use_nesterov=True)

        tower_grads = []
        anchors_op = []
        reuse_variables = None
        for i in range(FLAGS.num_gpus):
            with tf.device('/gpu:%d' % i):
                with tf.name_scope('%s_%d' % (network.TOWER_NAME, i)) as scope:
                    with slim.arg_scope(slim.get_model_variables(scope=scope),
                                        device='/cpu:0'):
                        # Calculate the loss for one tower of the model.
                        loss, anchors = \
                            _tower_loss(images_splits[i], labels_splits[i], cams_splits[i],
                                        num_classes, reuse_variables, start_sign_placeholder)

                        anchors_op.append(anchors)

                    # Reuse variables for the next tower.
                    reuse_variables = True

                    batchnorm = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                                  scope)
                    batchnorm = [
                        var for var in batchnorm if not 'Logits' in var.name
                    ]

                    trainable_var = tf.trainable_variables()
                    trainable_var = [
                        var for var in trainable_var
                        if not 'Logits' in var.name
                    ]

                    grads = opt.compute_gradients(loss, var_list=trainable_var)
                    tower_grads.append(grads)

        # synchronize gradients across all towers
        grads = network.average_gradients(tower_grads)
        gradient_op = opt.apply_gradients(grads, global_step=global_step)

        var_averages = tf.train.ExponentialMovingAverage(
            FLAGS.ema_decay, global_step)
        var_average = tf.trainable_variables()
        var_average = [var for var in var_average if not 'Logits' in var.name]
        var_op = var_averages.apply(var_average)

        batchnorm_op = tf.group(*batchnorm)
        train_op = tf.group(gradient_op, var_op, batchnorm_op)

        saver = tf.train.Saver(tf.global_variables(), max_to_keep=None)
        init = tf.global_variables_initializer()

        sess = tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        # continue training from existing model
        if FLAGS.pretrained_model_checkpoint_path:
            var_to_restore = [
                var for var in trainable_var if not 'Logits' in var.name
            ]
            restorer = tf.train.Saver(var_to_restore)
            restorer.restore(sess, FLAGS.pretrained_model_checkpoint_path)
            print('%s: Pre-trained model restored from %s' %
                  (datetime.now(), FLAGS.pretrained_model_checkpoint_path))

        tf.train.start_queue_runners(sess=sess)
        step_1_epoch = int(float(FLAGS.num_samples) / float(FLAGS.batch_size))

        for step in range(FLAGS.max_steps):
            start_time = time.time()
            _, _, loss_value = \
                sess.run([train_op, anchors_op, loss],
                         feed_dict={start_sign_placeholder:
                                    step>=step_1_epoch*FLAGS.warm_up_epochs})

            duration = time.time() - start_time
            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            if step % 10 == 0:
                examples_per_sec = FLAGS.batch_size / float(duration)
                format_str = ('%s: step %d, loss = %.4f '
                              '(%.1f examples/sec; %.3f sec/batch)')
                print(format_str % (datetime.now(), step, loss_value,
                                    examples_per_sec, duration))

            if (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step + 1)
Esempio n. 19
0
def train():
    """The main function that runs training"""
    ## data

    image, original_image_height, original_image_width, image_height, image_width, gt_boxes, gt_masks, num_instances, image_id = \
        datasets.get_dataset(FLAGS.dataset_name, 
                             FLAGS.dataset_split_name, 
                             FLAGS.dataset_dir, 
                             FLAGS.im_batch,
                             is_training=True)

    ## queuing data
    data_queue = tf.RandomShuffleQueue(capacity=32, min_after_dequeue=16,
            dtypes=(
                image.dtype, original_image_height.dtype, original_image_width.dtype, image_height.dtype, image_width.dtype,
                gt_boxes.dtype, gt_masks.dtype, 
                num_instances.dtype, image_id.dtype)) 
    enqueue_op = data_queue.enqueue((image, original_image_height, original_image_width, image_height, image_width, gt_boxes, gt_masks, num_instances, image_id))

    data_queue_runner = tf.train.QueueRunner(data_queue, [enqueue_op] * 4)
    tf.add_to_collection(tf.GraphKeys.QUEUE_RUNNERS, data_queue_runner)
    (image, original_image_height, original_image_width, image_height, image_width, gt_boxes, gt_masks, num_instances, image_id) =  data_queue.dequeue()

    im_shape = tf.shape(image)
    image = tf.reshape(image, (im_shape[0], im_shape[1], im_shape[2], 3))

    ## network
    logits, end_points, pyramid_map = network.get_network(FLAGS.network, image,
            weight_decay=FLAGS.weight_decay, batch_norm_decay=FLAGS.batch_norm_decay, is_training=True)
    outputs = pyramid_network.build(end_points, image_height, image_width, pyramid_map,
            num_classes=81,
            base_anchors=3,#9#15
            is_training=True,
            gt_boxes=gt_boxes, gt_masks=gt_masks,
            loss_weights=[1.0, 1.0, 10.0, 1.0, 10.0])
            # loss_weights=[10.0, 1.0, 0.0, 0.0, 0.0])
            # loss_weights=[100.0, 100.0, 1000.0, 10.0, 100.0])
            # loss_weights=[0.2, 0.2, 1.0, 0.2, 1.0])
            # loss_weights=[0.1, 0.01, 10.0, 0.1, 1.0])

    total_loss = outputs['total_loss']
    losses  = outputs['losses']
    batch_info = outputs['batch_info']
    regular_loss = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
    input_image = end_points['input']

    training_rcnn_rois                  = outputs['training_rcnn_rois']
    training_rcnn_clses                 = outputs['training_rcnn_clses']
    training_rcnn_clses_target          = outputs['training_rcnn_clses_target'] 
    training_rcnn_scores                = outputs['training_rcnn_scores']
    training_mask_rois                  = outputs['training_mask_rois']
    training_mask_clses_target          = outputs['training_mask_clses_target']
    training_mask_final_mask            = outputs['training_mask_final_mask']
    training_mask_final_mask_target     = outputs['training_mask_final_mask_target']
    tmp_0 = outputs['rpn']['P2']['shape']
    tmp_1 = outputs['rpn']['P3']['shape']
    tmp_2 = outputs['rpn']['P4']['shape']
    tmp_3 = outputs['rpn']['P5']['shape']

    ## solvers
    global_step = slim.create_global_step()
    update_op = solve(global_step)

    cropped_rois = tf.get_collection('__CROPPED__')[0]
    transposed = tf.get_collection('__TRANSPOSED__')[0]
    
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.95)
    #gpu_options = tf.GPUOptions(allow_growth=True)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
    #sess = tf.InteractiveSession(config=tf.ConfigProto(gpu_options=gpu_options))
    init_op = tf.group(
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
            )
    sess.run(init_op)

    summary_op = tf.summary.merge_all()
    logdir = os.path.join(FLAGS.train_dir, strftime('%Y%m%d%H%M%S', gmtime()))
    if not os.path.exists(logdir):
        os.makedirs(logdir)
    summary_writer = tf.summary.FileWriter(logdir, graph=sess.graph)

    ## restore
    restore(sess)

    ## coord settings
    coord = tf.train.Coordinator()
    threads = []
    for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
        threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
                                         start=True))
    tf.train.start_queue_runners(sess=sess, coord=coord)

    ## saver init
    saver = tf.train.Saver(max_to_keep=20)

    ## finalize the graph for checking memory leak
    sess.graph.finalize()

    ## main loop
    for step in range(FLAGS.max_iters):
        
        start_time = time.time()

        s_, tot_loss, reg_lossnp, image_id_str, \
        rpn_box_loss, rpn_cls_loss, rcnn_box_loss, rcnn_cls_loss, mask_loss, \
        gt_boxesnp, tmp_0np, tmp_1np, tmp_2np, tmp_3np, \
        rpn_batch_pos, rpn_batch, rcnn_batch_pos, rcnn_batch, mask_batch_pos, mask_batch, \
        input_imagenp, \
        training_rcnn_roisnp, training_rcnn_clsesnp, training_rcnn_clses_targetnp, training_rcnn_scoresnp, training_mask_roisnp, training_mask_clses_targetnp, training_mask_final_masknp, training_mask_final_mask_targetnp  = \
                     sess.run([update_op, total_loss, regular_loss, image_id] + 
                              losses + 
                              [gt_boxes] + [tmp_0] + [tmp_1] + [tmp_2] +[tmp_3] +
                              batch_info + 
                              [input_image] + 
                              [training_rcnn_rois] + [training_rcnn_clses] + [training_rcnn_clses_target] + [training_rcnn_scores] + [training_mask_rois] + [training_mask_clses_target] + [training_mask_final_mask] + [training_mask_final_mask_target])

        duration_time = time.time() - start_time
        if step % 1 == 0: 
            LOG ( """iter %d: image-id:%07d, time:%.3f(sec), regular_loss: %.6f, """
                    """total-loss %.4f(%.4f, %.4f, %.6f, %.4f, %.4f), """
                    """instances: %d, """
                    """batch:(%d|%d, %d|%d, %d|%d)""" 
                   % (step, image_id_str, duration_time, reg_lossnp, 
                      tot_loss, rpn_box_loss, rpn_cls_loss, rcnn_box_loss, rcnn_cls_loss, mask_loss,
                      gt_boxesnp.shape[0], 
                      rpn_batch_pos, rpn_batch, rcnn_batch_pos, rcnn_batch, mask_batch_pos, mask_batch))

            LOG ("target")
            LOG (cat_id_to_cls_name(np.unique(np.argmax(np.asarray(training_rcnn_clses_targetnp),axis=1))))
            LOG ("predict")
            LOG (cat_id_to_cls_name(np.unique(np.argmax(np.array(training_rcnn_clsesnp),axis=1))))
            LOG (tmp_0np)
            LOG (tmp_1np)
            LOG (tmp_2np)
            LOG (tmp_3np)

        if step % 50 == 0: 
            draw_bbox(step, 
                      np.uint8((np.array(input_imagenp[0])/2.0+0.5)*255.0), 
                      name='train_est', 
                      bbox=training_rcnn_roisnp, 
                      label=np.argmax(np.array(training_rcnn_scoresnp),axis=1), 
                      prob=training_rcnn_scoresnp,
                      # bbox=training_mask_roisnp, 
                      # label=training_mask_clses_targetnp, 
                      # prob=np.zeros((training_mask_final_masknp.shape[0],81), dtype=np.float32)+1.0,
                      # mask=training_mask_final_masknp,
                      vis_all=True)

            draw_bbox(step, 
                      np.uint8((np.array(input_imagenp[0])/2.0+0.5)*255.0), 
                      name='train_gt', 
                      bbox=training_rcnn_roisnp, 
                      label=np.argmax(np.array(training_rcnn_clses_targetnp),axis=1), 
                      prob=np.zeros((training_rcnn_clsesnp.shape[0],81), dtype=np.float32)+1.0,
                      # bbox=training_mask_roisnp, 
                      # label=training_mask_clses_targetnp, 
                      # prob=np.zeros((training_mask_final_masknp.shape[0],81), dtype=np.float32)+1.0,
                      # mask=training_mask_final_mask_targetnp,
                      vis_all=True)
            
            if np.isnan(tot_loss) or np.isinf(tot_loss):
                LOG (gt_boxesnp)
                raise
          
        if step % 100 == 0:
            summary_str = sess.run(summary_op)
            summary_writer.add_summary(summary_str, step)
            summary_writer.flush()

        if (step % 500 == 0 or step + 1 == FLAGS.max_iters) and step != 0:
            checkpoint_path = os.path.join(FLAGS.train_dir, 
                                           FLAGS.dataset_name + '_' + FLAGS.network + '_model.ckpt')
            saver.save(sess, checkpoint_path, global_step=step)

        if coord.should_stop():
            coord.request_stop()
            coord.join(threads)
        gc.collect()
def train(data_list, label_list, k=-1):
    tf.reset_default_graph()
    image, label = utils.balanced_train_sampling(data_list, label_list)
    num_samples = len(label_list)
    num_epochs = (Flags.max_step * Flags.batch_train) // num_samples + 1
    dataset =tf.data.Dataset.from_tensor_slices((image, label)).map(prepprocess)
    dataset = dataset.repeat(num_epochs).shuffle(num_samples).batch(Flags.batch_train)
    itreator = dataset.make_one_shot_iterator()
    image_batch, label_batch = itreator.get_next()
    slim.create_global_step()
    x = tf.placeholder(tf.float32, shape=[Flags.batch_train, Flags.img_height, Flags.img_width, Flags.input_channel])
    y = tf.placeholder(tf.int16, shape=[Flags.batch_train, Flags.num_classes])

    model = network.Model(imgs=x,
                          num_classes=Flags.num_classes,
                          scope=Flags.scope,
                          img_height=Flags.img_height,
                          img_width=Flags.img_width,
                          is_training=True)

    excitation_list, output = model.build()
    accuracy = utils.accuracy(output, y)
    loss = slim.losses.softmax_cross_entropy(output, y)
    optimizer = tf.train.AdamOptimizer(learning_rate=Flags.init_learning_rate)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    if update_ops:
        updates = tf.group(*update_ops)
        loss = control_flow_ops.with_dependencies([updates], loss)
    train_op = slim.learning.create_train_op(loss, optimizer)

    saver = tf.train.Saver(tf.global_variables())
    init = tf.global_variables_initializer()
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
    sess.run(init)
    if Flags.load_ckpt:
        print('loading pretrained weights')
        if Flags.scope == 'vgg_16':
            pre_trained_weights = Flags.vgg_16_ckpt_dir
            utils.load_ckpt_with_skip(pre_trained_weights, sess,
                                      skip_layer=['global_step ', 'vgg_16/fc8/biases', 'vgg_16/fc8/weights'])
            print('loading success')
        elif Flags.scope == 'resnet_v1_50':
            pre_trained_weights = Flags.resnet_v1_50_ckpt_dir
            utils.load_ckpt_with_skip(pre_trained_weights, sess,
                                      skip_layer=['global_step ', 'resnet_v1_50/logits/weights',
                                                  'resnet_v1_50/logits/biases'])
            print('loading success')
        elif Flags.scope == 'InceptionV3':
            pre_trained_weights = Flags.InceptionV3_ckpt_dir
            utils.load_ckpt_with_skip(pre_trained_weights, sess,
                                      skip_layer=['global_step ','InceptionV3/AuxLogits/Conv2d_1b_1x1/weights',
                                                  'InceptionV3/AuxLogits/Conv2d_2b_1x1/weights',
                                                  'InceptionV3/Logits/Conv2d_1c_1x1/weights'])

    print('training...')
    start_time = time.process_time()
    for step in np.arange(Flags.max_step):
        tra_images, tra_labels = sess.run([image_batch, label_batch])
        _, tra_loss, tra_acc = sess.run([train_op, loss, accuracy], feed_dict={x: tra_images, y: tra_labels})
        elapsed = round((time.process_time() - start_time) / 60, 2)
        print('Fold: %d, Step: %d, Loss: %.4f, Accuracy: %.4f%%, Time: %.2fmin' % (k + 1, step, tra_loss, tra_acc, elapsed))
        if  (step + 1) == Flags.max_step:
            checkpoint_path = Flags.log_dir + 'fold/{}/'.format(k + 1)
            if not os.path.exists(checkpoint_path):
                os.makedirs(checkpoint_path)
            saver.save(sess, checkpoint_path + 'mode.ckpt', global_step=step + 1)
    sess.close()
Esempio n. 21
0
def train():
    """The main function that runs training"""

    ## data
    image, ih, iw, gt_boxes, gt_masks, num_instances, img_id = \
        datasets.get_dataset(FLAGS.dataset_name,
                             FLAGS.dataset_split_name,
                             FLAGS.dataset_dir,
                             FLAGS.im_batch,
                             is_training=True)

    data_queue = tf.RandomShuffleQueue(
        capacity=32,
        min_after_dequeue=16,
        dtypes=(image.dtype, ih.dtype, iw.dtype, gt_boxes.dtype,
                gt_masks.dtype, num_instances.dtype, img_id.dtype))
    enqueue_op = data_queue.enqueue(
        (image, ih, iw, gt_boxes, gt_masks, num_instances, img_id))
    data_queue_runner = tf.train.QueueRunner(data_queue, [enqueue_op] * 4)
    tf.add_to_collection(tf.GraphKeys.QUEUE_RUNNERS, data_queue_runner)
    (image, ih, iw, gt_boxes, gt_masks, num_instances,
     img_id) = data_queue.dequeue()
    im_shape = tf.shape(image)
    image = tf.reshape(image, (im_shape[0], im_shape[1], im_shape[2], 3))

    ## network
    logits, end_points, pyramid_map = network.get_network(
        FLAGS.network,
        image,
        weight_decay=FLAGS.weight_decay,
        is_training=True)
    outputs = pyramid_network.build(end_points,
                                    im_shape[1],
                                    im_shape[2],
                                    pyramid_map,
                                    num_classes=81,
                                    base_anchors=9,
                                    is_training=True,
                                    gt_boxes=gt_boxes,
                                    gt_masks=gt_masks,
                                    loss_weights=[0.2, 0.2, 1.0, 0.2, 1.0])

    total_loss = outputs['total_loss']
    losses = outputs['losses']
    batch_info = outputs['batch_info']
    regular_loss = tf.add_n(
        tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))

    input_image = end_points['input']
    final_box = outputs['final_boxes']['box']
    final_cls = outputs['final_boxes']['cls']
    final_prob = outputs['final_boxes']['prob']
    final_gt_cls = outputs['final_boxes']['gt_cls']
    gt = outputs['gt']

    #############################
    tmp_0 = outputs['losses']
    tmp_1 = outputs['losses']
    tmp_2 = outputs['losses']
    tmp_3 = outputs['losses']
    tmp_4 = outputs['losses']

    # tmp_0 = outputs['tmp_0']
    # tmp_1 = outputs['tmp_1']
    # tmp_2 = outputs['tmp_2']
    tmp_3 = outputs['tmp_3']
    tmp_4 = outputs['tmp_4']
    ############################

    ## solvers
    global_step = slim.create_global_step()
    update_op = solve(global_step)

    cropped_rois = tf.get_collection('__CROPPED__')[0]
    transposed = tf.get_collection('__TRANSPOSED__')[0]

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.95)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    sess.run(init_op)

    summary_op = tf.summary.merge_all()
    logdir = os.path.join(FLAGS.train_dir, strftime('%Y%m%d%H%M%S', gmtime()))
    if not os.path.exists(logdir):
        os.makedirs(logdir)
    summary_writer = tf.summary.FileWriter(logdir, graph=sess.graph)

    ## restore
    restore(sess)

    ## main loop
    coord = tf.train.Coordinator()
    threads = []
    # print (tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS))
    for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
        threads.extend(
            qr.create_threads(sess, coord=coord, daemon=True, start=True))

    tf.train.start_queue_runners(sess=sess, coord=coord)
    saver = tf.train.Saver(max_to_keep=20)

    for step in range(FLAGS.max_iters):

        start_time = time.time()

        s_, tot_loss, reg_lossnp, img_id_str, \
        rpn_box_loss, rpn_cls_loss, refined_box_loss, refined_cls_loss, mask_loss, \
        gt_boxesnp, \
        rpn_batch_pos, rpn_batch, refine_batch_pos, refine_batch, mask_batch_pos, mask_batch, \
        input_imagenp, final_boxnp, final_clsnp, final_probnp, final_gt_clsnp, gtnp, tmp_0np, tmp_1np, tmp_2np, tmp_3np, tmp_4np= \
                     sess.run([update_op, total_loss, regular_loss, img_id] +
                              losses +
                              [gt_boxes] +
                              batch_info +
                              [input_image] + [final_box] + [final_cls] + [final_prob] + [final_gt_cls] + [gt] + [tmp_0] + [tmp_1] + [tmp_2] + [tmp_3] + [tmp_4])

        duration_time = time.time() - start_time
        if step % 1 == 0:
            print(
                """iter %d: image-id:%07d, time:%.3f(sec), regular_loss: %.6f, """
                """total-loss %.4f(%.4f, %.4f, %.6f, %.4f, %.4f), """
                """instances: %d, """
                """batch:(%d|%d, %d|%d, %d|%d)""" %
                (step, img_id_str, duration_time, reg_lossnp, tot_loss,
                 rpn_box_loss, rpn_cls_loss, refined_box_loss,
                 refined_cls_loss, mask_loss, gt_boxesnp.shape[0],
                 rpn_batch_pos, rpn_batch, refine_batch_pos, refine_batch,
                 mask_batch_pos, mask_batch))

            # draw_bbox(step,
            #           np.uint8((np.array(input_imagenp[0])/2.0+0.5)*255.0),
            #           name='est',
            #           bbox=final_boxnp,
            #           label=final_clsnp,
            #           prob=final_probnp,
            #           gt_label=np.argmax(np.asarray(final_gt_clsnp),axis=1),
            #           )

            # draw_bbox(step,
            #           np.uint8((np.array(input_imagenp[0])/2.0+0.5)*255.0),
            #           name='gt',
            #           bbox=gtnp[:,0:4],
            #           label=np.asarray(gtnp[:,4], dtype=np.uint8),
            #           )

            print("labels")
            # print (cat_id_to_cls_name(np.unique(np.argmax(np.asarray(final_gt_clsnp),axis=1)))[1:])
            # print (cat_id_to_cls_name(np.unique(np.asarray(gt_boxesnp, dtype=np.uint8)[:,4])))
            print(
                cat_id_to_cls_name(
                    np.unique(np.argmax(np.asarray(tmp_3np), axis=1)))[1:])
            #print (cat_id_to_cls_name(np.unique(np.argmax(np.asarray(gt_boxesnp)[:,4],axis=1))))
            print("classes")
            print(
                cat_id_to_cls_name(
                    np.unique(np.argmax(np.array(tmp_4np), axis=1))))
            # print (np.asanyarray(tmp_3np))

            #print ("ordered rois")
            #print (np.asarray(tmp_0np)[0])
            #print ("pyramid_feature")
            #print ()
            #print(np.unique(np.argmax(np.array(final_probnp),axis=1)))
            #for var, val in zip(tmp_2, tmp_2np):
            #    print(var.name)
            #print(np.argmax(np.array(tmp_0np),axis=1))

            if np.isnan(tot_loss) or np.isinf(tot_loss):
                print(gt_boxesnp)
                raise

        if step % 100 == 0:
            summary_str = sess.run(summary_op)
            summary_writer.add_summary(summary_str, step)
            summary_writer.flush()

        if (step % 10000 == 0 or step + 1 == FLAGS.max_iters) and step != 0:
            checkpoint_path = os.path.join(
                FLAGS.train_dir,
                FLAGS.dataset_name + '_' + FLAGS.network + '_model.ckpt')
            saver.save(sess, checkpoint_path, global_step=step)

        if coord.should_stop():
            coord.request_stop()
            coord.join(threads)
Esempio n. 22
0
def train_multigpu(losses, init_fn, hparams):
    """Wraps slim.learning.train to run a training loop.

  Args:
    loss: a loss tensor
    init_fn: A callable to be executed after all other initialization is done.
    hparams: a model hyper parameters
  """
    with tf.device("/cpu:0"):
        global_step = slim.create_global_step()

    with tf.device("/cpu:0"):
        optimizer = create_optimizer(hparams)

    if FLAGS.sync_replicas:
        replica_id = tf.constant(FLAGS.task, tf.int32, shape=())
        optimizer = tf.LegacySyncReplicasOptimizer(
            opt=optimizer,
            replicas_to_aggregate=FLAGS.replicas_to_aggregate,
            replica_id=replica_id,
            total_num_replicas=FLAGS.total_num_replicas)
        sync_optimizer = optimizer
        startup_delay_steps = 0
    else:
        startup_delay_steps = 0
        sync_optimizer = None

    #train_op = slim.learning.create_train_op(
    #    loss,
    #    optimizer,
    #    summarize_gradients=True,
    #    clip_gradient_norm=FLAGS.clip_gradient_norm)
    #with tf.device("/cpu:0"):
    #  tf.summary.scalar('TotalLoss_all', total_loss)
    #  grad = optimizer.compute_gradients(total_loss)
    #with tf.device("/cpu:0"):
    #  with ops.name_scope('summarize_grads'):
    #    add_gradients_summaries(grad)
    #  clipped_grad = tf.contrib.training.clip_gradient_norms(grad, FLAGS.clip_gradient_norm)
    #  update = optimizer.apply_gradients(clipped_grad, global_step=global_step)
    #with tf.control_dependencies([update]):
    #  train_op = tf.identity(total_loss, name='train_op')

    # Gather update_ops from the first clone. These contain, for example,
    # the updates for the batch_norm variables created by network_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, "clone_0")

    grads = []
    total_loss = []
    for loss, i in losses:
        with tf.device("/gpu:{0}".format(i)):
            scaled_loss = tf.div(loss, 1.0 * FLAGS.num_clones)
            if i == 0:
                regularization_loss = tf.add_n(
                    tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
                scaled_loss = scaled_loss + regularization_loss
            total_loss.append(scaled_loss)
            grad = optimizer.compute_gradients(scaled_loss)
            #if i == 0:
            #  with tf.device("/cpu:0"):
            #    with ops.name_scope("summarize_grads_{0}".format(i)):
            #      add_gradients_summaries(grad)
            grads.append(grad)
    total_loss = tf.add_n(total_loss)
    with tf.device("/cpu:0"):
        tf.summary.scalar('Total_Loss', total_loss)
    sum_grad = _sum_clones_gradients(grads)
    clipped_grad = tf.contrib.training.clip_gradient_norms(
        sum_grad, FLAGS.clip_gradient_norm)
    update = optimizer.apply_gradients(clipped_grad, global_step=global_step)
    update_ops.append(update)

    with tf.control_dependencies([tf.group(*update_ops)]):
        train_op = tf.identity(total_loss, name='train_op')

    session_config = tf.ConfigProto()
    session_config.gpu_options.allow_growth = True
    #session_config.log_device_placement = True

    with tf.device("/cpu:0"):
        slim.learning.train(train_op=train_op,
                            logdir=FLAGS.train_log_dir,
                            graph=total_loss.graph,
                            master=FLAGS.master,
                            is_chief=(FLAGS.task == 0),
                            number_of_steps=FLAGS.max_number_of_steps,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            trace_every_n_steps=1000,
                            save_interval_secs=FLAGS.save_interval_secs,
                            startup_delay_steps=startup_delay_steps,
                            sync_optimizer=sync_optimizer,
                            init_fn=init_fn,
                            session_config=session_config)
def main(_):
    if not os.path.isdir(FLAGS.train_dir):
        os.makedirs(FLAGS.train_dir)

    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    if not FLAGS.aug_mode:
        raise ValueError('aug_mode need to be speficied.')

    if (not FLAGS.train_image_height) or (not FLAGS.train_image_width):
        raise ValueError(
            'The image height and width must be define explicitly.')

    if FLAGS.hd_data:
        if FLAGS.train_image_height != 400 or FLAGS.train_image_width != 200:
            FLAGS.train_image_height, FLAGS.train_image_width = 400, 200
            print("set the image size to (%d, %d)" % (400, 200))

    # config and print log
    config_and_print_log(FLAGS)

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        #####################################
        # Select the preprocessing function #
        #####################################
        img_func = get_img_func()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.DataLoader(FLAGS.model_name,
                                             FLAGS.dataset_name,
                                             FLAGS.dataset_dir, FLAGS.set,
                                             FLAGS.hd_data, img_func,
                                             FLAGS.batch_size, FLAGS.batch_k,
                                             FLAGS.max_number_of_steps,
                                             get_pair_type())

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True,
            sample_number=FLAGS.sample_number)

        ####################
        # Define the model #
        ####################
        def clone_fn(tf_batch_queue):
            return build_graph(tf_batch_queue, network_fn)

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [dataset.tf_batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        loss_dict = {}
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            if loss.name == 'softmax_cross_entropy_loss/value:0':
                loss_dict['clf'] = loss
            elif 'softmax_cross_entropy_loss' in loss.name:
                loss_dict['sample_clf_' +
                          str(loss.name.split('/')[0].split('_')[-1])] = loss
            elif 'entropy' in loss.name:
                loss_dict['entropy'] = loss
            else:
                raise Exception('Loss type error')

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step, FLAGS)
            optimizer = _configure_optimizer(learning_rate)

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables,
                replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
                total_num_replicas=FLAGS.worker_replicas)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        # total_loss is the sum of all LOSSES and REGULARIZATION_LOSSES in tf.GraphKeys
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op],
                                                          total_loss,
                                                          name='train_op')

        train_tensor_list = [train_tensor]
        format_str = 'step %d, loss = %.2f'

        for loss_key in sorted(loss_dict.keys()):
            train_tensor_list.append(loss_dict[loss_key])
            format_str += (', %s_loss = ' % loss_key + '%.8f')

        format_str += ' (%.1f examples/sec; %.3f sec/batch)'

        # Create a saver.
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
        checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')

        ###########################
        # Kicks off the training. #
        ###########################
        # Build an initialization operation to run below.
        init = tf.global_variables_initializer()

        # Start running operations on the Graph. allow_soft_placement must be set to
        # True to build towers on GPU, as some of the ops do not have GPU
        # implementations.
        sess = tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        # load pretrained weights
        if FLAGS.checkpoint_path is not None:
            print("Load the pretrained weights")
            weight_ini_fn = _get_init_fn()
            weight_ini_fn(sess)
        else:
            print("Train from the scratch")

        # Start the queue runners.
        tf.train.start_queue_runners(sess=sess)

        # for step in xrange(FLAGS.max_number_of_steps):
        for step in xrange(FLAGS.max_number_of_steps + 1):
            start_time = time.time()
            loss_value_list = sess.run(train_tensor_list,
                                       feed_dict=dataset.get_feed_dict())

            duration = time.time() - start_time

            if step % FLAGS.log_every_n_steps == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = duration

                print(format_str % tuple([step] + loss_value_list +
                                         [examples_per_sec, sec_per_batch]))

            # Save the model checkpoint periodically.
            # if step % FLAGS.model_snapshot_steps == 0 or (step + 1) == FLAGS.max_number_of_steps:
            if step % FLAGS.model_snapshot_steps == 0:
                saver.save(sess, checkpoint_path, global_step=step)

        print('OK...')
Esempio n. 24
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        with tf.device(deploy_config.inputs_device()):
            iterator = coco.get_dataset(FLAGS.train_data_file,
                                        batch_size=FLAGS.batch_size,
                                        num_epochs=500,
                                        buffer_size=250 * FLAGS.num_clones,
                                        num_parallel_calls=4 *
                                        FLAGS.num_clones,
                                        crop_height=FLAGS.height,
                                        crop_width=FLAGS.width,
                                        resize_shape=FLAGS.width,
                                        data_augment=True)

        def clone_fn(iterator):
            with tf.device(deploy_config.inputs_device()):
                batch_image, batch_labels = iterator.get_next()

            s = batch_labels.get_shape().as_list()
            batch_labels.set_shape([FLAGS.batch_size, s[1], s[2], s[3]])

            s = batch_image.get_shape().as_list()
            batch_image.set_shape([FLAGS.batch_size, s[1], s[2], s[3]])

            num_classes = coco.num_classes()

            logits, end_points = resseg_model(
                batch_image,
                FLAGS.height,
                FLAGS.width,
                FLAGS.scale,
                FLAGS.weight_decay,
                FLAGS.use_seperable_convolution,
                num_classes,
                is_training=True,
                use_batch_norm=FLAGS.use_batch_norm,
                num_units=FLAGS.num_units,
                filter_depth_multiplier=FLAGS.filter_depth_multiplier)

            s = logits.get_shape().as_list()
            with tf.device(deploy_config.inputs_device()):
                lmap_size = 256
                lmap = np.array([0] * lmap_size)
                for k, v in coco.id2trainid_objects.items():
                    lmap[k] = v + 1
                lmap = tf.constant(lmap, tf.uint8)
                down_labels = tf.cast(batch_labels, tf.int32)
                label_mask = tf.squeeze((down_labels < 255))
                down_labels = tf.gather(lmap, down_labels)
                down_labels = tf.cast(down_labels, tf.int32)
                down_labels = tf.reshape(
                    down_labels, tf.TensorShape([FLAGS.batch_size, s[1],
                                                 s[2]]))
                down_labels = tf.cast(label_mask, tf.int32) * down_labels

                fg_weights = tf.constant(FLAGS.foreground_weight,
                                         dtype=tf.int32,
                                         shape=label_mask.shape)
                label_weights = tf.cast(label_mask, tf.int32) * fg_weights

            # Specify the loss
            cross_entropy = tf.losses.sparse_softmax_cross_entropy(
                down_labels, logits, weights=label_weights, scope='xentropy')
            tf.losses.add_loss(cross_entropy)

            return end_points, batch_image, down_labels, logits

        # Gather initial summaries
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [iterator])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(
                FLAGS.num_samples_per_epoch, global_step,
                deploy_config.num_clones)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables,
                replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
                total_num_replicas=FLAGS.worker_replicas)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        end_points, batch_image, down_labels, logits = clones[0].outputs

        cmap = np.array(coco.id2color)
        cmap = tf.constant(cmap, tf.uint8)
        seg_map = tf.gather(cmap, down_labels)

        predictions = tf.argmax(logits, axis=3)
        pred_map = tf.gather(cmap, predictions)

        summaries.add(tf.summary.image('labels', seg_map))
        summaries.add(tf.summary.image('predictions', pred_map))
        summaries.add(tf.summary.image('images', batch_image))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        # Returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)

        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op],
                                                          total_loss,
                                                          name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        if FLAGS.sync_replicas:
            sync_optimizer = opt
            startup_delay_steps = 0
        else:
            sync_optimizer = None
            startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps

        ###########################
        # Kick off the training.  #
        ###########################
        slim.learning.train(train_tensor,
                            logdir=FLAGS.train_dir,
                            master=FLAGS.master,
                            is_chief=(FLAGS.task == 0),
                            init_fn=_get_init_fn(),
                            summary_op=summary_op,
                            number_of_steps=FLAGS.max_number_of_steps,
                            log_every_n_steps=FLAGS.log_every_n_steps,
                            startup_delay_steps=startup_delay_steps,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            save_interval_secs=FLAGS.save_interval_secs,
                            sync_optimizer=sync_optimizer)
Esempio n. 25
0
def train():
    """The main function that runs training"""

    ## data
    image, ih, iw, gt_boxes, gt_masks, num_instances, img_id = \
        datasets.get_dataset(FLAGS.dataset_name,
                             FLAGS.dataset_split_name,
                             FLAGS.dataset_dir,
                             FLAGS.im_batch,
                             is_training=True)

    data_queue = tf.RandomShuffleQueue(
        capacity=32,
        min_after_dequeue=16,
        dtypes=(image.dtype, ih.dtype, iw.dtype, gt_boxes.dtype,
                gt_masks.dtype, num_instances.dtype, img_id.dtype))
    enqueue_op = data_queue.enqueue(
        (image, ih, iw, gt_boxes, gt_masks, num_instances, img_id))
    data_queue_runner = tf.train.QueueRunner(data_queue, [enqueue_op] * 4)
    tf.add_to_collection(tf.GraphKeys.QUEUE_RUNNERS, data_queue_runner)
    (image, ih, iw, gt_boxes, gt_masks, num_instances,
     img_id) = data_queue.dequeue()
    im_shape = tf.shape(image)
    image = tf.reshape(image, (im_shape[0], im_shape[1], im_shape[2], 3))

    ## network
    logits, end_points, pyramid_map = network.get_network(
        FLAGS.network, image, weight_decay=FLAGS.weight_decay)
    outputs = pyramid_network.build(end_points,
                                    ih,
                                    iw,
                                    pyramid_map,
                                    num_classes=81,
                                    base_anchors=9,
                                    is_training=True,
                                    gt_boxes=gt_boxes,
                                    gt_masks=gt_masks,
                                    loss_weights=[0.2, 0.2, 1.0, 0.2, 1.0])

    total_loss = outputs['total_loss']
    losses = outputs['losses']
    batch_info = outputs['batch_info']
    regular_loss = tf.add_n(
        tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))

    ## solvers
    global_step = slim.create_global_step()
    update_op = solve(global_step)

    cropped_rois = tf.get_collection('__CROPPED__')[0]
    transposed = tf.get_collection('__TRANSPOSED__')[0]

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    sess.run(init_op)

    summary_op = tf.summary.merge_all()
    logdir = os.path.join(FLAGS.train_dir, strftime('%Y%m%d%H%M%S', gmtime()))
    if not os.path.exists(logdir):
        os.makedirs(logdir)
    summary_writer = tf.summary.FileWriter(logdir, graph=sess.graph)

    ## restore
    restore(sess)

    ## main loop
    coord = tf.train.Coordinator()
    threads = []
    # print (tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS))
    for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
        threads.extend(
            qr.create_threads(sess, coord=coord, daemon=True, start=True))

    tf.train.start_queue_runners(sess=sess, coord=coord)
    saver = tf.train.Saver(max_to_keep=20)

    for step in range(FLAGS.max_iters):

        start_time = time.time()

        s_, tot_loss, reg_lossnp, img_id_str, \
        rpn_box_loss, rpn_cls_loss, refined_box_loss, refined_cls_loss, mask_loss, \
        gt_boxesnp, \
        rpn_batch_pos, rpn_batch, refine_batch_pos, refine_batch, mask_batch_pos, mask_batch = \
                     sess.run([update_op, total_loss, regular_loss, img_id] +
                              losses +
                              [gt_boxes] +
                              batch_info )

        duration_time = time.time() - start_time
        if step % 1 == 0:
            print(
                """iter %d: image-id:%07d, time:%.3f(sec), regular_loss: %.6f, """
                """total-loss %.4f(%.4f, %.4f, %.6f, %.4f, %.4f), """
                """instances: %d, """
                """batch:(%d|%d, %d|%d, %d|%d)""" %
                (step, img_id_str, duration_time, reg_lossnp, tot_loss,
                 rpn_box_loss, rpn_cls_loss, refined_box_loss,
                 refined_cls_loss, mask_loss, gt_boxesnp.shape[0],
                 rpn_batch_pos, rpn_batch, refine_batch_pos, refine_batch,
                 mask_batch_pos, mask_batch))

            if np.isnan(tot_loss) or np.isinf(tot_loss):
                print(gt_boxesnp)
                raise

        if step % 100 == 0:
            summary_str = sess.run(summary_op)
            summary_writer.add_summary(summary_str, step)

        if (step % 10000 == 0 or step + 1 == FLAGS.max_iters) and step != 0:
            checkpoint_path = os.path.join(
                FLAGS.train_dir,
                FLAGS.dataset_name + '_' + FLAGS.network + '_model.ckpt')
            saver.save(sess, checkpoint_path, global_step=step)

        if coord.should_stop():
            coord.request_stop()
            coord.join(threads)
Esempio n. 26
0
def train():
    """The main function that runs training"""

    ## data
    image, ih, iw, gt_boxes, gt_masks, num_instances, img_id = \
        datasets.get_dataset(FLAGS.dataset_name, 
                             FLAGS.dataset_split_name, 
                             FLAGS.dataset_dir, 
                             FLAGS.im_batch,
                             is_training=True)
    
    data_queue = tf.RandomShuffleQueue(capacity=32, min_after_dequeue=16,
            dtypes=(
                image.dtype, ih.dtype, iw.dtype, 
                gt_boxes.dtype, gt_masks.dtype, 
                num_instances.dtype, img_id.dtype)) 
    enqueue_op = data_queue.enqueue((image, ih, iw, gt_boxes, gt_masks, num_instances, img_id))
    data_queue_runner = tf.train.QueueRunner(data_queue, [enqueue_op] * 4)
    tf.add_to_collection(tf.GraphKeys.QUEUE_RUNNERS, data_queue_runner)
    (image, ih, iw, gt_boxes, gt_masks, num_instances, img_id) =  data_queue.dequeue()
    im_shape = tf.shape(image)
    image = tf.reshape(image, (im_shape[0], im_shape[1], im_shape[2], 3))

    ## network
    logits, end_points, pyramid_map = network.get_network(FLAGS.network, image,
            weight_decay=FLAGS.weight_decay, is_training=True)
    outputs = pyramid_network.build(end_points, im_shape[1], im_shape[2], pyramid_map,
            num_classes=81,
            base_anchors=9,
            is_training=True,
            gt_boxes=gt_boxes, gt_masks=gt_masks,
            loss_weights=[0.2, 0.2, 1.0, 0.2, 1.0])


    total_loss = outputs['total_loss'] 
    losses  = outputs['losses']
    batch_info = outputs['batch_info']
    regular_loss = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))

    ## solvers
    global_step = slim.create_global_step()
    update_op = solve(global_step)

    cropped_rois = tf.get_collection('__CROPPED__')[0]
    transposed = tf.get_collection('__TRANSPOSED__')[0]
    
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
    init_op = tf.group(
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
            )
    sess.run(init_op)

    summary_op = tf.summary.merge_all()
    logdir = os.path.join(FLAGS.train_dir, strftime('%Y%m%d%H%M%S', gmtime()))
    if not os.path.exists(logdir):
        os.makedirs(logdir)
    summary_writer = tf.summary.FileWriter(logdir, graph=sess.graph)

    ## restore
    restore(sess)

    ## main loop
    coord = tf.train.Coordinator()
    threads = []
    # print (tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS))
    for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
        threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
                                         start=True))

    tf.train.start_queue_runners(sess=sess, coord=coord)
    saver = tf.train.Saver(max_to_keep=20)

    for step in range(FLAGS.max_iters):
        
        start_time = time.time()

        s_, tot_loss, reg_lossnp, img_id_str, \
        rpn_box_loss, rpn_cls_loss, refined_box_loss, refined_cls_loss, mask_loss, \
        gt_boxesnp, \
        rpn_batch_pos, rpn_batch, refine_batch_pos, refine_batch, mask_batch_pos, mask_batch = \
                     sess.run([update_op, total_loss, regular_loss, img_id] + 
                              losses + 
                              [gt_boxes] + 
                              batch_info )

        duration_time = time.time() - start_time
        if step % 1 == 0: 
            print ( """iter %d: image-id:%07d, time:%.3f(sec), regular_loss: %.6f, """
                    """total-loss %.4f(%.4f, %.4f, %.6f, %.4f, %.4f), """
                    """instances: %d, """
                    """batch:(%d|%d, %d|%d, %d|%d)""" 
                   % (step, img_id_str, duration_time, reg_lossnp, 
                      tot_loss, rpn_box_loss, rpn_cls_loss, refined_box_loss, refined_cls_loss, mask_loss,
                      gt_boxesnp.shape[0], 
                      rpn_batch_pos, rpn_batch, refine_batch_pos, refine_batch, mask_batch_pos, mask_batch))

            if np.isnan(tot_loss) or np.isinf(tot_loss):
                print (gt_boxesnp)
                raise
          
        if step % 100 == 0:
            summary_str = sess.run(summary_op)
            summary_writer.add_summary(summary_str, step)

        if (step % 10000 == 0 or step + 1 == FLAGS.max_iters) and step != 0:
            checkpoint_path = os.path.join(FLAGS.train_dir, 
                                           FLAGS.dataset_name + '_' + FLAGS.network + '_model.ckpt')
            saver.save(sess, checkpoint_path, global_step=step)

        if coord.should_stop():
            coord.request_stop()
            coord.join(threads)
Esempio n. 27
0
def main(_):
  if not FLAGS.dataset_dir:
    raise ValueError('You must supply the dataset directory with --dataset_dir')

  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    #######################
    # Config model_deploy #
    #######################
    deploy_config = model_deploy.DeploymentConfig(
      num_clones=FLAGS.num_clones,
      clone_on_cpu=FLAGS.clone_on_cpu,
      replica_id=FLAGS.task,
      num_replicas=FLAGS.worker_replicas,
      num_ps_tasks=FLAGS.num_ps_tasks)

    # Create global_step
    with tf.device(deploy_config.variables_device()):
      global_step = slim.create_global_step()

    ######################
    # Select the dataset #
    ######################
    dataset = dataset_factory.get_dataset(
      FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

    ######################
    # Select the network #
    ######################
    network_fn = nets_factory.get_network_fn(
      FLAGS.model_name,
      num_classes=(dataset.num_classes - FLAGS.labels_offset),
      weight_decay=FLAGS.weight_decay,
      is_training=True,
      width_multiplier=FLAGS.width_multiplier)

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
      preprocessing_name,
      is_training=True)

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    with tf.device(deploy_config.inputs_device()):
      provider = slim.dataset_data_provider.DatasetDataProvider(
        dataset,
        num_readers=FLAGS.num_readers,
        common_queue_capacity=20 * FLAGS.batch_size,
        common_queue_min=10 * FLAGS.batch_size)

      # gt_bboxes format [ymin, xmin, ymax, xmax]
      [image, img_shape, gt_labels, gt_bboxes] = provider.get(['image', 'shape',
                                                               'object/label',
                                                               'object/bbox'])

      # Preprocesing
      # gt_bboxes = scale_bboxes(gt_bboxes, img_shape)  # bboxes format [0,1) for tf draw

      image, gt_labels, gt_bboxes = image_preprocessing_fn(image,
                                                           config.IMG_HEIGHT,
                                                           config.IMG_WIDTH,
                                                           labels=gt_labels,
                                                           bboxes=gt_bboxes,
                                                           )

      #############################################
      # Encode annotations for losses computation #
      #############################################

      # anchors format [cx, cy, w, h]
      anchors = tf.convert_to_tensor(config.ANCHOR_SHAPE, dtype=tf.float32)

      # encode annos, box_input format [cx, cy, w, h]
      input_mask, labels_input, box_delta_input, box_input = encode_annos(gt_labels,
                                                                          gt_bboxes,
                                                                          anchors,
                                                                          config.NUM_CLASSES)

      images, b_input_mask, b_labels_input, b_box_delta_input, b_box_input = tf.train.batch(
        [image, input_mask, labels_input, box_delta_input, box_input],
        batch_size=FLAGS.batch_size,
        num_threads=FLAGS.num_preprocessing_threads,
        capacity=5 * FLAGS.batch_size)

      batch_queue = slim.prefetch_queue.prefetch_queue(
        [images, b_input_mask, b_labels_input, b_box_delta_input, b_box_input], capacity=2 * deploy_config.num_clones)

    ####################
    # Define the model #
    ####################
    def clone_fn(batch_queue):
      """Allows data parallelism by creating multiple clones of network_fn."""
      images, b_input_mask, b_labels_input, b_box_delta_input, b_box_input = batch_queue.dequeue()
      anchors = tf.convert_to_tensor(config.ANCHOR_SHAPE, dtype=tf.float32)
      end_points = network_fn(images)
      end_points["viz_images"] = images
      conv_ds_14 = end_points['MobileNet/conv_ds_14/depthwise_conv']
      dropout = slim.dropout(conv_ds_14, keep_prob=0.5, is_training=True)
      num_output = config.NUM_ANCHORS * (config.NUM_CLASSES + 1 + 4)
      predict = slim.conv2d(dropout, num_output, kernel_size=(3, 3), stride=1, padding='SAME',
                            activation_fn=None,
                            weights_initializer=tf.truncated_normal_initializer(stddev=0.0001),
                            scope="MobileNet/conv_predict")

      with tf.name_scope("Interpre_prediction") as scope:
        pred_box_delta, pred_class_probs, pred_conf, ious, det_probs, det_boxes, det_class = \
          interpre_prediction(predict, b_input_mask, anchors, b_box_input)
        end_points["viz_det_probs"] = det_probs
        end_points["viz_det_boxes"] = det_boxes
        end_points["viz_det_class"] = det_class

      with tf.name_scope("Losses") as scope:
        losses(b_input_mask, b_labels_input, ious, b_box_delta_input, pred_class_probs, pred_conf, pred_box_delta)

      return end_points

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
    first_clone_scope = deploy_config.clone_scope(0)
    # Gather update_ops from the first clone. These contain, for example,
    # the updates for the batch_norm variables created by network_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Add summaries for end_points.
    end_points = clones[0].outputs
    for end_point in end_points:
      if end_point not in ["viz_images", "viz_det_probs", "viz_det_boxes", "viz_det_class"]:
        x = end_points[end_point]
        summaries.add(tf.summary.histogram('activations/' + end_point, x))
        summaries.add(tf.summary.scalar('sparsity/' + end_point,
                                        tf.nn.zero_fraction(x)))

    # Add summaries for det result TODO(shizehao): vizulize prediction


    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Add summaries for variables.
    for variable in slim.get_model_variables():
      summaries.add(tf.summary.histogram(variable.op.name, variable))

    #################################
    # Configure the moving averages #
    #################################
    if FLAGS.moving_average_decay:
      moving_average_variables = slim.get_model_variables()
      variable_averages = tf.train.ExponentialMovingAverage(
        FLAGS.moving_average_decay, global_step)
    else:
      moving_average_variables, variable_averages = None, None

    #########################################
    # Configure the optimization procedure. #
    #########################################
    with tf.device(deploy_config.optimizer_device()):
      learning_rate = _configure_learning_rate(dataset.num_samples, global_step)
      optimizer = _configure_optimizer(learning_rate)
      summaries.add(tf.summary.scalar('learning_rate', learning_rate))

    if FLAGS.sync_replicas:
      # If sync_replicas is enabled, the averaging will be done in the chief
      # queue runner.
      optimizer = tf.train.SyncReplicasOptimizer(
        opt=optimizer,
        replicas_to_aggregate=FLAGS.replicas_to_aggregate,
        variable_averages=variable_averages,
        variables_to_average=moving_average_variables,
        replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
        total_num_replicas=FLAGS.worker_replicas)
    elif FLAGS.moving_average_decay:
      # Update ops executed locally by trainer.
      update_ops.append(variable_averages.apply(moving_average_variables))

    # Variables to train.
    variables_to_train = _get_variables_to_train()

    #  and returns a train_tensor and summary_op
    total_loss, clones_gradients = model_deploy.optimize_clones(
      clones,
      optimizer,
      var_list=variables_to_train)
    # Add total_loss to summary.
    summaries.add(tf.summary.scalar('total_loss', total_loss))

    # Create gradient updates.
    grad_updates = optimizer.apply_gradients(clones_gradients,
                                             global_step=global_step)
    update_ops.append(grad_updates)

    update_op = tf.group(*update_ops)
    train_tensor = control_flow_ops.with_dependencies([update_op], total_loss,
                                                      name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                       first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries), name='summary_op')

    ###########################
    # Kicks off the training. #
    ###########################
    slim.learning.train(
      train_tensor,
      logdir=FLAGS.train_dir,
      master=FLAGS.master,
      is_chief=(FLAGS.task == 0),
      init_fn=_get_init_fn(),
      summary_op=summary_op,
      number_of_steps=FLAGS.max_number_of_steps,
      log_every_n_steps=FLAGS.log_every_n_steps,
      save_summaries_secs=FLAGS.save_summaries_secs,
      save_interval_secs=FLAGS.save_interval_secs,
      sync_optimizer=optimizer if FLAGS.sync_replicas else None)
Esempio n. 28
0
def main(_):
    # for name, value in FLAGS.__flags.items():
    #     print(name, ': ', value.value)
    if not FLAGS.dataset_dir:
        raise ValueError('Directory of dataset is not found')
    tf.logging.set_verbosity(tf.logging.DEBUG)
    tf.logging.debug("hahahaha %s" %FLAGS.dataset_dir )
    with tf.Graph().as_default():

        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=0,
            num_replicas=1,
            num_ps_tasks=0
        )
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        dataset = dataset_factory.get_dataset(
            FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir
        )

        #SSD Net and anchors
        ssd_class = ssd_vgg.SSDNet
        ssd_params = ssd_class.default_parameters._replace(num_classes=FLAGS.num_classes)
        print("Class numbers", ssd_params.num_classes)
        ssd_net = ssd_class(ssd_params)
        ssd_shape = ssd_net.params.img_shape
        ssd_anchors = ssd_net.anchors(ssd_shape)

        # Preprocessing function
        image_preprocessing_fun = ssd_vgg_preprocessing.preprocess_image        # Need is_train = True
        tf_utils.print_configs(FLAGS.__flags, ssd_params, dataset.data_sources, FLAGS.train_dir)

        #--------------------------------------------
        #         Data provider and batches
        # -------------------------------------------
        with tf.device(deploy_config.input_device()):
            with tf.name_scope(FLAGS.dataset_name + '_data_provider'):
                provider = slim.dataset_data_provider.DatasetDataProvider(
                    dataset,
                    num_readers=FLAGS.num_readers,
                    common_queue_capacity=20*FLAGS.batch_size,
                    common_queue_min=10*FLAGS.batch_size,
                    shuffle=True
                )
            [image, shape, gtlabels, gtbboxes] = provider.get(['image',
                                                               'shape',
                                                               'object/label',
                                                               'object/bbox'])
            # Pre-processing image, labels and bboxes
            image, gtlabels, gtbboxes = image_preprocessing_fun(image, gtlabels,
                                                                gtbboxes, out_shape=ssd_shape,
                                                                data_format=DATA_FORMAT,
                                                                is_training=True)
            # Encode groundtruth labels and bboxes
            gtclasses, gtlocations, gtscores = ssd_net.bboxes_encode(gtlabels,
                                                                     gtbboxes,
                                                                     ssd_anchors)
            batch_shape = [1] + [len(ssd_anchors)] * 3          #[1, len, len, len]

            # Training batch and queue
            r = tf.train.batch(
                tf_utils.reshape_list([image, gtclasses, gtlocations, gtscores]),
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size
            )
            b_image, b_gtclasses, b_gtlocations, b_gtscores = \
                tf_utils.reshape_list(r, batch_shape)           #[1*image, N*gtclasses, N*gtlocations, N*gtstores]

            # Intermediate queueing: unique batch computation pipeline for all
            # GPUs running the training.
            batch_queue = slim.prefetch_queue.prefetch_queue(
                tf_utils.reshape_list([b_image, b_gtclasses, b_gtlocations, b_gtscores]),
                capacity=2 * deploy_config.num_clones
            )

        #--------------------------------------------------
        #                 Clone on every GPU
        #--------------------------------------------------
        def clone_fn(batch_queue):
            b_image, b_gtclasses, b_gtlocations, b_gtscores = \
                tf_utils.reshape_list(batch_queue.dequeue(), batch_shape)

            arg_scope = ssd_net.arg_scope(weight_decay=FLAGS.weight_decay,
                                          data_format=DATA_FORMAT)
            with slim.arg_scope(arg_scope):
                prediction, location, logits, end_points = \
                    ssd_net.net(b_image, is_training=True)

            ssd_net.losses(logits, location,
                           b_gtclasses, b_gtlocations, b_gtscores,
                           match_threshold=FLAGS.match_threshold,
                           negative_ratio=FLAGS.negative_ratio,
                           alpha=FLAGS.loss_alpha,
                           label_smoothing=FLAGS.label_smoothing)
            return end_points

        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # ---------------------------------------------------------
        #                 Summary for first clone
        # ---------------------------------------------------------
        clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(tf.summary.scalar('sparsity/' + end_point,
                                            tf.nn.zero_fraction(x)))        # 零元素所占比例
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar(loss.op.name, loss))
        for loss in tf.get_collection('EXTRA_LOSSES', first_clone_scope):
            summaries.add(tf.summary.scalar(loss.op.name, loss))

        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step
            )
        else:
            moving_average_variables, variable_averages = None, None

        # ----------------------------------------------
        #              Optimization procedure
        # ----------------------------------------------
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = tf_utils.configure_learning_rate(FLAGS,
                                                             dataset.num_samples,
                                                             global_step)
            optimizer = tf_utils.configure_optimizer(FLAGS, learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

            if FLAGS.moving_average_decay:
                update_ops.append(variable_averages.apply(moving_average_variables))

            variables_to_train = tf_utils.get_variables_to_train(FLAGS)

            total_loss, clones_gradients = model_deploy.optimize_clones(
                clones,
                optimizer,
                var_list=variables_to_train
            )

            summaries.add(tf.summary.scalar('total_loss', total_loss))

            grad_updates = optimizer.apply_gradients(clones_gradients,
                                                     global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops)               # *Operations, not lists
            # All the operations needed to be ran prior
            train_tensor = control_flow_ops.with_dependencies([update_op], total_loss,
                                                              name='train_op')

            summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                               first_clone_scope))

            summary_op = tf.summary.merge(list(summaries), name='summary_op')

            # ---------------------------------------------
            #               Now let's start!
            # ---------------------------------------------
            gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction,
                                        allow_growth=True)
            config = tf.ConfigProto(log_device_placement=False,
                                    gpu_options=gpu_options)
            saver = tf.train.Saver(max_to_keep=5,
                                   keep_checkpoint_every_n_hours=1.0,
                                   write_version=2,
                                   pad_step_number=False,
                                   name='Model_ssd_vgg')


            def train_step_fn(session, *args, **kwargs):
                total_loss, should_stop = slim.learning.train_step(session, *args, *kwargs)
                if train_step_fn.step % 2 == 0:
                    print('step: %s || loss: %f || gradient: '
                          %(str(train_step_fn.step), total_loss))

                train_step_fn.step += 1
                return [total_loss, should_stop]

            train_step_fn.step = 0


            slim.learning.train(
                train_tensor,
                logdir=FLAGS.train_dir,
                master='',
                is_chief=True,
                init_fn=tf_utils.get_init_fn(FLAGS),
                summary_op=summary_op,
                number_of_steps=FLAGS.max_number_of_steps,
                log_every_n_steps=FLAGS.log_every_n_steps,
                save_summaries_secs=FLAGS.save_summaries_secs,      # Save summaries in seconds
                saver=saver,
                save_interval_secs=FLAGS.save_interval_secs,        # Save checkpoints in seconds
                session_config=config,
                sync_optimizer=None,
                train_step_fn=train_step_fn
            )