def get_model(images, fids, pids):
        global endpoints, body_prefix
        global dists, neg_dists, pos_dists
        global losses, train_top1, prec_at_k, num_active
        # Feed the image through the model. The returned `body_prefix` will be used
        # further down to load the pre-trained weights for all variables with this
        # prefix.
        endpoints, body_prefix = model.endpoints(images, is_training=True)
        with tf.name_scope('head'):
            endpoints = head.head(endpoints,
                                  args.embedding_dim,
                                  is_training=True)

        # Create the loss in two steps:
        # 1. Compute all pairwise distances according to the specified metric.
        # 2. For each anchor along the first dimension, compute its loss.
        dists = loss.cdist(endpoints['emb'],
                           endpoints['emb'],
                           metric=args.metric)
        dists_origin = loss.cdist(endpoints['emb_joint'],
                                  endpoints['emb_joint'],
                                  metric=args.metric)
        losses, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[
            args.loss](dists,
                       dists_origin,
                       pids,
                       args.margin,
                       batch_precision_at_k=args.batch_k - 1)

        # Count the number of active entries, and compute the total batch loss.
        num_active = tf.reduce_sum(
            tf.cast(tf.greater(losses, 1e-5), tf.float32))
        loss_mean = tf.reduce_mean(losses, keep_dims=True)

        return loss_mean
예제 #2
0
 def calculate_distances(self, Q, G):
     metric = 'euclidean'
     batch_embs = tf.data.Dataset.from_tensor_slices(
     (Q)).batch(self.batch_size).make_one_shot_iterator().get_next()
     batch_distances = loss.cdist(batch_embs, G , metric=metric)
     distances = np.zeros((len(Q), len(G)), np.float32)
     with tf.Session() as sess:
         for start_idx in count(step=self.batch_size):
             try:
                 dist = sess.run(batch_distances)
                 distances[start_idx:start_idx + len(dist)] = dist
             except tf.errors.OutOfRangeError:
                 print()  # Done!
                 break
     print(distances.shape)
     return distances
예제 #3
0
def main():
    args = parser.parse_args()

    # We store all arguments in a json file. This has two advantages:
    # 1. We can always get back and see what exactly that experiment was
    # 2. We can resume an experiment as-is without needing to remember all flags.
    args_file = os.path.join(args.experiment_root, 'args.json')
    if args.resume:
        if not os.path.isfile(args_file):
            raise IOError('`args.json` not found in {}'.format(args_file))

        print('Loading args from {}.'.format(args_file))
        with open(args_file, 'r') as f:
            args_resumed = json.load(f)
        args_resumed['resume'] = True  # This would be overwritten.

        # When resuming, we not only want to populate the args object with the
        # values from the file, but we also want to check for some possible
        # conflicts between loaded and given arguments.
        for key, value in args.__dict__.items():
            if key in args_resumed:
                resumed_value = args_resumed[key]
                if resumed_value != value:
                    print('Warning: For the argument `{}` we are using the'
                          ' loaded value `{}`. The provided value was `{}`'
                          '.'.format(key, resumed_value, value))
                    comand = input('Would you like to restore it?(yes/no)')
                    if comand == 'yes':
                        args.__dict__[key] = resumed_value
                        print(
                            'For the argument `{}` we are using the loaded value `{}`.'
                            .format(key, args.__dict__[key]))
                    else:
                        print(
                            'For the argument `{}` we are using the provided value `{}`.'
                            .format(key, args.__dict__[key]))
            else:
                print('Warning: A new argument was added since the last run:'
                      ' `{}`. Using the new value: `{}`.'.format(key, value))
        os.remove(args_file)
        with open(args_file, 'w') as f:
            json.dump(vars(args),
                      f,
                      ensure_ascii=False,
                      indent=2,
                      sort_keys=True)

    else:
        # If the experiment directory exists already, we bail in fear.
        if os.path.exists(args.experiment_root):
            if os.listdir(args.experiment_root):
                print('The directory {} already exists and is not empty.'
                      ' If you want to resume training, append --resume to'
                      ' your call.'.format(args.experiment_root))
                exit(1)
        else:
            os.makedirs(args.experiment_root)

        # Store the passed arguments for later resuming and grepping in a nice
        # and readable format.
        with open(args_file, 'w') as f:
            json.dump(vars(args),
                      f,
                      ensure_ascii=False,
                      indent=2,
                      sort_keys=True)

    log_file = os.path.join(args.experiment_root, "train")
    logging.config.dictConfig(common.get_logging_dict(log_file))
    log = logging.getLogger('train')

    # Also show all parameter values at the start, for ease of reading logs.
    log.info('Training using the following parameters:')
    for key, value in sorted(vars(args).items()):
        log.info('{}: {}'.format(key, value))

    # Check them here, so they are not required when --resume-ing.
    if not args.train_set:
        parser.print_help()
        log.error("You did not specify the `train_set` argument!")
        sys.exit(1)
    if not args.image_root:
        parser.print_help()
        log.error("You did not specify the required `image_root` argument!")
        sys.exit(1)

######################################################################################
#prepare the training dataset
# Load the data from the TxT file. see Common.load_dataset function for details
    pids_train, fids_train = common.load_dataset(args.train_set,
                                                 args.image_root)
    max_fid_len = max(map(len,
                          fids_train))  # We'll need this later for logfiles.

    # Setup a tf.Dataset where one "epoch" loops over all PIDS.
    # PIDS are shuffled after every epoch and continue indefinitely.
    unique_pids = np.unique(pids_train)
    dataset = tf.data.Dataset.from_tensor_slices(unique_pids)
    dataset = dataset.shuffle(len(unique_pids))

    # Constrain the dataset size to a multiple of the batch-size, so that
    # we don't get overlap at the end of each epoch.
    dataset = dataset.take((len(unique_pids) // args.batch_p) * args.batch_p)
    dataset = dataset.repeat(None)  # Repeat forever. Funny way of stating it.

    # For every PID, get K images.
    dataset = dataset.map(lambda pid: sample_k_fids_for_pid(
        pid, all_fids=fids_train, all_pids=pids_train, batch_k=args.batch_k
    ))  # now the dataset has been modified as [selected_fids
    # , pid] due to the return of the function 'sample_k_fids_for_pid'

    # Ungroup/flatten the batches for easy loading of the files.
    dataset = dataset.apply(tf.contrib.data.unbatch())

    # Convert filenames to actual image tensors.
    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)
    dataset = dataset.map(
        lambda fid, pid: common.fid_to_image(fid,
                                             pid,
                                             image_root=args.image_root,
                                             image_size=pre_crop_size if args.
                                             crop_augment else net_input_size),
        num_parallel_calls=args.loading_threads
    )  # now the dataset has been modified as [selected_images
    # , fid, pid] due to the return of the function 'fid_to_image'

    # Augment the data if specified by the arguments.
    if args.flip_augment:
        dataset = dataset.map(lambda im, fid, pid:
                              (tf.image.random_flip_left_right(im), fid, pid))
    if args.crop_augment:
        dataset = dataset.map(lambda im, fid, pid: (tf.random_crop(
            im, net_input_size + (3, )), fid, pid))

    # Group it back into PK batches.
    batch_size = args.batch_p * args.batch_k
    dataset = dataset.batch(batch_size)

    # Overlap producing and consuming for parallelism.
    dataset = dataset.prefetch(1)

    # Since we repeat the data infinitely, we only need a one-shot iterator.
    images_train, fids_train, pids_train = dataset.make_one_shot_iterator(
    ).get_next()
    ########################################################################################################################
    #prepare the validation set
    pids_val, fids_val = common.load_dataset(args.validation_set,
                                             args.validation_image_root)
    # Setup a tf.Dataset where one "epoch" loops over all PIDS.
    # PIDS are shuffled after every epoch and continue indefinitely.
    unique_pids_val = np.unique(pids_val)
    dataset_val = tf.data.Dataset.from_tensor_slices(unique_pids_val)
    dataset_val = dataset_val.shuffle(len(unique_pids_val))

    # Constrain the dataset size to a multiple of the batch-size, so that
    # we don't get overlap at the end of each epoch.
    dataset_val = dataset_val.take(
        (len(unique_pids_val) // args.batch_p) * args.batch_p)
    dataset_val = dataset_val.repeat(
        None)  # Repeat forever. Funny way of stating it.

    # For every PID, get K images.
    dataset_val = dataset_val.map(lambda pid: sample_k_fids_for_pid(
        pid, all_fids=fids_val, all_pids=pids_val, batch_k=args.batch_k
    ))  # now the dataset has been modified as [selected_fids
    # , pid] due to the return of the function 'sample_k_fids_for_pid'

    # Ungroup/flatten the batches for easy loading of the files.
    dataset_val = dataset_val.apply(tf.contrib.data.unbatch())

    # Convert filenames to actual image tensors.
    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)
    dataset_val = dataset_val.map(
        lambda fid, pid: common.fid_to_image(
            fid,
            pid,
            image_root=args.validation_image_root,
            image_size=pre_crop_size if args.crop_augment else net_input_size),
        num_parallel_calls=args.loading_threads
    )  # now the dataset has been modified as [selected_images
    # , fid, pid] due to the return of the function 'fid_to_image'

    # Augment the data if specified by the arguments.
    if args.flip_augment:
        dataset_val = dataset_val.map(lambda im, fid, pid: (
            tf.image.random_flip_left_right(im), fid, pid))
    if args.crop_augment:
        dataset_val = dataset_val.map(lambda im, fid, pid: (tf.random_crop(
            im, net_input_size + (3, )), fid, pid))

    # Group it back into PK batches.
    dataset_val = dataset_val.batch(batch_size)

    # Overlap producing and consuming for parallelism.
    dataset_val = dataset_val.prefetch(1)

    # Since we repeat the data infinitely, we only need a one-shot iterator.
    images_val, fids_val, pids_val = dataset_val.make_one_shot_iterator(
    ).get_next()
    ####################################################################################################################
    # Create the model and an embedding head.
    model = import_module('nets.' + args.model_name)
    head = import_module('heads.' + args.head_name)

    # Feed the image through the model. The returned `body_prefix` will be used
    # further down to load the pre-trained weights for all variables with this
    # prefix.
    input_images = tf.placeholder(
        dtype=tf.float32,
        shape=[None, args.net_input_height, args.net_input_width, 3],
        name='input')
    pids = tf.placeholder(dtype=tf.string, shape=[
        None,
    ], name='pids')
    fids = tf.placeholder(dtype=tf.string, shape=[
        None,
    ], name='fids')

    endpoints, body_prefix = model.endpoints(input_images, is_training=True)
    with tf.name_scope('head'):
        endpoints = head.head(endpoints, args.embedding_dim, is_training=True)

    # Create the loss in two steps:
    # 1. Compute all pairwise distances according to the specified metric.
    # 2. For each anchor along the first dimension, compute its loss.
    # dists = loss.cdist(endpoints['emb'], endpoints['emb'], metric=args.metric)
    # losses, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[args.loss](
    #     dists, pids, args.margin, batch_precision_at_k=args.batch_k-1)
    # # '_' stands for the boolean matrix shows topK where the correct match of the identities occurs
    # shape=(batch_size,K)


# 更改
# loss1
    dists1 = loss.cdist(endpoints['feature1'],
                        endpoints['feature1'],
                        metric=args.metric)
    losses1, _, _, _, _, _ = loss.LOSS_CHOICES[args.loss](
        dists1, pids, args.margin, batch_precision_at_k=args.batch_k - 1)
    dists2 = loss.cdist(endpoints['feature2'],
                        endpoints['feature2'],
                        metric=args.metric)
    losses2, _, _, _, _, _ = loss.LOSS_CHOICES[args.loss](
        dists2, pids, args.margin, batch_precision_at_k=args.batch_k - 1)
    dists3 = loss.cdist(endpoints['feature3'],
                        endpoints['feature3'],
                        metric=args.metric)
    losses3, _, _, _, _, _ = loss.LOSS_CHOICES[args.loss](
        dists3, pids, args.margin, batch_precision_at_k=args.batch_k - 1)
    dists4 = loss.cdist(endpoints['feature4'],
                        endpoints['feature4'],
                        metric=args.metric)
    losses4, _, _, _, _, _ = loss.LOSS_CHOICES[args.loss](
        dists4, pids, args.margin, batch_precision_at_k=args.batch_k - 1)
    dists_fu = loss.cdist(endpoints['fusion_layer'],
                          endpoints['fusion_layer'],
                          metric=args.metric)
    losses_fu, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[
        args.loss](dists_fu,
                   pids,
                   args.margin,
                   batch_precision_at_k=args.batch_k - 1)

    losses = losses1 + losses2 + losses3 + losses4 + losses_fu

    # 更改
    #loss
    # losses_fu, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[args.loss](
    #     endpoints,pids, model_type=args.model_name, metric=args.metric, batch_precision_at_k=args.batch_k - 1
    # )

    # Count the number of active entries, and compute the total batch loss.
    num_active = tf.reduce_sum(tf.cast(tf.greater(losses, 1e-5), tf.float32))

    # 此处losses即为 pospair 比 negpair+margin 还大的部分
    loss_mean = tf.reduce_mean(losses)

    # Some logging for tensorboard.
    tf.summary.histogram('loss_distribution', losses)
    tf.summary.scalar('loss', loss_mean)
    tf.summary.scalar('batch_top1', train_top1)
    tf.summary.scalar('batch_prec_at_{}'.format(args.batch_k - 1), prec_at_k)
    tf.summary.scalar('active_count', num_active)
    #tf.summary.histogram('embedding_dists', dists)
    tf.summary.histogram('embedding_pos_dists', pos_dists)
    tf.summary.histogram('embedding_neg_dists', neg_dists)
    tf.summary.histogram('embedding_lengths',
                         tf.norm(endpoints['emb_raw'], axis=1))

    # Create the mem-mapped arrays in which we'll log all training detail in
    # addition to tensorboard, because tensorboard is annoying for detailed
    # inspection and actually discards data in histogram summaries.
    if args.detailed_logs:
        log_embs = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'embeddings'),
            dtype=np.float32,
            shape=(args.train_iterations, batch_size, args.embedding_dim))
        log_loss = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'losses'),
            dtype=np.float32,
            shape=(args.train_iterations, batch_size))
        log_fids = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'fids'),
            dtype='S' + str(max_fid_len),
            shape=(args.train_iterations, batch_size))

    # These are collected here before we add the optimizer, because depending
    # on the optimizer, it might add extra slots, which are also global
    # variables, with the exact same prefix.
    model_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                        body_prefix)

    # Define the optimizer and the learning-rate schedule.
    # Unfortunately, we get NaNs if we don't handle no-decay separately.
    global_step = tf.Variable(
        0, name='global_step',
        trainable=False)  # 'global_step' means the number of batches seen
    #  by graph
    if 0 <= args.decay_start_iteration < args.train_iterations:
        learning_rate = tf.train.exponential_decay(
            args.learning_rate,
            tf.maximum(0, global_step - args.decay_start_iteration
                       ),  # decay every 'lr_decay_steps' after the
            # 'decay_start_iteration'
            # args.train_iterations - args.decay_start_iteration, args.weight_decay_factor)
            args.lr_decay_steps,
            args.lr_decay_factor,
            staircase=True)
    else:
        learning_rate = args.learning_rate  # the case when we set 'decay_start_iteration' as -1
    tf.summary.scalar('learning_rate', learning_rate)
    optimizer = tf.train.AdamOptimizer(learning_rate, epsilon=1e-3)
    # Feel free to try others!
    # optimizer = tf.train.AdadeltaOptimizer(learning_rate)

    # Update_ops are used to update batchnorm stats.
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        train_op = optimizer.minimize(loss_mean, global_step=global_step)

    # Define a saver for the complete model.
    checkpoint_saver = tf.train.Saver(max_to_keep=0)

    with tf.Session(config=config) as sess:
        if args.resume:
            # In case we're resuming, simply load the full checkpoint to init.
            if args.checkpoint is None:
                last_checkpoint = tf.train.latest_checkpoint(
                    args.experiment_root)
                log.info(
                    'Restoring from checkpoint: {}'.format(last_checkpoint))
                checkpoint_saver.restore(sess, last_checkpoint)
            else:
                ckpt_path = os.path.join(args.experiment_root, args.checkpoint)
                log.info('Restoring from checkpoint: {}'.format(
                    args.checkpoint))
                checkpoint_saver.restore(sess, ckpt_path)
        else:
            # But if we're starting from scratch, we may need to load some
            # variables from the pre-trained weights, and random init others.
            sess.run(tf.global_variables_initializer())
            if args.initial_checkpoint is not None:
                saver = tf.train.Saver(model_variables)
                saver.restore(
                    sess, args.initial_checkpoint
                )  # restore the pre-trained parameter from online model

            # In any case, we also store this initialization as a checkpoint,
            # such that we could run exactly reproduceable experiments.
            checkpoint_saver.save(sess,
                                  os.path.join(args.experiment_root,
                                               'checkpoint'),
                                  global_step=0)

        merged_summary = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(args.experiment_root,
                                               sess.graph)

        start_step = sess.run(global_step)
        log.info('Starting training from iteration {}.'.format(start_step))

        # Finally, here comes the main-loop. This `Uninterrupt` is a handy
        # utility such that an iteration still finishes on Ctrl+C and we can
        # stop the training cleanly.
        with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u:
            for i in range(start_step, args.train_iterations):

                # Compute gradients, update weights, store logs!
                start_time = time.time()
                _, summary, step, b_prec_at_k, b_embs, b_loss, b_fids = \
                    sess.run([train_op, merged_summary, global_step,
                              prec_at_k, endpoints['emb'], losses, fids], feed_dict={input_images:images_train.eval(),
                                                                                     pids:pids_train.eval(),
                                                                                     fids:fids_train.eval()})
                elapsed_time = time.time() - start_time

                # Compute the iteration speed and add it to the summary.
                # We did observe some weird spikes that we couldn't track down.
                summary2 = tf.Summary()
                summary2.value.add(tag='secs_per_iter',
                                   simple_value=elapsed_time)
                summary_writer.add_summary(summary2, step)
                summary_writer.add_summary(summary, step)

                if args.detailed_logs:
                    log_embs[i], log_loss[i], log_fids[
                        i] = b_embs, b_loss, b_fids

                # Do a huge print out of the current progress.
                seconds_todo = (args.train_iterations - step) * elapsed_time
                log.info(
                    'iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, '
                    'batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it)'.format(
                        step, float(np.min(b_loss)), float(np.mean(b_loss)),
                        float(np.max(b_loss)), args.batch_k - 1,
                        float(b_prec_at_k),
                        timedelta(seconds=int(seconds_todo)), elapsed_time))
                sys.stdout.flush()
                sys.stderr.flush()

                # Save a checkpoint of training every so often.
                if (args.checkpoint_frequency > 0
                        and step % args.checkpoint_frequency == 0):
                    checkpoint_saver.save(sess,
                                          os.path.join(args.experiment_root,
                                                       'checkpoint'),
                                          global_step=step)

                #get validation results
                if (args.validation_frequency > 0
                        and step % args.validation_frequency == 0):
                    b_prec_at_k_val, b_loss, b_fids = \
                       sess.run([prec_at_k, losses, fids], feed_dict={input_images : images_val.eval(),
                                                                      pids:pids_val.eval(),
                                                                      fids:fids_val.eval()})
                    log.info(
                        'Validation @:{:6d} iteration, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, '
                        'batch-p@{}: {:.2%}'.format(step,
                                                    float(np.min(b_loss)),
                                                    float(np.mean(b_loss)),
                                                    float(np.max(b_loss)),
                                                    args.batch_k - 1,
                                                    float(b_prec_at_k_val)))
                    sys.stdout.flush()
                    sys.stderr.flush()
                    summary3 = tf.Summary()
                    summary3.value.add(tag='secs_per_iter',
                                       simple_value=float(np.mean(b_loss)))
                    summary_writer.add_summary(summary3, step)
                    summary_writer.add_summary(summary3, step)

                # Stop the main-loop at the end of the step, if requested.
                if u.interrupted:
                    log.info("Interrupted on request!")
                    break

        # Store one final checkpoint. This might be redundant, but it is crucial
        # in case intermediate storing was disabled and it saves a checkpoint
        # when the process was interrupted.
        checkpoint_saver.save(sess,
                              os.path.join(args.experiment_root, 'checkpoint'),
                              global_step=step)
def main():
    args = parser.parse_args()

    # We store all arguments in a json file. This has two advantages:
    # 1. We can always get back and see what exactly that experiment was
    # 2. We can resume an experiment as-is without needing to remember all flags.
    args_file = os.path.join(args.experiment_root, 'args.json')
    if args.resume:
        if not os.path.isfile(args_file):
            raise IOError('`args.json` not found in {}'.format(args_file))

        print('Loading args from {}.'.format(args_file))
        with open(args_file, 'r') as f:
            args_resumed = json.load(f)
        args_resumed['resume'] = True  # This would be overwritten.

        # When resuming, we not only want to populate the args object with the
        # values from the file, but we also want to check for some possible
        # conflicts between loaded and given arguments.
        for key, value in args.__dict__.items():
            if key in args_resumed:
                resumed_value = args_resumed[key]
                if resumed_value != value:
                    print('Warning: For the argument `{}` we are using the'
                          ' loaded value `{}`. The provided value was `{}`'
                          '.'.format(key, resumed_value, value))
                    args.__dict__[key] = resumed_value
            else:
                print('Warning: A new argument was added since the last run:'
                      ' `{}`. Using the new value: `{}`.'.format(key, value))

    else:
        # If the experiment directory exists already, we bail in fear.
        if os.path.exists(args.experiment_root):
            if os.listdir(args.experiment_root):
                print('The directory {} already exists and is not empty.'
                      ' If you want to resume training, append --resume to'
                      ' your call.'.format(args.experiment_root))
                exit(1)
        else:
            os.makedirs(args.experiment_root)

        # Store the passed arguments for later resuming and grepping in a nice
        # and readable format.
        with open(args_file, 'w') as f:
            json.dump(vars(args),
                      f,
                      ensure_ascii=False,
                      indent=2,
                      sort_keys=True)

    log_file = os.path.join(args.experiment_root, "train")
    logging.config.dictConfig(common.get_logging_dict(log_file))
    log = logging.getLogger('train')

    # Also show all parameter values at the start, for ease of reading logs.
    log.info('Training using the following parameters:')
    for key, value in sorted(vars(args).items()):
        log.info('{}: {}'.format(key, value))

    # Check them here, so they are not required when --resume-ing.
    if not args.train_set:
        parser.print_help()
        log.error("You did not specify the `train_set` argument!")
        sys.exit(1)
    if not args.image_root:
        parser.print_help()
        log.error("You did not specify the required `image_root` argument!")
        sys.exit(1)

    # Load the data from the CSV file.
    pids, fids = common.load_dataset(args.train_set, args.image_root)
    max_fid_len = max(map(len, fids))  # We'll need this later for logfiles.

    # Setup a tf.Dataset where one "epoch" loops over all PIDS.
    # PIDS are shuffled after every epoch and continue indefinitely.
    unique_pids = np.unique(pids)
    dataset = tf.data.Dataset.from_tensor_slices(unique_pids)
    dataset = dataset.shuffle(len(unique_pids))

    # Constrain the dataset size to a multiple of the batch-size, so that
    # we don't get overlap at the end of each epoch.
    dataset = dataset.take((len(unique_pids) // args.batch_p) * args.batch_p)
    dataset = dataset.repeat(None)  # Repeat forever. Funny way of stating it.

    # For every PID, get K images.
    dataset = dataset.map(lambda pid: sample_k_fids_for_pid(
        pid, all_fids=fids, all_pids=pids, batch_k=args.batch_k))

    # Ungroup/flatten the batches for easy loading of the files.
    dataset = dataset.apply(tf.contrib.data.unbatch())

    # Convert filenames to actual image tensors.
    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)
    dataset = dataset.map(lambda fid, pid: common.fid_to_image(
        fid,
        pid,
        image_root=args.image_root,
        image_size=pre_crop_size if args.crop_augment else net_input_size),
                          num_parallel_calls=args.loading_threads)

    # Augment the data if specified by the arguments.
    if args.flip_augment:
        dataset = dataset.map(lambda im, fid, pid:
                              (tf.image.random_flip_left_right(im), fid, pid))
    if args.crop_augment:
        dataset = dataset.map(lambda im, fid, pid: (tf.random_crop(
            im, net_input_size + (3, )), fid, pid))

    # Group it back into PK batches.
    batch_size = args.batch_p * args.batch_k
    dataset = dataset.batch(batch_size)

    # Overlap producing and consuming for parallelism.
    dataset = dataset.prefetch(1)

    # Since we repeat the data infinitely, we only need a one-shot iterator.
    images, fids, pids = dataset.make_one_shot_iterator().get_next()

    # Create the model and an embedding head.
    model = import_module('nets.' + args.model_name)
    head = import_module('heads.' + args.head_name)

    # Feed the image through the model. The returned `body_prefix` will be used
    # further down to load the pre-trained weights for all variables with this
    # prefix.
    endpoints, body_prefix = model.endpoints(images, is_training=True)
    with tf.name_scope('head'):
        endpoints = head.head(endpoints, args.embedding_dim, is_training=True)

    # Create the loss in two steps:
    # 1. Compute all pairwise distances according to the specified metric.
    # 2. For each anchor along the first dimension, compute its loss.
    dists = loss.cdist(endpoints['emb'], endpoints['emb'], metric=args.metric)
    losses, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[
        args.loss](dists,
                   pids,
                   args.margin,
                   batch_precision_at_k=args.batch_k - 1)

    decDense = tf.layers.dense(
        inputs=endpoints['emb'], units=5120,
        name='decDense')  #  ,activation = tf.nn.relu  ################
    unflat = tf.reshape(decDense, shape=[tf.shape(decDense)[0], 32, 16, 10])
    unp3shape = tf.TensorShape(
        [2 * di for di in unflat.get_shape().as_list()[1:-1]])
    unPool3 = tf.image.resize_nearest_neighbor(unflat,
                                               unp3shape,
                                               name='unpool3')

    deConv3 = tf.layers.conv2d(inputs=unPool3,
                               filters=64,
                               kernel_size=[5, 5],
                               strides=(1, 1),
                               padding='same',
                               activation=tf.nn.relu,
                               name='deConv3')
    unp2shape = tf.TensorShape(
        [2 * di for di in deConv3.get_shape().as_list()[1:-1]])
    unPool2 = tf.image.resize_nearest_neighbor(deConv3,
                                               unp2shape,
                                               name='unpool2')

    deConv2 = tf.layers.conv2d(inputs=unPool2,
                               filters=32,
                               kernel_size=[5, 5],
                               strides=(1, 1),
                               padding='same',
                               activation=tf.nn.relu,
                               name='deConv2')
    unp1shape = tf.TensorShape(
        [2 * di for di in deConv2.get_shape().as_list()[1:-1]])
    unPool1 = tf.image.resize_nearest_neighbor(deConv2,
                                               unp1shape,
                                               name='unpool1')

    deConv1 = tf.layers.conv2d(inputs=unPool1,
                               filters=3,
                               kernel_size=[5, 5],
                               strides=(1, 1),
                               padding='same',
                               activation=None,
                               name='deConv1')
    imClip = deConv1  #tf.clip_by_value(t = deConv1,clip_value_min = -1.0,clip_value_max = 1.0,name='clipRelu')
    print('RconstructeddImage :  ', imClip.name)

    recLoss = tf.multiply(
        0.01, tf.losses.mean_squared_error(
            labels=images,
            predictions=imClip,
        ))
    print('recLoss :  ', recLoss.name)

    decDense1 = tf.layers.dense(
        inputs=endpoints['emb'], units=5120,
        name='decDense1')  #  ,activation = tf.nn.relu  ################
    unflat1 = tf.reshape(decDense1, shape=[tf.shape(decDense1)[0], 32, 16, 10])
    unp3shape1 = tf.TensorShape(
        [2 * di for di in unflat1.get_shape().as_list()[1:-1]])
    unPool3_new = tf.image.resize_nearest_neighbor(unflat1,
                                                   unp3shape1,
                                                   name='unpool3_new')

    deConv3_new = tf.layers.conv2d(inputs=unPool3_new,
                                   filters=64,
                                   kernel_size=[5, 5],
                                   strides=(1, 1),
                                   padding='same',
                                   activation=tf.nn.relu,
                                   name='deConv3_new')
    unp2shape_new = tf.TensorShape(
        [2 * di for di in deConv3_new.get_shape().as_list()[1:-1]])
    unPool2_new = tf.image.resize_nearest_neighbor(deConv3_new,
                                                   unp2shape_new,
                                                   name='unpool2_new')

    deConv2_new = tf.layers.conv2d(inputs=unPool2_new,
                                   filters=3,
                                   kernel_size=[5, 5],
                                   strides=(1, 1),
                                   padding='same',
                                   activation=tf.nn.relu,
                                   name='deConv2_new')
    unp1shape_new = tf.TensorShape(
        [2 * di for di in deConv2_new.get_shape().as_list()[1:-1]])
    unPool1_new = tf.image.resize_nearest_neighbor(deConv2_new,
                                                   unp1shape_new,
                                                   name='unpool1_new')

    deConv1_new = tf.layers.conv2d(inputs=unPool1_new,
                                   filters=3,
                                   kernel_size=[5, 5],
                                   strides=(1, 1),
                                   padding='same',
                                   activation=None,
                                   name='deConv1_new')
    imClip1 = deConv2_new
    print('RconstructeddImage :  ', imClip1.name)
    print(imClip1.shape)

    images2 = tf.image.resize_images(images, [128, 64])
    print(images2.shape)

    recLoss1 = tf.multiply(
        0.01,
        tf.losses.mean_squared_error(
            labels=images2,
            predictions=imClip1,
        ))
    print('recLoss_new :  ', recLoss1.name)

    decDense2 = tf.layers.dense(
        inputs=endpoints['emb'], units=5120,
        name='decDense2')  #  ,activation = tf.nn.relu  ################
    unflat12 = tf.reshape(decDense2,
                          shape=[tf.shape(decDense2)[0], 32, 16, 10])
    unp3shape12 = tf.TensorShape(
        [2 * di for di in unflat12.get_shape().as_list()[1:-1]])
    unPool3_new2 = tf.image.resize_nearest_neighbor(unflat12,
                                                    unp3shape12,
                                                    name='unpool3_new2')

    deConv3_new2 = tf.layers.conv2d(inputs=unPool3_new2,
                                    filters=3,
                                    kernel_size=[5, 5],
                                    strides=(1, 1),
                                    padding='same',
                                    activation=tf.nn.relu,
                                    name='deConv3_new2')
    unp2shape_new2 = tf.TensorShape(
        [2 * di for di in deConv3_new2.get_shape().as_list()[1:-1]])
    unPool2_new2 = tf.image.resize_nearest_neighbor(deConv3_new2,
                                                    unp2shape_new2,
                                                    name='unpool2_new2')

    imClip11 = deConv3_new2
    images21 = tf.image.resize_images(images, [64, 32])

    recLoss2 = tf.multiply(
        0.01,
        tf.losses.mean_squared_error(
            labels=images21,
            predictions=imClip11,
        ))
    print('recLoss_new :  ', recLoss2.name)

    decDensel = tf.layers.dense(
        inputs=endpoints['emb'], units=5120,
        name='decDensel')  #  ,activation = tf.nn.relu  ################
    unflatl = tf.reshape(decDensel, shape=[tf.shape(decDensel)[0], 32, 16, 10])
    unp3shapel = tf.TensorShape(
        [2 * di for di in unflatl.get_shape().as_list()[1:-1]])
    unPool3l = tf.image.resize_nearest_neighbor(unflatl,
                                                unp3shapel,
                                                name='unpool3l')

    deConv3l = tf.layers.conv2d(inputs=unPool3l,
                                filters=64,
                                kernel_size=[5, 5],
                                strides=(1, 1),
                                padding='same',
                                activation=tf.nn.relu,
                                name='deConv3l')
    unp2shapel = tf.TensorShape(
        [2 * di for di in deConv3l.get_shape().as_list()[1:-1]])
    unPool2l = tf.image.resize_nearest_neighbor(deConv3l,
                                                unp2shapel,
                                                name='unpool2l')

    deConv2l = tf.layers.conv2d(inputs=unPool2l,
                                filters=32,
                                kernel_size=[5, 5],
                                strides=(1, 1),
                                padding='same',
                                activation=tf.nn.relu,
                                name='deConv2l')
    unp1shapel = tf.TensorShape(
        [2 * di for di in deConv2l.get_shape().as_list()[1:-1]])
    unPool1l = tf.image.resize_nearest_neighbor(deConv2l,
                                                unp1shapel,
                                                name='unpool1l')

    deConv1l = tf.layers.conv2d(inputs=unPool1l,
                                filters=3,
                                kernel_size=[5, 5],
                                strides=(1, 1),
                                padding='same',
                                activation=None,
                                name='deConv1l')
    imClipl = deConv1l  #tf.clip_by_value(t = deConv1,clip_value_min = -1.0,clip_value_max = 1.0,name='clipRelu')
    print('RconstructeddImage :  ', imClipl.name)

    recLossl = tf.multiply(
        0.01, tf.losses.mean_squared_error(
            labels=images,
            predictions=imClipl,
        ))
    print('recLoss :  ', recLossl.name)

    # Count the number of active entries, and compute the total batch loss.
    num_active = tf.reduce_sum(tf.cast(tf.greater(losses, 1e-5), tf.float32))
    loss_mean = tf.reduce_mean(losses)

    # Some logging for tensorboard.
    tf.summary.histogram('loss_distribution', losses)
    tf.summary.scalar('loss', loss_mean)
    tf.summary.scalar('batch_top1', train_top1)
    tf.summary.scalar('batch_prec_at_{}'.format(args.batch_k - 1), prec_at_k)
    tf.summary.scalar('active_count', num_active)
    tf.summary.histogram('embedding_dists', dists)
    tf.summary.histogram('embedding_pos_dists', pos_dists)
    tf.summary.histogram('embedding_neg_dists', neg_dists)
    tf.summary.histogram('embedding_lengths',
                         tf.norm(endpoints['emb_raw'], axis=1))

    # Create the mem-mapped arrays in which we'll log all training detail in
    # addition to tensorboard, because tensorboard is annoying for detailed
    # inspection and actually discards data in histogram summaries.
    if args.detailed_logs:
        log_embs = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'embeddings'),
            dtype=np.float32,
            shape=(args.train_iterations, batch_size, args.embedding_dim))
        log_loss = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'losses'),
            dtype=np.float32,
            shape=(args.train_iterations, batch_size))
        log_fids = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'fids'),
            dtype='S' + str(max_fid_len),
            shape=(args.train_iterations, batch_size))

    # These are collected here before we add the optimizer, because depending
    # on the optimizer, it might add extra slots, which are also global
    # variables, with the exact same prefix.
    model_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                        body_prefix)

    # Define the optimizer and the learning-rate schedule.
    # Unfortunately, we get NaNs if we don't handle no-decay separately.
    global_step = tf.Variable(0, name='global_step', trainable=False)
    if 0 <= args.decay_start_iteration < args.train_iterations:
        learning_rate = tf.train.exponential_decay(
            args.learning_rate,
            tf.maximum(0, global_step - args.decay_start_iteration),
            args.train_iterations - args.decay_start_iteration, 0.001)
    else:
        learning_rate = args.learning_rate
    tf.summary.scalar('learning_rate', learning_rate)
    optimizer = tf.train.AdamOptimizer(learning_rate)
    # Feel free to try others!
    # optimizer = tf.train.AdadeltaOptimizer(learning_rate)

    # Update_ops are used to update batchnorm stats.
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        train_op = optimizer.minimize(tf.add(
            loss_mean,
            tf.add(recLoss, tf.add(recLoss1, tf.add(recLoss2, recLossl)))),
                                      global_step=global_step)

    # Define a saver for the complete model.
    checkpoint_saver = tf.train.Saver(max_to_keep=0)

    with tf.Session() as sess:
        if args.resume:
            # In case we're resuming, simply load the full checkpoint to init.
            last_checkpoint = tf.train.latest_checkpoint(args.experiment_root)
            log.info('Restoring from checkpoint: {}'.format(last_checkpoint))
            checkpoint_saver.restore(sess, last_checkpoint)
        else:
            # But if we're starting from scratch, we may need to load some
            # variables from the pre-trained weights, and random init others.
            sess.run(tf.global_variables_initializer())
            if args.initial_checkpoint is not None:
                saver = tf.train.Saver(model_variables)
                saver.restore(sess, args.initial_checkpoint)

            # In any case, we also store this initialization as a checkpoint,
            # such that we could run exactly reproduceable experiments.
            checkpoint_saver.save(sess,
                                  os.path.join(args.experiment_root,
                                               'checkpoint'),
                                  global_step=0)

        merged_summary = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(args.experiment_root,
                                               sess.graph)

        start_step = sess.run(global_step)
        log.info('Starting training from iteration {}.'.format(start_step))

        # Finally, here comes the main-loop. This `Uninterrupt` is a handy
        # utility such that an iteration still finishes on Ctrl+C and we can
        # stop the training cleanly.
        with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u:
            for i in range(start_step, args.train_iterations):

                # Compute gradients, update weights, store logs!
                start_time = time.time()
                _, summary, step, b_prec_at_k, b_embs, b_loss, b_fids ,b_rec, b_rec1= \
                    sess.run([train_op, merged_summary, global_step,
                              prec_at_k, endpoints['emb'], losses, fids,recLoss, recLoss1])
                elapsed_time = time.time() - start_time

                # Compute the iteration speed and add it to the summary.
                # We did observe some weird spikes that we couldn't track down.
                summary2 = tf.Summary()
                summary2.value.add(tag='secs_per_iter',
                                   simple_value=elapsed_time)
                summary_writer.add_summary(summary2, step)
                summary_writer.add_summary(summary, step)

                if args.detailed_logs:
                    log_embs[i], log_loss[i], log_fids[
                        i] = b_embs, b_loss, b_fids

                # Do a huge print out of the current progress.
                seconds_todo = (args.train_iterations - step) * elapsed_time
                log.info(
                    'iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, '
                    'recLoss: {:.3f} batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it)'.
                    format(step, float(np.min(b_loss)), float(np.mean(b_loss)),
                           float(np.max(b_loss)), b_rec, args.batch_k - 1,
                           float(b_prec_at_k),
                           timedelta(seconds=int(seconds_todo)), elapsed_time))
                sys.stdout.flush()
                sys.stderr.flush()

                # Save a checkpoint of training every so often.
                if (args.checkpoint_frequency > 0
                        and step % args.checkpoint_frequency == 0):
                    checkpoint_saver.save(sess,
                                          os.path.join(args.experiment_root,
                                                       'checkpoint'),
                                          global_step=step)

                # Stop the main-loop at the end of the step, if requested.
                if u.interrupted:
                    log.info("Interrupted on request!")
                    break

        # Store one final checkpoint. This might be redundant, but it is crucial
        # in case intermediate storing was disabled and it saves a checkpoint
        # when the process was interrupted.
        checkpoint_saver.save(sess,
                              os.path.join(args.experiment_root, 'checkpoint'),
                              global_step=step)
예제 #5
0
def evaluate_embs(args, query_pids, query_fids, query_embs, gallery_pids,
                  gallery_fids, gallery_embs):
    # Just a quick sanity check that both have the same embedding dimension!
    query_dim = query_embs.shape[1]
    gallery_dim = gallery_embs.shape[1]
    if query_dim != gallery_dim:
        raise ValueError('Shape mismatch between query ({}) and gallery ({}) '
                         'dimension'.format(query_dim, gallery_dim))

    # Setup the dataset specific matching function
    if 'vehicleid' in args.dataset:
        gallery_pids = np.asarray(gallery_pids)
        gallery_fids = np.asarray(gallery_fids)
    else:
        excluder = import_module('excluders.' +
                                 args.excluder).Excluder(gallery_fids)

    # We go through the queries in batches, but we always need the whole gallery
    batch_pids, batch_fids, batch_embs = tf.data.Dataset.from_tensor_slices(
        (query_pids, query_fids, query_embs)).batch(
            args.batch_size).make_one_shot_iterator().get_next()

    batch_distances = loss.cdist(batch_embs, gallery_embs, metric=args.metric)
    # Loop over the query embeddings and compute their APs and the CMC curve.
    aps = []
    correct_rank = []
    results_images = []
    scores_of_queries = []
    pid_matches_all = np.zeros(shape=(1, len(gallery_fids)))
    num_of_NaN = 0
    cmc = np.zeros(len(gallery_pids), dtype=np.int32)
    num_of_paired_img = len(query_pids)
    if args.output_name is not None:
        text_file = open(
            os.path.join(args.experiment_root,
                         "accuracy_summary_" + args.output_name + ".txt"), "w")
    else:
        text_file = open(
            os.path.join(args.experiment_root, "accuracy_summary.txt"), "w")
    with tf.Session() as sess:
        for start_idx in count(step=args.batch_size):
            try:
                # Compute distance to all gallery embeddings for the batch of queries
                distances, pids, fids = sess.run(
                    [batch_distances, batch_pids, batch_fids])
                print('\rEvaluating batch {}-{}/{}'.format(
                    start_idx, start_idx + len(fids), len(query_fids)),
                      flush=True,
                      end='')
            except tf.errors.OutOfRangeError:
                print()  # Done!
                break

            # Convert the array of objects back to array of strings
            pids, fids = np.array(pids, '|U'), np.array(fids, '|U')

            # Compute the pid matches
            pid_matches = gallery_pids[None] == pids[:, None]

            # Get a mask indicating True for those gallery entries that should
            # be ignored for whatever reason (same camera, junk, ...) and
            # exclude those in a way that doesn't affect CMC and mAP.
            if 'vehicleid' not in args.dataset:
                mask = excluder(fids)
                distances[mask] = np.inf
                pid_matches[mask] = False
            pid_matches_all = np.concatenate((pid_matches_all, pid_matches),
                                             axis=0)

            # Keep track of statistics. Invert distances to scores using any
            # arbitrary inversion, as long as it's monotonic and well-behaved,
            # it won't change anything.
            scores = 1 / (1 + distances)
            num_of_col = 10
            for i in range(len(distances)):
                ap = average_precision_score(pid_matches[i], scores[i])
                sorted_distances_inds = np.argsort(distances[i])

                if np.isnan(ap):
                    print()
                    print(
                        str(num_of_NaN) +
                        ". WARNING: encountered an AP of NaN!")
                    print("This usually means a person only appears once.")
                    print("In this case, it's because of {}.".format(fids[i]))
                    print(
                        "I'm excluding this person from eval and carrying on.")
                    print()
                    text = (
                        str(num_of_NaN) +
                        ". WARNING: encountered an AP of NaN! Probably a person only appears once - {}\n"
                        .format(fids[i]))
                    text_file.write(text)
                    correct_rank.append(-1)
                    results_images.append(
                        gallery_fids[sorted_distances_inds[0:num_of_col]])
                    num_of_NaN += 1
                    num_of_paired_img -= 1
                    scores_of_queries.append(-1)
                    continue

                aps.append(ap)
                scores_of_queries.append(ap)
                # Find the first true match and increment the cmc data from there on.
                rank_k = np.where(pid_matches[i, sorted_distances_inds])[0][0]
                cmc[rank_k:] += 1
                # Save five more similar images to each of image and correct rank of each image
                if (len(gallery_fids) < num_of_col):
                    num_of_col = len(gallery_fids)
                correct_rank.append(rank_k)
                results_images.append(
                    gallery_fids[sorted_distances_inds[0:num_of_col]])

    # Compute the actual cmc and mAP values
    cmc = cmc / num_of_paired_img
    mean_ap = np.mean(aps)

    # Save important data
    saveResults(args, results_images,
                np.argsort(scores_of_queries)[::-1], query_fids, 10)
    if args.output_name is not None:
        out_file = open(
            os.path.join(args.experiment_root,
                         "evaluation_" + args.output_name + ".json"), "w")
        json.dump({
            'mAP': mean_ap,
            'CMC': list(cmc),
            'aps': list(aps)
        }, out_file)
        out_file.close()
    else:
        out_file = open(os.path.join(args.experiment_root, "evaluation.json"),
                        "w")
        json.dump({
            'mAP': mean_ap,
            'CMC': list(cmc),
            'aps': list(aps)
        }, out_file)
        out_file.close()

    # Print out a short summary and save summary accuracy.
    if len(cmc) > 9:
        print(
            'mAP: {:.2%} | top-1: {:.2%} top-2: {:.2%} | top-5: {:.2%} | top-10: {:.2%}'
            .format(mean_ap, cmc[0], cmc[1], cmc[4], cmc[9]))
        text_file.write(
            'mAP: {:.2%} | top-1: {:.2%} top-2: {:.2%} | top-5: {:.2%} | top-10: {:.2%}'
            .format(mean_ap, cmc[0], cmc[1], cmc[4], cmc[9]))
    elif len(cmc) > 5:
        print(
            'mAP: {:.2%} | top-1: {:.2%} top-2: {:.2%} | top-5: {:.2%}'.format(
                mean_ap, cmc[0], cmc[1], cmc[4]))
        text_file.write(
            'mAP: {:.2%} | top-1: {:.2%} top-2: {:.2%} | top-5: {:.2%}'.format(
                mean_ap, cmc[0], cmc[1], cmc[4]))
    elif len(cmc) > 2:
        print('mAP: {:.2%} | top-1: {:.2%} top-2: {:.2%}'.format(
            mean_ap, cmc[0], cmc[1]))
        text_file.write('mAP: {:.2%} | top-1: {:.2%} top-2: {:.2%}'.format(
            mean_ap, cmc[0], cmc[1]))
    else:
        print('mAP: {:.2%} | top-1: {:.2%}'.format(mean_ap, cmc[0]))
        text_file.write('mAP: {:.2%} | top-1: {:.2%}'.format(mean_ap, cmc[0]))
    text_file.close()

    return [mean_ap, cmc[0]]
예제 #6
0
def main():
    # my_devices = tf.config.experimental.list_physical_devices(device_type='CPU')
    # tf.config.experimental.set_visible_devices(devices= my_devices, device_type='CPU')

    # # To find out which devices your operations and tensors are assigned to
    # tf.debugging.set_log_device_placement(True)
    args = parser.parse_args(args=[])

    show_all_parameters(args)

    if not args.train_set:
        parser.print_help()
        print("You didn't specify the 'train_set' argument!")
        sys.exit(1)
    if not args.image_root:
        parser.print_help()
        print("You didn't specify the 'image_root' argument!")
        sys.exit(1)

    pids, fids = common.load_dataset(args.train_set, args.image_root)

    unique_pids = np.unique(pids)
    dataset = tf.data.Dataset.from_tensor_slices(unique_pids)
    dataset = dataset.shuffle(len(unique_pids))

    # Take the dataset size equal to a multiple of the batch-size, so that
    # we don't get overlap at the end of each epoch.
    dataset = dataset.take((len(unique_pids) // args.batch_p) * args.batch_p)
    dataset = dataset.repeat(None)  # Repeat indefinitely.

    # For every PID, get K images.
    dataset = dataset.map(lambda pid: sample_k_fids_for_pid(
        pid, all_fids=fids, all_pids=pids, batch_k=args.batch_k))

    # Ungroup/flatten the batches
    dataset = dataset.unbatch()

    # Convert filenames to actual image tensors.
    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)
    dataset = dataset.map(lambda fid, pid: common.fid_to_image(
        fid,
        pid,
        image_root=args.image_root,
        image_size=pre_crop_size if args.crop_augment else net_input_size))

    if args.flip_augment:
        dataset = dataset.map(lambda im, fid, pid:
                              (tf.image.random_flip_left_right(im), fid, pid))
    if args.crop_augment:
        dataset = dataset.map(lambda im, fid, pid: (tf.image.random_crop(
            im, net_input_size + (3, )), fid, pid))

    # Group the data into PK batches.
    batch_size = args.batch_p * args.batch_k
    dataset = dataset.batch(batch_size)

    dataset = dataset.prefetch(1)
    dataiter = iter(dataset)

    model = Trinet(args.embedding_dim)
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        args.learning_rate, args.train_iterations - args.decay_start_iteration,
        0.001)
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

    writer = tf.summary.create_file_writer(args.experiment_root)
    ckpt = tf.train.Checkpoint(step=tf.Variable(0),
                               optimizer=optimizer,
                               net=model)
    manager = tf.train.CheckpointManager(ckpt,
                                         args.experiment_root,
                                         max_to_keep=10)

    if args.resume:
        ckpt.restore(manager.latest_checkpoint)

    for epoch in range(args.train_iterations):

        # for images,fids,pids in dataset:
        images, fids, pids = next(dataiter)
        with tf.GradientTape() as tape:
            emb = model(images)
            dists = loss.cdist(emb, emb)
            losses, top1, prec, topksame, negdist, posdist = loss.batch_hard(
                dists, pids, args.margin, args.batch_k)
            lossavg = tf.reduce_mean(losses)
            lossnp = losses.numpy()
        with writer.as_default():
            tf.summary.scalar("loss", lossavg, step=epoch)
            tf.summary.scalar('batch_top1', top1, step=epoch)
            tf.summary.scalar('batch_prec_at_{}'.format(args.batch_k - 1),
                              prec,
                              step=epoch)
            tf.summary.histogram('losses', losses, step=epoch)
            tf.summary.histogram('embedding_dists', dists, step=epoch)
            tf.summary.histogram('embedding_pos_dists', negdist, step=epoch)
            tf.summary.histogram('embedding_neg_dists', posdist, step=epoch)

        print('iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, '
              ' batch-p@{}: {:.2%}'.format(epoch, float(np.min(lossnp)),
                                           float(np.mean(lossnp)),
                                           float(np.max(lossnp)),
                                           args.batch_k - 1, float(prec)))
        grad = tape.gradient(lossavg, model.trainable_variables)
        optimizer.apply_gradients(zip(grad, model.trainable_variables))
        ckpt.step.assign_add(1)
        if epoch % args.checkpoint_frequency == 0:
            manager.save()
예제 #7
0
def main():
    # Verify that parameters are set correctly.
    args = parser.parse_args()

    # Load the query and gallery data from the CSV files.
    # query_pids, query_fids = common.load_dataset_query(args.query_dataset, args.image_root)
    # gallery_pids, gallery_fids = common.load_dataset_test(args.gallery_dataset, args.image_root)

    query_pids, query_fids = common.load_dataset(args.query_dataset, None)
    gallery_pids, gallery_fids = common.load_dataset(args.gallery_dataset,
                                                     None)

    # query_pids, query_fids = embed_partial_reid.load_dataset_partial_reid(args.query_dataset, '/data1/chenyf/PartialREID/partial_body_images')
    # gallery_pids, gallery_fids = embed_partial_reid.load_dataset_partial_reid(args.gallery_dataset, '/data1/chenyf/PartialREID/whole_body_images')

    # Load the two datasets fully into memory.
    with h5py.File(args.query_embeddings, 'r') as f_query:
        query_embs = np.array(f_query['emb'])
    with h5py.File(args.gallery_embeddings, 'r') as f_gallery:
        gallery_embs = np.array(f_gallery['emb'])

    # Just a quick sanity check that both have the same embedding dimension!
    query_dim = query_embs.shape[1]
    gallery_dim = gallery_embs.shape[1]
    if query_dim != gallery_dim:
        raise ValueError('Shape mismatch between query ({}) and gallery ({}) '
                         'dimension'.format(query_dim, gallery_dim))

    # Setup the dataset specific matching function
    excluder = import_module('excluders.' +
                             args.excluder).Excluder(gallery_fids)

    # We go through the queries in batches, but we always need the whole gallery
    batch_pids, batch_fids, batch_embs = tf.data.Dataset.from_tensor_slices(
        (query_pids, query_fids, query_embs)).batch(
            args.batch_size).make_one_shot_iterator().get_next()

    batch_distances = loss.cdist(batch_embs, gallery_embs, metric=args.metric)

    # Loop over the query embeddings and compute their APs and the CMC curve.
    aps = []
    cmc = np.zeros(len(gallery_pids), dtype=np.int32)
    with tf.Session() as sess:
        for start_idx in count(step=args.batch_size):
            try:
                # Compute distance to all gallery embeddings
                distances, pids, fids = sess.run(
                    [batch_distances, batch_pids, batch_fids])
                print('\rEvaluating batch {}-{}/{}'.format(
                    start_idx, start_idx + len(fids), len(query_fids)),
                      flush=True,
                      end='')
            except tf.errors.OutOfRangeError:
                print()  # Done!
                break

            # Convert the array of objects back to array of strings
            pids, fids = np.array(pids, '|U'), np.array(fids, '|U')

            # Compute the pid matches
            pid_matches = gallery_pids[None] == pids[:, None]

            # Get a mask indicating True for those gallery entries that should
            # be ignored for whatever reason (same camera, junk, ...) and
            # exclude those in a way that doesn't affect CMC and mAP.
            mask = excluder(fids)
            distances[mask] = np.inf
            pid_matches[mask] = False

            # Keep track of statistics. Invert distances to scores using any
            # arbitrary inversion, as long as it's monotonic and well-behaved,
            # it won't change anything.
            scores = 1 / (1 + distances)
            for i in range(len(distances)):
                ap = average_precision_score(pid_matches[i], scores[i])

                if np.isnan(ap):
                    print()
                    print("WARNING: encountered an AP of NaN!")
                    print("This usually means a person only appears once.")
                    print("In this case, it's because of {}.".format(fids[i]))
                    print(
                        "I'm excluding this person from eval and carrying on.")
                    print()
                    continue

                aps.append(ap)
                # Find the first true match and increment the cmc data from there on.
                k = np.where(pid_matches[i, np.argsort(distances[i])])[0][0]
                cmc[k:] += 1

    # Compute the actual cmc and mAP values
    cmc = cmc / len(query_pids)
    mean_ap = np.mean(aps)

    # Save important data
    if args.filename is not None:
        json.dump({
            'mAP': mean_ap,
            'CMC': list(cmc),
            'aps': list(aps)
        }, args.filename)

    # Print out a short summary.
    print(
        'mAP: {:.2%} | top-1: {:.2%} top-2: {:.2%} | top-3: {:.2%} | top-10: {:.2%}'
        .format(mean_ap, cmc[0], cmc[1], cmc[2], cmc[9]))
예제 #8
0
def main():
    # Verify that parameters are set correctly.
    args = parser.parse_args()

    # Load the query and gallery data from the CSV files.
    query_pids, query_fids = common.load_dataset(args.query_dataset, None)
    gallery_pids, gallery_fids = common.load_dataset(args.gallery_dataset, None)

    # Load the two datasets fully into memory.
    with h5py.File(args.query_embeddings, 'r') as f_query:
        query_embs = np.array(f_query['emb'])
    with h5py.File(args.gallery_embeddings, 'r') as f_gallery:
        gallery_embs = np.array(f_gallery['emb'])

    print('gallery shape: {}'.format(gallery_embs.shape))
    # Just a quick sanity check that both have the same embedding dimension!
    query_dim = query_embs.shape[1]
    gallery_dim = gallery_embs.shape[1]
    gallery_amount = gallery_embs.shape[0]
    if query_dim != gallery_dim:
        raise ValueError('Shape mismatch between query ({}) and gallery ({}) '
                         'dimension'.format(query_dim, gallery_dim))

    # Setup the dataset specific matching function
    excluder = import_module('excluders.' + args.excluder).Excluder(gallery_fids)

    # We go through the queries in batches, but we always need the whole gallery
    batch_pids, batch_fids, batch_embs = tf.data.Dataset.from_tensor_slices(
        (query_pids, query_fids, query_embs)
    ).batch(args.batch_size).make_one_shot_iterator().get_next()

    gallery_iter = tf.data.Dataset.from_tensor_slices(gallery_embs).batch(gallery_amount).make_initializable_iterator()
    gallery_embs = gallery_iter.get_next()

    batch_distances = loss.cdist(batch_embs, gallery_embs, metric=args.metric)
    print('batch distance: {}'.format(batch_distances))

    # Loop over the query embeddings and compute their APs and the CMC curve.
    aps = []
    cmc = np.zeros(len(gallery_pids), dtype=np.int32)
    with tf.Session(config=config) as sess:
        for start_idx in count(step=args.batch_size):
            try:
                sess.run(gallery_iter.initializer)
                # Compute distance to all gallery embeddings
                distances, pids, fids = sess.run([
                    batch_distances, batch_pids, batch_fids])
                if args.display:
                    _att_maps = sess.run([attention_map])
                print('\rEvaluating batch {}-{}/{}'.format(
                        start_idx, start_idx + len(fids), len(query_fids)),
                      flush=True, end='')
            except tf.errors.OutOfRangeError:
                print()  # Done!
                break

            # Convert the array of objects back to array of strings
            pids, fids = np.array(pids, '|U'), np.array(fids, '|U')

            # Compute the pid matches
            pid_matches = gallery_pids[None] == pids[:,None]
            # print('pid match matrix shape: {}'.format(pid_matches.shape))
            # print('pid_matchs[0]: {}'.format(pid_matches[0]))
            # print('fids: {}'.format(fids))

            # Get a mask indicating True for those gallery entries that should
            # be ignored for whatever reason (same camera, junk, ...) and
            # exclude those in a way that doesn't affect CMC and mAP.
            mask = excluder(fids)
            distances[mask] = np.inf
            pid_matches[mask] = False
            # print('distance matrix: {}'.format(distances.shape))

            # Keep track of statistics. Invert distances to scores using any
            # arbitrary inversion, as long as it's monotonic and well-behaved,
            # it won't change anything.
            scores = 1 / (1 + distances)
            print('variables: {}'.format(tf.GraphKeys.VARIABLES))
            print('attention map')
            print(_att_maps)
            print(np.array(_att_maps).shape)
            for i in range(len(distances)):
                # for each query instance
                ap = average_precision_score(pid_matches[i], scores[i])

                if np.isnan(ap):
                    print("\nWARNING: encountered an AP of NaN!")
                    print("This usually means a person only appears once.")
                    print("In this case, it's because of {}.".format(fids[i]))
                    print("I'm excluding this person from eval and carrying on.\n")
                    continue

                aps.append(ap)
                # Find the first true match and increment the cmc data from there on.
                # print('match shape: {}'.format(pid_matches[i, np.argsort(distances[i])].shape))
                k = np.where(pid_matches[i, np.argsort(distances[i])])[0][0]
                # if args.display:
                if False:
                    if not os.path.exists('./mismatched-5'):
                        os.mkdir('mismatched-5')
                    mismatched_idx = np.argsort(distances[i])[:k]
                    if k > 5:
                        print('Mismatch | fid: {}'.format(fids[i]))
                        os.system('cp {} {}'.format(os.path.join('/data2/wangq/VD1/', fids[i]), os.path.join('./mismatched-5', 'batch-{}_query-{}_{}'.format(start_idx / args.batch_size, i, fids[i].split('/')[-1]))))
                        for l in range(k):
                            os.system('cp {} {}'.format(os.path.join('/data2/wangq/VD1/', gallery_fids[mismatched_idx[l]]), os.path.join('./mismatched-5', 'batch-{}_query-{}_gallary-{}_{}'.format(start_idx / args.batch_size, i, l, gallery_fids[mismatched_idx[l]].split('/')[-1]))))
                cmc[k:] += 1

    # Compute the actual cmc and mAP values
    cmc = cmc / len(query_pids)
    mean_ap = np.mean(aps)

    # Save important data
    if args.filename is not None:
        json.dump({'mAP': mean_ap, 'CMC': list(cmc), 'aps': list(aps)}, args.filename)

    # Print out a short summary.
    print('mAP: {:.2%} | top-1: {:.2%} top-2: {:.2%} | top-5: {:.2%} | top-10: {:.2%}'.format(
        mean_ap, cmc[0], cmc[1], cmc[4], cmc[9]))
예제 #9
0
def main():
    args = parser.parse_args()

    # We store all arguments in a json file. This has two advantages:       我们将所有参数存储在json文件中。 这有两个好处:
    # 1. We can always get back and see what exactly that experiment was    1.我们总是可以回头看看实验是什么
    # 2. We can resume an experiment as-is without needing to remember all flags.2.我们可以按原样恢复实验,无需记住所有标志。
    args_file = os.path.join(args.experiment_root, 'args.json')
    if args.resume:
        if not os.path.isfile(args_file):
            raise IOError('`args.json` not found in {}'.format(args_file))

        print('Loading args from {}.'.format(args_file))
        with open(args_file, 'r') as f:
            args_resumed = json.load(f)
        args_resumed['resume'] = True  # This would be overwritten.

        # When resuming, we not only want to populate the args object with the
        # values from the file, but we also want to check for some possible
        # conflicts between loaded and given arguments.
        #恢复时,我们不仅需要使用文件中的值填充args对象,但我们也想检查加载参数和给定参数之间的一些可能的冲突。
        for key, value in args.__dict__.items():
            if key in args_resumed:
                resumed_value = args_resumed[key]
                if resumed_value != value:
                    print('Warning: For the argument `{}` we are using the'
                          ' loaded value `{}`. The provided value was `{}`'
                          '.'.format(key, resumed_value, value))
                    args.__dict__[key] = resumed_value
            else:
                print('Warning: A new argument was added since the last run:'
                      ' `{}`. Using the new value: `{}`.'.format(key, value))

    else:
        # If the experiment directory exists already, we bail in fear.如果实验目录已经存在,我们就会担心。
        if os.path.exists(args.experiment_root):
            if os.listdir(args.experiment_root):
                print('The directory {} already exists and is not empty.'
                      ' If you want to resume training, append --resume to'
                      ' your call.'.format(args.experiment_root))
                exit(1)
        else:
            os.makedirs(args.experiment_root)

        # Store the passed arguments for later resuming and grepping in a nice
        # and readable format.将传递的参数存储起来,以便稍后以一种很好且可读的格式恢复和刷新。
        with open(args_file, 'w') as f:
            json.dump(vars(args),
                      f,
                      ensure_ascii=False,
                      indent=2,
                      sort_keys=True)

    log_file = os.path.join(args.experiment_root, "train")
    logging.config.dictConfig(common.get_logging_dict(log_file))
    log = logging.getLogger('train')

    # Also show all parameter values at the start, for ease of reading logs.
    #同时在开始时显示所有参数值,便于读取日志。
    log.info('Training using the following parameters:')
    for key, value in sorted(vars(args).items()):
        log.info('{}: {}'.format(key, value))

    # Check them here, so they are not required when --resume-ing.在这里检查他们,
    if not args.train_set:
        parser.print_help()
        log.error("You did not specify the `train_set` argument!")
        sys.exit(1)
    if not args.image_root:
        parser.print_help()
        log.error("You did not specify the required `image_root` argument!")
        sys.exit(1)

    # Load the data from the CSV file.                               加载CSV文件中的数据。
    pids, fids = common.load_dataset(args.train_set, args.image_root)
    max_fid_len = max(map(
        len, fids))  # We'll need this later for logfiles.我们稍后需要这个日志文件。

    # Setup a tf.Dataset where one "epoch" loops over all PIDS. 设置一个tf.Dataset,其中一个“epoch”在所有PIDS上循环。
    # PIDS are shuffled after every epoch and continue indefinitely.PIDS在每个时代之后都会被洗牌并无限期地继续下去。

    unique_pids = np.unique(pids)
    dataset = tf.data.Dataset.from_tensor_slices(unique_pids)
    dataset = dataset.shuffle(len(unique_pids))

    # Constrain the dataset size to a multiple of the batch-size, so that
    # we don't get overlap at the end of each epoch.将数据集大小限制为批量的倍数,以便在每个时期结束时我们不会重叠。
    dataset = dataset.take((len(unique_pids) // args.batch_p) * args.batch_p)
    dataset = dataset.repeat(
        None)  # Repeat forever. Funny way of stating it.  永远重复。 说明它的有趣方式。

    # For every PID, get K images.对于每个PID,获得K个图像。
    dataset = dataset.map(lambda pid: sample_k_fids_for_pid(
        pid, all_fids=fids, all_pids=pids, batch_k=args.batch_k))

    # Ungroup/flatten the batches for easy loading of the files.取消组合/拼合批次以便轻松加载文件。
    dataset = dataset.apply(tf.contrib.data.unbatch())

    # Convert filenames to actual image tensors.                 将文件名转换为实际图像张量。
    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)
    dataset = dataset.map(lambda fid, pid: common.fid_to_image(
        fid,
        pid,
        image_root=args.image_root,
        image_size=pre_crop_size if args.crop_augment else net_input_size),
                          num_parallel_calls=args.loading_threads)

    # Augment the data if specified by the arguments.            如果由参数指定,则增加数据。
    if args.flip_augment:
        dataset = dataset.map(lambda im, fid, pid:
                              (tf.image.random_flip_left_right(im), fid, pid))
    if args.crop_augment:
        dataset = dataset.map(lambda im, fid, pid: (tf.random_crop(
            im, net_input_size + (3, )), fid, pid))

    # Group it back into PK batches.                将其重新分组为PK批次。
    batch_size = args.batch_p * args.batch_k
    dataset = dataset.batch(batch_size)

    # Overlap producing and consuming for parallelism.
    dataset = dataset.prefetch(1)

    # Since we repeat the data infinitely, we only need a one-shot iterator.由于我们无限地重复数据,因此我们只需要一个one-shot 迭代器。
    images, fids, pids = dataset.make_one_shot_iterator().get_next()

    # Create the model and an embedding head.
    model = import_module('nets.' + args.model_name)
    head = import_module('heads.' + args.head_name)

    # Feed the image through the model. The returned `body_prefix` will be used
    # further down to load the pre-trained weights for all variables with this
    # prefix.
    #通过模型提供图像。返回的`body_prefix`将进一步用于加载具有此前缀的所有变量的预先训练权重。
    endpoints, body_prefix = model.endpoints(images, is_training=True)
    with tf.name_scope('head'):
        endpoints = head.head(endpoints, args.embedding_dim, is_training=True)

    # Create the loss in two steps:                                       分两步创建损失:
    # 1. Compute all pairwise distances according to the specified metric.1.根据指定的度量计算所有成对距离。
    # 2. For each anchor along the first dimension, compute its loss.     2.对于第一维中的每个锚,计算其损失。
    dists = loss.cdist(endpoints['emb'], endpoints['emb'], metric=args.metric)
    losses, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[
        args.loss](dists,
                   pids,
                   args.margin,
                   batch_precision_at_k=args.batch_k - 1)

    # Count the number of active entries, and compute the total batch loss.  计算活动条目的数量,并计算总批量损失。
    num_active = tf.reduce_sum(tf.cast(tf.greater(losses, 1e-5), tf.float32))
    loss_mean = tf.reduce_mean(losses)

    # Some logging for tensorboard.            一些日志记录在tensorboard。
    tf.summary.histogram('loss_distribution', losses)
    tf.summary.scalar('loss', loss_mean)
    tf.summary.scalar('batch_top1', train_top1)
    tf.summary.scalar('batch_prec_at_{}'.format(args.batch_k - 1), prec_at_k)
    tf.summary.scalar('active_count', num_active)
    tf.summary.histogram('embedding_dists', dists)
    tf.summary.histogram('embedding_pos_dists', pos_dists)
    tf.summary.histogram('embedding_neg_dists', neg_dists)
    tf.summary.histogram('embedding_lengths',
                         tf.norm(endpoints['emb_raw'], axis=1))

    # Create the mem-mapped arrays in which we'll log all training detail in
    # addition to tensorboard, because tensorboard is annoying for detailed
    # inspection and actually discards data in histogram summaries.
    #创建mem - mapped数组,我们将记录除tensorboard之外的所有训练细节,因为tensorboard对于详细检查很烦人,实际上在直方图总结中丢弃了数据。

    if args.detailed_logs:
        log_embs = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'embeddings'),
            dtype=np.float32,
            shape=(args.train_iterations, batch_size, args.embedding_dim))
        log_loss = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'losses'),
            dtype=np.float32,
            shape=(args.train_iterations, batch_size))
        log_fids = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'fids'),
            dtype='S' + str(max_fid_len),
            shape=(args.train_iterations, batch_size))

    # These are collected here before we add the optimizer, because depending
    # on the optimizer, it might add extra slots, which are also global
    # variables, with the exact same prefix.
    #在我们添加优化器之前收集这些信息,因为根据优化器的不同,它可能会添加额外的插槽,这些插槽也是全局变量,具有完全相同的前缀。
    model_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                        body_prefix)

    # Define the optimizer and the learning-rate schedule.              定义优化器和学习率计划。
    # Unfortunately, we get NaNs if we don't handle no-decay separately. 不幸的是,如果我们不单独处理无衰减,我们会得到NaNs。
    global_step = tf.Variable(0, name='global_step', trainable=False)
    if 0 <= args.decay_start_iteration < args.train_iterations:
        learning_rate = tf.train.exponential_decay(
            args.learning_rate,
            tf.maximum(0, global_step - args.decay_start_iteration),
            args.train_iterations - args.decay_start_iteration, 0.001)
    else:
        learning_rate = args.learning_rate
    tf.summary.scalar('learning_rate', learning_rate)
    optimizer = tf.train.AdamOptimizer(learning_rate)
    # Feel free to try others!                               随意尝试别的的优化器!
    # optimizer = tf.train.AdadeltaOptimizer(learning_rate)

    # Update_ops are used to update batchnorm stats.  Update_ops用于更新batchnorm统计信息。
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        train_op = optimizer.minimize(loss_mean, global_step=global_step)

    # Define a saver for the complete model.  为整个模型定义一个保存器。
    checkpoint_saver = tf.train.Saver(max_to_keep=0)

    with tf.Session() as sess:
        if args.resume:
            # In case we're resuming, simply load the full checkpoint to init.如果我们正在恢复,只需将完整的检查点加载到init。
            last_checkpoint = tf.train.latest_checkpoint(args.experiment_root)
            log.info('Restoring from checkpoint: {}'.format(last_checkpoint))
            checkpoint_saver.restore(sess, last_checkpoint)
        else:
            # But if we're starting from scratch, we may need to load some
            # variables from the pre-trained weights, and random init others.
            #但是如果我们从头开始,我们可能需要从预先训练的权重中加载一些变量,并随机初始化其他的变量。
            sess.run(tf.global_variables_initializer())
            if args.initial_checkpoint is not None:
                saver = tf.train.Saver(model_variables)
                saver.restore(sess, args.initial_checkpoint)

            # In any case, we also store this initialization as a checkpoint,
            # such that we could run exactly reproduceable experiments.
            #无论如何,我们也将这个初始化作为检查点存储,以便我们可以运行完全可再生的实验。
            checkpoint_saver.save(sess,
                                  os.path.join(args.experiment_root,
                                               'checkpoint'),
                                  global_step=0)

        merged_summary = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(args.experiment_root,
                                               sess.graph)

        start_step = sess.run(global_step)
        log.info('Starting training from iteration {}.'.format(start_step))

        # Finally, here comes the main-loop. This `Uninterrupt` is a handy
        # utility such that an iteration still finishes on Ctrl+C and we can
        # stop the training cleanly.
        #最后,这里是主循环。这个`Uninterrupt`是一个非常方便的工具,可以在Ctrl + C之后完成迭代才停止,我们可以干净地停止训练。
        with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u:
            for i in range(start_step, args.train_iterations):

                # Compute gradients, update weights, store logs!计算梯度,更新权重,存储日志!
                start_time = time.time()
                _, summary, step, b_prec_at_k, b_embs, b_loss, b_fids = \
                    sess.run([train_op, merged_summary, global_step,
                              prec_at_k, endpoints['emb'], losses, fids])
                elapsed_time = time.time() - start_time

                # Compute the iteration speed and add it to the summary.计算迭代速度并将其添加到摘要中。
                # We did observe some weird spikes that we couldn't track down.我们确实观察到一些我们无法追查的奇怪尖峰。
                summary2 = tf.Summary()
                summary2.value.add(tag='secs_per_iter',
                                   simple_value=elapsed_time)
                summary_writer.add_summary(summary2, step)
                summary_writer.add_summary(summary, step)

                if args.detailed_logs:
                    log_embs[i], log_loss[i], log_fids[
                        i] = b_embs, b_loss, b_fids

                # Do a huge print out of the current progress.            在当前的进展中做大量的印刷。
                seconds_todo = (args.train_iterations - step) * elapsed_time
                log.info(
                    'iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, '
                    'batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it)'.format(
                        step, float(np.min(b_loss)), float(np.mean(b_loss)),
                        float(np.max(b_loss)), args.batch_k - 1,
                        float(b_prec_at_k),
                        timedelta(seconds=int(seconds_todo)), elapsed_time))
                sys.stdout.flush()
                sys.stderr.flush()

                # Save a checkpoint of training every so often.每隔一段时间保存一次训练的检查点checkpoint。
                if (args.checkpoint_frequency > 0
                        and step % args.checkpoint_frequency == 0):
                    checkpoint_saver.save(sess,
                                          os.path.join(args.experiment_root,
                                                       'checkpoint'),
                                          global_step=step)

                # Stop the main-loop at the end of the step, if requested. 如果需要,在步骤结束时停止主循环。
                if u.interrupted:
                    log.info("Interrupted on request!")
                    break

        # Store one final checkpoint. This might be redundant, but it is crucial
        # in case intermediate storing was disabled and it saves a checkpoint
        # when the process was interrupted.
        #存储一个最终检查点。 这可能是多余的,但是在中间存储被禁用的情况下它是至关重要的,并且在进程被中断时保存检查点。
        checkpoint_saver.save(sess,
                              os.path.join(args.experiment_root, 'checkpoint'),
                              global_step=step)
예제 #10
0
def main():
    args = parser.parse_args([])

    query_pids, query_fids = common.load_dataset(args.query_dataset, None)
    gallery_pids, gallery_fids = common.load_dataset(args.gallery_dataset, None)

    # Load the two datasets fully into memory.
    with h5py.File(args.query_embeddings, 'r') as f_query:
        query_embs = np.array(f_query['emb'])
    with h5py.File(args.gallery_embeddings, 'r') as f_gallery:
        gallery_embs = np.array(f_gallery['emb'])


    dataset = tf.data.Dataset.from_tensor_slices((query_pids, query_fids, query_embs))
    dataset=dataset.batch(args.batch_size)


    aps = []
    cmc = np.zeros(len(gallery_pids), dtype=np.int32)
    fnames={}
    uniqpids = []
    start_idx=0
    for pids,fids,embs in dataset:
        try:
            distances = loss.cdist(embs,gallery_embs)

            print('\rEvaluating batch {}-{}/{}'.format(
                    start_idx, start_idx + len(fids), len(query_fids)))
            start_idx+=len(fids)
        except tf.errors.OutOfRangeError:
            print()
            break

        pids, fids = np.array(pids, '|U'), np.array(fids, '|U')

        pid_matches = gallery_pids[None] == pids[:,None]


        scores = 1.0 / (1 + distances)
        for i in range(len(distances)):
            ap = average_precision(pid_matches[i], scores[i])

            if np.isnan(ap):
                continue

            aps.append(ap)
            sorteddist = np.argsort(distances[i])
            k = np.where(pid_matches[i,sorteddist])[0][0]
            if len(fnames)<5:
                if pids[i] not in uniqpids:
                    uniqpids.append(pids[i])
                    temp=np.where(pid_matches[i,sorteddist])[0]
                    if len(temp)>=3:
                        fnames[fids[i]]=list(np.array(gallery_fids[sorteddist][temp[:3]],'|U'))
            cmc[k:] += 1

    cmc = cmc / float(len(query_pids))
    mean_ap = np.mean(aps)

    print('mAP: {:.2%} | top-1: {:.2%} top-2: {:.2%} | top-5: {:.2%} | top-10: {:.2%}'.format(
        mean_ap, cmc[0], cmc[1], cmc[4], cmc[9]))

    plt.figure(figsize=(20,15))
    i=1
    for key,val in fnames.items():
        plt.subplot(5,4,4*(i-1)+1)
        img =plt.imread(os.path.join(args.image_root,key))
        plt.imshow(img)
        # print(key)
        # print(img == None)
        if i == 1:
            plt.title('query')
        plt.axis('off')
        for j in range(3):
            plt.subplot(5,4,(i-1)*4+j+2)
            plt.imshow(plt.imread(os.path.join(args.image_root,val[j])))
            if i == 1 and j == 1:
                plt.title('matches')
            plt.axis('off')
        i+=1
    plt.show()

    if args.filename is not None:
        with open(args.filename,'w') as f:
            json.dump({'mAP': mean_ap, 'CMC': list(cmc), 'aps': list(aps)}, f)
예제 #11
0
def main():
    # Verify that parameters are set correctly.
    args = parser.parse_args()

    # Possibly auto-generate the output filename.
    if args.filename is None:
        basename = os.path.basename(args.dataset)
        args.filename = os.path.splitext(basename)[0] + '_embeddings.h5'
    args.filename = os.path.join(args.experiment_root, args.filename)

    # Load the args from the original experiment.
    args_file = os.path.join(args.experiment_root, 'args.json')

    if os.path.isfile(args_file):
        if not args.quiet:
            print('Loading args from {}.'.format(args_file))
        with open(args_file, 'r') as f:
            args_resumed = json.load(f)

        # Add arguments from training.
        for key, value in args_resumed.items():
            args.__dict__.setdefault(key, value)

        # A couple special-cases and sanity checks
        if (args_resumed['crop_augment']) == (args.crop_augment is None):
            print('WARNING: crop augmentation differs between training and '
                  'evaluation.')

        args.image_root = args.image_root or args_resumed['image_root']
    else:
        raise IOError(
            '`args.json` could not be found in: {}'.format(args_file))

    # Check a proper aggregator is provided if augmentation is used.
    if args.flip_augment or args.crop_augment == 'five':
        if args.aggregator is None:
            print(
                'ERROR: Test time augmentation is performed but no aggregator'
                'was specified.')
            exit(1)
    else:
        if args.aggregator is not None:
            print('ERROR: No test time augmentation that needs aggregating is '
                  'performed but an aggregator was specified.')
            exit(1)

    if not args.quiet:
        print('Evaluating using the following parameters:')
        for key, value in sorted(vars(args).items()):
            print('{}: {}'.format(key, value))

    # Load the data from the CSV file.
    _, data_fids = common.load_dataset_test(args.dataset, args.image_root)

    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)

    # Setup a tf Dataset containing all images.
    dataset = tf.data.Dataset.from_tensor_slices(data_fids)

    # Convert filenames to actual image tensors.
    dataset = dataset.map(lambda fid: common.fid_to_image(
        fid,
        tf.constant('dummy'),
        image_root=args.image_root,
        image_size=pre_crop_size if args.crop_augment else net_input_size),
                          num_parallel_calls=args.loading_threads)

    # Augment the data if specified by the arguments.
    # `modifiers` is a list of strings that keeps track of which augmentations
    # have been applied, so that a human can understand it later on.
    modifiers = ['original']
    if args.flip_augment:
        dataset = dataset.map(flip_augment)
        dataset = dataset.apply(tf.contrib.data.unbatch())
        modifiers = [o + m for m in ['', '_flip'] for o in modifiers]

    if args.crop_augment == 'center':
        dataset = dataset.map(lambda im, fid, pid:
                              (five_crops(im, net_input_size)[0], fid, pid))
        modifiers = [o + '_center' for o in modifiers]
    elif args.crop_augment == 'five':
        dataset = dataset.map(lambda im, fid, pid:
                              (tf.stack(five_crops(im, net_input_size)),
                               tf.stack([fid] * 5), tf.stack([pid] * 5)))
        dataset = dataset.apply(tf.contrib.data.unbatch())
        modifiers = [
            o + m for o in modifiers for m in [
                '_center', '_top_left', '_top_right', '_bottom_left',
                '_bottom_right'
            ]
        ]
    elif args.crop_augment == 'avgpool':
        modifiers = [o + '_avgpool' for o in modifiers]
    else:
        modifiers = [o + '_resize' for o in modifiers]

    # Group it back into PK batches.
    dataset = dataset.batch(args.batch_size)

    # Overlap producing and consuming.
    dataset = dataset.prefetch(1)

    images, _, _ = dataset.make_one_shot_iterator().get_next()
    tf.summary.image('image', images, 12)

    # Create the model and an embedding head.
    model = import_module('nets.' + args.model_name)
    head = import_module('heads.' + args.head_name)

    endpoints, body_prefix = model.endpoints(images, is_training=False)

    ############################################################################################################
    # with tf.name_scope('heatmap'):
    #     heatmap_in = endpoints[args.model_name+'/block4']
    #     heatmap_out = heatmap.hmnet(heatmap_in, 18)
    #     heat = tf.image.resize_images(heatmap_out, net_input_size)
    # tf.summary.image('heatmap', heat, 12)
    ############################################################################################################

    with tf.name_scope('head'):
        endpoints = head.head(endpoints, args.embedding_dim, is_training=False)

    dists = loss.cdist(endpoints['emb'], endpoints['emb'], metric=args.metric)

    tf.summary.histogram('embedding_dists', dists)

    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(args.experiment_root)

    with h5py.File(args.filename, 'w') as f_out, tf.Session() as sess:
        # Initialize the network/load the checkpoint.
        if args.checkpoint is None:
            checkpoint = tf.train.latest_checkpoint(args.experiment_root)
        else:
            checkpoint = os.path.join(args.experiment_root, args.checkpoint)
        if not args.quiet:
            print('Restoring from checkpoint: {}'.format(checkpoint))
        tf.train.Saver().restore(sess, checkpoint)

        # Go ahead and embed the whole dataset, with all augmented versions too.
        emb_storage = np.zeros(
            (len(data_fids) * len(modifiers), args.embedding_dim), np.float32)
        for start_idx in count(step=args.batch_size):
            try:
                emb, summary = sess.run([endpoints['emb'], merged])
                summary_writer.add_summary(summary)
                print('\rEmbedded batch {}-{}/{}'.format(
                    start_idx, start_idx + len(emb), len(emb_storage)),
                      flush=True,
                      end='')
                emb_storage[start_idx:start_idx + len(emb)] = emb
            except tf.errors.OutOfRangeError:
                break  # This just indicates the end of the dataset.

        print()
        if not args.quiet:
            print("Done with embedding, aggregating augmentations...",
                  flush=True)

        if len(modifiers) > 1:
            # Pull out the augmentations into a separate first dimension.
            emb_storage = emb_storage.reshape(len(data_fids), len(modifiers),
                                              -1)
            emb_storage = emb_storage.transpose((1, 0, 2))  # (Aug,FID,128D)

            # Store the embedding of all individual variants too.
            emb_dataset = f_out.create_dataset('emb_aug', data=emb_storage)

            # Aggregate according to the specified parameter.
            emb_storage = AGGREGATORS[args.aggregator](emb_storage)

        # Store the final embeddings.
        emb_dataset = f_out.create_dataset('emb', data=emb_storage)

        # Store information about the produced augmentation and in case no crop
        # augmentation was used, if the images are resized or avg pooled.
        f_out.create_dataset('augmentation_types',
                             data=np.asarray(modifiers, dtype='|S'))
예제 #12
0
def train(args, images, fids, pids, max_fid_len, log):
    '''
    Creation model and training neural network

    :param args: all stored arguments
    :param images: prepared images for training
    :param fids: figure id (relative paths from image_root to images)
    :param pids: person id (or car id) for all images
    :param log: log file, where logs from training are stored
    :return: saved files (checkpoints, train log file)
    '''
    ###################################################################################################################
    # CREATE MODEL
    ###################################################################################################################
    # Create the model and an embedding head.

    model = import_module('nets.resnet_v1_50')
    # Feed the image through the model. The returned `body_prefix` will be used
    # further down to load the pre-trained weights for all variables with this
    # prefix.
    drops = {}
    if args.dropout is not None:
        drops = getDropoutProbs(args.dropout)
    b4_layers = None
    try:
        b4_layers = int(args.b4_layers)
        if b4_layers not in [1, 2, 3]: raise ValueError()
    except:
        ValueError("Argument exception: b4_layers has to be in [1, 2, 3]")

    endpoints, body_prefix = model.endpoints(images,
                                             b4_layers,
                                             drops,
                                             is_training=True,
                                             resnet_stride=int(
                                                 args.resnet_stride))
    endpoints['emb'] = endpoints['emb_raw'] = slim.fully_connected(
        endpoints['model_output'],
        args.embedding_dim,
        activation_fn=None,
        weights_initializer=tf.orthogonal_initializer(),
        scope='emb')

    step_pl = tf.placeholder(dtype=tf.float32)

    features = endpoints['emb']

    # Create the loss in two steps:
    # 1. Compute all pairwise distances according to the specified metric.
    # 2. For each anchor along the first dimension, compute its loss.
    dists = loss.cdist(features, features, metric=args.metric)
    losses, train_top1, prec_at_k, _, probe_neg_dists, pos_dists, neg_dists = loss.loss_function(
        dists,
        pids, [args.alpha1, args.alpha2, args.alpha3],
        batch_precision_at_k=args.batch_k - 1)

    # Count the number of active entries, and compute the total batch loss.
    num_active = tf.reduce_sum(tf.cast(tf.greater(losses, 1e-5), tf.float32))
    loss_mean = tf.reduce_mean(losses)

    # Some logging for tensorboard.
    tf.summary.histogram('loss_distribution', losses)
    tf.summary.scalar('loss', loss_mean)
    tf.summary.scalar('batch_top1', train_top1)
    tf.summary.scalar('batch_prec_at_{}'.format(args.batch_k - 1), prec_at_k)
    tf.summary.scalar('active_count', num_active)
    tf.summary.scalar('embedding_pos_dists', tf.reduce_mean(pos_dists))
    tf.summary.scalar('embedding_probe_neg_dists',
                      tf.reduce_mean(probe_neg_dists))
    tf.summary.scalar('embedding_neg_dists', tf.reduce_mean(neg_dists))
    tf.summary.histogram('embedding_dists', dists)
    tf.summary.histogram('embedding_pos_dists', pos_dists)
    tf.summary.histogram('embedding_probe_neg_dists', probe_neg_dists)
    tf.summary.histogram('embedding_neg_dists', neg_dists)
    tf.summary.histogram('embedding_lengths',
                         tf.norm(endpoints['emb_raw'], axis=1))

    # Create the mem-mapped arrays in which we'll log all training detail in
    # addition to tensorboard, because tensorboard is annoying for detailed
    # inspection and actually discards data in histogram summaries.
    batch_size = args.batch_p * args.batch_k
    if args.detailed_logs:
        log_embs = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'embeddings'),
            dtype=np.float32,
            shape=(args.train_iterations, batch_size, args.embedding_dim))
        log_loss = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'losses'),
            dtype=np.float32,
            shape=(args.train_iterations, batch_size))
        log_fids = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'fids'),
            dtype='S' + str(max_fid_len),
            shape=(args.train_iterations, batch_size))

    # These are collected here before we add the optimizer, because depending
    # on the optimizer, it might add extra slots, which are also global
    # variables, with the exact same prefix.
    model_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                        body_prefix)

    # Define the optimizer and the learning-rate schedule.
    # Unfortunately, we get NaNs if we don't handle no-decay separately.
    global_step = tf.Variable(0, name='global_step', trainable=False)

    if args.sgdr:
        learning_rate = tf.train.cosine_decay_restarts(
            learning_rate=args.learning_rate,
            global_step=global_step,
            first_decay_steps=4000,
            t_mul=1.5)
    else:
        if 0 <= args.decay_start_iteration < args.train_iterations:
            learning_rate = tf.train.exponential_decay(
                args.learning_rate,
                tf.maximum(0, global_step - args.decay_start_iteration),
                args.train_iterations - args.decay_start_iteration,
                float(args.lr_decay))
        else:
            learning_rate = args.learning_rate
    tf.summary.scalar('learning_rate', learning_rate)
    optimizer = tf.train.AdamOptimizer(tf.convert_to_tensor(learning_rate))

    # Update_ops are used to update batchnorm stats.
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        train_op = optimizer.minimize(loss_mean, global_step=global_step)

    # Define a saver for the complete model.
    checkpoint_saver = tf.train.Saver(max_to_keep=0)
    with tf.Session() as sess:
        if args.resume:
            # In case we're resuming, simply load the full checkpoint to init.
            last_checkpoint = tf.train.latest_checkpoint(args.experiment_root)
            log.info('Restoring from checkpoint: {}'.format(last_checkpoint))
            checkpoint_saver.restore(sess, last_checkpoint)
        else:
            # But if we're starting from scratch, we may need to load some
            # variables from the pre-trained weights, and random init others.
            sess.run(tf.global_variables_initializer())
            if args.initial_checkpoint is not None:
                saver = tf.train.Saver(model_variables)
                saver.restore(sess, args.initial_checkpoint)

            # In any case, we also store this initialization as a checkpoint,
            # such that we could run exactly reproduceable experiments.
            checkpoint_saver.save(sess,
                                  os.path.join(args.experiment_root,
                                               'checkpoint'),
                                  global_step=0)

        merged_summary = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(args.experiment_root,
                                               sess.graph)

        start_step = sess.run(global_step)
        step = start_step
        log.info('Starting training from iteration {}.'.format(start_step))

        ###################################################################################################################
        # TRAINING
        ###################################################################################################################
        # Finally, here comes the main-loop. This `Uninterrupt` is a handy
        # utility such that an iteration still finishes on Ctrl+C and we can
        # stop the training cleanly.
        with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u:
            for i in range(start_step, args.train_iterations):
                # Compute gradients, update weights, store logs!
                start_time = time.time()
                _, summary, step, b_prec_at_k, b_embs, b_loss, b_fids = \
                    sess.run([train_op, merged_summary, global_step,
                              prec_at_k, features, losses, fids], feed_dict={step_pl: step})
                elapsed_time = time.time() - start_time

                # Compute the iteration speed and add it to the summary.
                # We did observe some weird spikes that we couldn't track down.
                summary2 = tf.Summary()
                summary2.value.add(tag='secs_per_iter',
                                   simple_value=elapsed_time)
                summary_writer.add_summary(summary2, step)
                summary_writer.add_summary(summary, step)

                if args.detailed_logs:
                    log_embs[i], log_loss[i], log_fids[
                        i] = b_embs, b_loss, b_fids

                # Do a huge print out of the current progress. Maybe steal from here.
                seconds_todo = (args.train_iterations - step) * elapsed_time
                log.info(
                    'iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, '
                    'batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it), lr={:.4g}'.
                    format(step, float(np.min(b_loss)), float(np.mean(b_loss)),
                           float(np.max(b_loss)), args.batch_k - 1,
                           float(b_prec_at_k),
                           timedelta(seconds=int(seconds_todo)), elapsed_time,
                           sess.run(optimizer._lr)))
                sys.stdout.flush()
                sys.stderr.flush()

                # Save a checkpoint of training every so often.
                if (args.checkpoint_frequency > 0
                        and step % args.checkpoint_frequency == 0):
                    checkpoint_saver.save(sess,
                                          os.path.join(args.experiment_root,
                                                       'checkpoint'),
                                          global_step=step)

                # Stop the main-loop at the end of the step, if requested.
                if u.interrupted:
                    log.info("Interrupted on request!")
                    break

        # Store one final checkpoint. This might be redundant, but it is crucial
        # in case intermediate storing was disabled and it saves a checkpoint
        # when the process was interrupted.
        checkpoint_saver.save(sess,
                              os.path.join(args.experiment_root, 'checkpoint'),
                              global_step=step)
예제 #13
0
def main():
    # Verify that parameters are set correctly.
    args = parser.parse_args()

    # Load the query and gallery data from the CSV files.
    query_pids, query_fids = common.load_dataset(args.query_dataset, None)
    gallery_pids, gallery_fids = common.load_dataset(args.gallery_dataset,
                                                     None)

    # Load the two datasets fully into memory.
    with h5py.File(args.query_embeddings, 'r') as f_query:
        query_embs = np.array(f_query['emb'])
    with h5py.File(args.gallery_embeddings, 'r') as f_gallery:
        gallery_embs = np.array(f_gallery['emb'])

    # Just a quick sanity check that both have the same embedding dimension!
    query_dim = query_embs.shape[1]
    gallery_dim = gallery_embs.shape[1]
    if query_dim != gallery_dim:
        raise ValueError('Shape mismatch between query ({}) and gallery ({}) '
                         'dimension'.format(query_dim, gallery_dim))

    # Setup the dataset specific matching function
    excluder = import_module('excluders.' +
                             args.excluder).Excluder(gallery_fids)

    # We go through the queries in batches, but we always need the whole gallery
    batch_pids, batch_fids, batch_embs = tf.data.Dataset.from_tensor_slices(
        (query_pids, query_fids, query_embs)).batch(
            args.batch_size).make_one_shot_iterator().get_next()

    batch_distances = loss.cdist(batch_embs, gallery_embs, metric=args.metric)

    results = []

    # Loop over the query embeddings.
    with tf.Session() as sess:
        for start_idx in count(step=args.batch_size):
            try:
                # Compute distance to all gallery embeddings
                distances, pids, fids = sess.run(
                    [batch_distances, batch_pids, batch_fids])
                print('\rEvaluating batch {}-{}/{}'.format(
                    start_idx, start_idx + len(fids), len(query_fids)),
                      flush=True,
                      end='')
                print()
            except tf.errors.OutOfRangeError:
                print()  # Done!
                break

            # Convert the array of objects back to array of strings
            pids, fids = np.array(pids, '|U'), np.array(fids, '|U')

            # Compute the pid matches
            pid_matches = gallery_pids[None] == pids[:, None]

            # Get a mask indicating True for those gallery entries that should
            # be ignored for whatever reason (same camera, junk, ...) and
            # exclude those in a way that doesn't affect CMC and mAP.
            mask = excluder(fids)
            distances[mask] = np.inf
            pid_matches[mask] = False

            # Keep track of statistics. Invert distances to scores using any
            # arbitrary inversion, as long as it's monotonic and well-behaved,
            # it won't change anything.
            for i in range(len(distances)):
                print(gallery_fids[np.argsort(distances[i])[0]])
예제 #14
0
    def exe_query(self, query_extention = True):
        aps = []
        ques = []
        cmc = np.zeros(len(self.gallery_pids), dtype=np.int32)
        gallery_views_id, gallery_views_count = np.unique(self.gallery_views, return_counts=True)
        
        metric = 'euclidean'
        print(self.gallery_embs.shape)
        print(self.query_embs.shape)
        batch_pids, batch_fids, batch_embs = tf.data.Dataset.from_tensor_slices(
        (self.query_pids, self.query_fids, self.query_embs)).batch(self.batch_size).make_one_shot_iterator().get_next()
        batch_distances = loss.cdist(batch_embs, self.gallery_embs, metric=metric)
        
        self.submission_file = "track2.txt"
        print("Total queries: ", len(self.query_fids))
        print("Results folder: ", self.result_folder)
        print("Submission file: ", self.submission_file)
        
        
        dist_h5_file = "results_dists/test798x798.h5"
        tracklet_dists = load_h5(dist_h5_file)
        tracklet_mapper = load_h5(dist_h5_file, "mapper")
        trklet_dict = {}
        
        for i, trid in enumerate(tracklet_mapper):
            trklet_dict[i] = i
        
        with tf.Session() as sess, open(self.submission_file, "w") as f_sub:
            for start_idx in count(step=self.batch_size):
                try:
                    if (query_extention):
                        top1_view = tf_get_top1_view(batch_distances, self.gallery_views)
                        que_ext_re_ranking = tf_query_extention(top1_view, self.gallery_views, self.gallery_embs)
                        top1_views, distances, pids, fids = sess.run([top1_view, que_ext_re_ranking, batch_pids, batch_fids])
                    else:
                        distances, pids, fids = sess.run([batch_distances, batch_pids, batch_fids])
                        top1_view = np.zeros(fids.shape, dtype=int)

                    print('\rCalculating batch {}-{}/{}'.format( start_idx, start_idx + len(fids), len(self.query_fids)), flush=True, end='')

                except tf.errors.OutOfRangeError:
                    print()  # Done!
                    break

                pids, fids = np.array(pids, '|U'), np.array(fids, '|U')
                pid_matches = self.gallery_pids[None] == pids[:,None]
                scores = 1 / (1 + distances)
                
                for i in range(len(distances)):
                    fid = fids[i]
                    pid = pids[i]
                    pid_match = pid_matches[i,:]
                    score = scores[i]
                    top1_view = top1_views[i]
                    top1_view = trklet_dict[top1_view]
                    if(query_extention):
                        selected = self.select_NN(top1_view, tracklet_dists)
                        score = self.track_re_ranking(selected, tracklet_mapper, score)
                    
                    top100 = np.argsort(score)[-100:][::-1]
                    #Save submission file
                    save_submission(f_sub, top100, self.gallery_fids)
                    #Save predict results:
                    save_predict_results(top100, self.result_folder,score, pid, fid, pid_match, self.gallery_pids, self.gallery_fids, self.gallery_views, self.gal_root.split("/")[-2])
                    
                    #Calculate AP:
                    ap, k = calculate_ap(fid, pid_match, score)
                    cmc[k:] += 1
                    aps.append(ap)
                    ques.append(fid)

            # Save index.csv
            save_test_img_index(self.result_folder,ques,aps, self.query_root)  

            # Compute the actual cmc and mAP values
            cmc = cmc / len(self.query_pids)
            mean_ap = np.mean(aps)
            print('mAP: {:.2%} | top-1: {:.2%} top-2: {:.2%} | top-5: {:.2%} | top-10: {:.2%}'.format(
                mean_ap, cmc[0], cmc[1], cmc[4], cmc[9]))
            
            
            
예제 #15
0
def tf_query_imgs(ins, gallery_embs, metric):
    ins_distances = loss.cdist(ins, gallery_embs, metric=metric)
    return tf.reduce_mean(ins_distances, axis=0)
예제 #16
0
def main():
    args = parser.parse_args()

    # Data augmentation
    global seq_geo
    global seq_img
    seq_geo = iaa.SomeOf(
        (0, 5),
        [
            iaa.Fliplr(0.5),  # horizontally flip 50% of the images
            iaa.PerspectiveTransform(scale=(0, 0.075)),
            iaa.Affine(
                scale={
                    "x": (0.8, 1.0),
                    "y": (0.8, 1.0)
                },
                rotate=(-5, 5),
                translate_percent={
                    "x": (-0.1, 0.1),
                    "y": (-0.1, 0.1)
                },
            ),  # rotate by -45 to +45 degrees),
            iaa.Crop(pc=(
                0, 0.125
            )),  # crop images from each side by 0 to 12.5% (randomly chosen)
            iaa.CoarsePepper(p=0.01, size_percent=0.1)
        ],
        random_order=False)
    # Content transformation
    seq_img = iaa.SomeOf(
        (0, 3),
        [
            iaa.GaussianBlur(
                sigma=(0, 1.0)),  # blur images with a sigma of 0 to 2.0
            iaa.ContrastNormalization(alpha=(0.9, 1.1)),
            iaa.Grayscale(alpha=(0, 0.2)),
            iaa.Multiply((0.9, 1.1))
        ])

    # We store all arguments in a json file. This has two advantages:
    # 1. We can always get back and see what exactly that experiment was
    # 2. We can resume an experiment as-is without needing to remember all flags.
    args_file = os.path.join(args.experiment_root, 'args.json')
    if args.resume:
        if not os.path.isfile(args_file):
            raise IOError('`args.json` not found in {}'.format(args_file))

        print('Loading args from {}.'.format(args_file))
        with open(args_file, 'r') as f:
            args_resumed = json.load(f)
        args_resumed['resume'] = True  # This would be overwritten.

        # When resuming, we not only want to populate the args object with the
        # values from the file, but we also want to check for some possible
        # conflicts between loaded and given arguments.
        for key, value in args.__dict__.items():
            if key in args_resumed:
                resumed_value = args_resumed[key]
                if resumed_value != value:
                    print('Warning: For the argument `{}` we are using the'
                          ' loaded value `{}`. The provided value was `{}`'
                          '.'.format(key, resumed_value, value))
                    args.__dict__[key] = resumed_value
            else:
                print('Warning: A new argument was added since the last run:'
                      ' `{}`. Using the new value: `{}`.'.format(key, value))

    else:
        # If the experiment directory exists already, we bail in fear.
        if os.path.exists(args.experiment_root):
            if os.listdir(args.experiment_root):
                print('The directory {} already exists and is not empty.'
                      ' If you want to resume training, append --resume to'
                      ' your call.'.format(args.experiment_root))
                exit(1)
        else:
            os.makedirs(args.experiment_root)

        # Store the passed arguments for later resuming and grepping in a nice
        # and readable format.
        with open(args_file, 'w') as f:
            json.dump(vars(args),
                      f,
                      ensure_ascii=False,
                      indent=2,
                      sort_keys=True)

    log_file = os.path.join(args.experiment_root, "train")
    logging.config.dictConfig(common.get_logging_dict(log_file))
    log = logging.getLogger('train')

    # Also show all parameter values at the start, for ease of reading logs.
    log.info('Training using the following parameters:')
    for key, value in sorted(vars(args).items()):
        log.info('{}: {}'.format(key, value))

    # Check them here, so they are not required when --resume-ing.
    if not args.train_set:
        parser.print_help()
        log.error("You did not specify the `train_set` argument!")
        sys.exit(1)
    if not args.image_root:
        parser.print_help()
        log.error("You did not specify the required `image_root` argument!")
        sys.exit(1)

    # Load the data from the CSV file.
    pids, fids = common.load_dataset(args.train_set, args.image_root)
    max_fid_len = max(map(len, fids))  # We'll need this later for logfiles.

    # Load feature embeddings
    if args.hard_pool_size > 0:
        with h5py.File(args.train_embeddings, 'r') as f_train:
            train_embs = np.array(f_train['emb'])
            f_dists = scipy.spatial.distance.cdist(train_embs, train_embs)
            hard_ids = get_hard_id_pool(pids, f_dists, args.hard_pool_size)

    # Setup a tf.Dataset where one "epoch" loops over all PIDS.
    # PIDS are shuffled after every epoch and continue indefinitely.
    unique_pids = np.unique(pids)
    dataset = tf.data.Dataset.from_tensor_slices(unique_pids)
    dataset = dataset.shuffle(len(unique_pids))

    # Constrain the dataset size to a multiple of the batch-size, so that
    # we don't get overlap at the end of each epoch.
    if args.hard_pool_size == 0:
        dataset = dataset.take(
            (len(unique_pids) // args.batch_p) * args.batch_p)
        dataset = dataset.repeat(
            None)  # Repeat forever. Funny way of stating it.

    else:
        dataset = dataset.repeat(
            None)  # Repeat forever. Funny way of stating it.
        dataset = dataset.map(lambda pid: sample_batch_ids_for_pid(
            pid, all_pids=pids, batch_p=args.batch_p, all_hard_pids=hard_ids))
        # Unbatch the P PIDs
        dataset = dataset.apply(tf.contrib.data.unbatch())

    # For every PID, get K images.
    dataset = dataset.map(lambda pid: sample_k_fids_for_pid(
        pid, all_fids=fids, all_pids=pids, batch_k=args.batch_k))

    # Ungroup/flatten the batches for easy loading of the files.
    dataset = dataset.apply(tf.contrib.data.unbatch())

    # Convert filenames to actual image tensors.
    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)
    dataset = dataset.map(lambda im, fid, pid: common.fid_to_image(
        fid,
        pid,
        image_root=args.image_root,
        image_size=pre_crop_size if args.crop_augment else net_input_size),
                          num_parallel_calls=args.loading_threads)

    # Augment the data if specified by the arguments.
    if args.augment == False:
        dataset = dataset.map(
            lambda fid, pid: common.fid_to_image(fid,
                                                 pid,
                                                 image_root=args.image_root,
                                                 image_size=pre_crop_size
                                                 if args.crop_augment else
                                                 net_input_size),  #Ergys
            num_parallel_calls=args.loading_threads)

        if args.flip_augment:
            dataset = dataset.map(lambda im, fid, pid: (
                tf.image.random_flip_left_right(im), fid, pid))
        if args.crop_augment:
            dataset = dataset.map(lambda im, fid, pid: (tf.random_crop(
                im, net_input_size + (3, )), fid, pid))
    else:
        dataset = dataset.map(lambda im, fid, pid: common.fid_to_image(
            fid, pid, image_root=args.image_root, image_size=net_input_size),
                              num_parallel_calls=args.loading_threads)

        dataset = dataset.map(lambda im, fid, pid: (tf.py_func(
            augment_images, [im], [tf.float32]), fid, pid))
        dataset = dataset.map(lambda im, fid, pid: (tf.reshape(
            im[0],
            (args.net_input_height, args.net_input_width, 3)), fid, pid))

    # Group it back into PK batches.
    batch_size = args.batch_p * args.batch_k
    dataset = dataset.batch(batch_size)

    # Overlap producing and consuming for parallelism.
    dataset = dataset.prefetch(batch_size * 2)

    # Since we repeat the data infinitely, we only need a one-shot iterator.
    images, fids, pids = dataset.make_one_shot_iterator().get_next()

    # Create the model and an embedding head.
    model = import_module('nets.' + args.model_name)
    head = import_module('heads.' + args.head_name)

    # Feed the image through the model. The returned `body_prefix` will be used
    # further down to load the pre-trained weights for all variables with this
    # prefix.
    endpoints, body_prefix = model.endpoints(images, is_training=True)
    with tf.name_scope('head'):
        endpoints = head.head(endpoints, args.embedding_dim, is_training=True)

    # Create the loss in two steps:
    # 1. Compute all pairwise distances according to the specified metric.
    # 2. For each anchor along the first dimension, compute its loss.
    dists = loss.cdist(endpoints['emb'], endpoints['emb'], metric=args.metric)
    losses, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[
        args.loss](dists,
                   pids,
                   args.margin,
                   batch_precision_at_k=args.batch_k - 1)

    # Count the number of active entries, and compute the total batch loss.
    num_active = tf.reduce_sum(tf.cast(tf.greater(losses, 1e-5), tf.float32))
    loss_mean = tf.reduce_mean(losses)

    # Some logging for tensorboard.
    tf.summary.histogram('loss_distribution', losses)
    tf.summary.scalar('loss', loss_mean)
    tf.summary.scalar('batch_top1', train_top1)
    tf.summary.scalar('batch_prec_at_{}'.format(args.batch_k - 1), prec_at_k)
    tf.summary.scalar('active_count', num_active)
    tf.summary.histogram('embedding_dists', dists)
    tf.summary.histogram('embedding_pos_dists', pos_dists)
    tf.summary.histogram('embedding_neg_dists', neg_dists)
    tf.summary.histogram('embedding_lengths',
                         tf.norm(endpoints['emb_raw'], axis=1))

    # Create the mem-mapped arrays in which we'll log all training detail in
    # addition to tensorboard, because tensorboard is annoying for detailed
    # inspection and actually discards data in histogram summaries.
    if args.detailed_logs:
        log_embs = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'embeddings'),
            dtype=np.float32,
            shape=(args.train_iterations, batch_size, args.embedding_dim))
        log_loss = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'losses'),
            dtype=np.float32,
            shape=(args.train_iterations, batch_size))
        log_fids = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'fids'),
            dtype='S' + str(max_fid_len),
            shape=(args.train_iterations, batch_size))

    # These are collected here before we add the optimizer, because depending
    # on the optimizer, it might add extra slots, which are also global
    # variables, with the exact same prefix.
    model_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                        body_prefix)

    # Define the optimizer and the learning-rate schedule.
    # Unfortunately, we get NaNs if we don't handle no-decay separately.
    global_step = tf.Variable(0, name='global_step', trainable=False)
    if 0 <= args.decay_start_iteration < args.train_iterations:
        learning_rate = tf.train.exponential_decay(
            args.learning_rate,
            tf.maximum(0, global_step - args.decay_start_iteration),
            args.train_iterations - args.decay_start_iteration, 0.001)
    else:
        learning_rate = args.learning_rate
    tf.summary.scalar('learning_rate', learning_rate)
    optimizer = tf.train.AdamOptimizer(learning_rate)
    # Feel free to try others!
    # optimizer = tf.train.AdadeltaOptimizer(learning_rate)

    # Update_ops are used to update batchnorm stats.
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        train_op = optimizer.minimize(loss_mean, global_step=global_step)

    # Define a saver for the complete model.
    checkpoint_saver = tf.train.Saver(max_to_keep=0)

    with tf.Session() as sess:
        if args.resume:
            # In case we're resuming, simply load the full checkpoint to init.
            last_checkpoint = tf.train.latest_checkpoint(args.experiment_root)
            log.info('Restoring from checkpoint: {}'.format(last_checkpoint))
            checkpoint_saver.restore(sess, last_checkpoint)
        else:
            # But if we're starting from scratch, we may need to load some
            # variables from the pre-trained weights, and random init others.
            sess.run(tf.global_variables_initializer())
            if args.initial_checkpoint is not None:
                saver = tf.train.Saver(model_variables)
                saver.restore(sess, args.initial_checkpoint)

            # In any case, we also store this initialization as a checkpoint,
            # such that we could run exactly reproduceable experiments.
            checkpoint_saver.save(sess,
                                  os.path.join(args.experiment_root,
                                               'checkpoint'),
                                  global_step=0)

        merged_summary = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(args.experiment_root,
                                               sess.graph)

        start_step = sess.run(global_step)
        log.info('Starting training from iteration {}.'.format(start_step))

        # Finally, here comes the main-loop. This `Uninterrupt` is a handy
        # utility such that an iteration still finishes on Ctrl+C and we can
        # stop the training cleanly.
        with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u:
            for i in range(start_step, args.train_iterations):

                # Compute gradients, update weights, store logs!
                start_time = time.time()
                _, summary, step, b_prec_at_k, b_embs, b_loss, b_fids = \
                    sess.run([train_op, merged_summary, global_step,
                              prec_at_k, endpoints['emb'], losses, fids])
                elapsed_time = time.time() - start_time

                # Compute the iteration speed and add it to the summary.
                # We did observe some weird spikes that we couldn't track down.
                summary2 = tf.Summary()
                summary2.value.add(tag='secs_per_iter',
                                   simple_value=elapsed_time)
                summary_writer.add_summary(summary2, step)
                summary_writer.add_summary(summary, step)

                if args.detailed_logs:
                    log_embs[i], log_loss[i], log_fids[
                        i] = b_embs, b_loss, b_fids

                # Do a huge print out of the current progress.
                seconds_todo = (args.train_iterations - step) * elapsed_time
                log.info(
                    'iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, '
                    'batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it)'.format(
                        step, float(np.min(b_loss)), float(np.mean(b_loss)),
                        float(np.max(b_loss)), args.batch_k - 1,
                        float(b_prec_at_k),
                        timedelta(seconds=int(seconds_todo)), elapsed_time))
                sys.stdout.flush()
                sys.stderr.flush()

                # Save a checkpoint of training every so often.
                if (args.checkpoint_frequency > 0
                        and step % args.checkpoint_frequency == 0):
                    checkpoint_saver.save(sess,
                                          os.path.join(args.experiment_root,
                                                       'checkpoint'),
                                          global_step=step)

                # Stop the main-loop at the end of the step, if requested.
                if u.interrupted:
                    log.info("Interrupted on request!")
                    break

        # Store one final checkpoint. This might be redundant, but it is crucial
        # in case intermediate storing was disabled and it saves a checkpoint
        # when the process was interrupted.
        checkpoint_saver.save(sess,
                              os.path.join(args.experiment_root, 'checkpoint'),
                              global_step=step)
예제 #17
0
def main():
    # args = parser.parse_args()

    # We store all arguments in a json file. This has two advantages:
    # 1. We can always get back and see what exactly that experiment was
    # 2. We can resume an experiment as-is without needing to remember all flags.

    train_config = cfg.TrainConfig()

    args_file = os.path.join(train_config.experiment_root, 'args.json')
    if train_config.resume:
        if not os.path.isfile(args_file):
            raise IOError('`args.json` not found in {}'.format(args_file))

        print('Loading args from {}.'.format(args_file))
        with open(args_file, 'r') as f:
            args_resumed = json.load(f)
        args_resumed['resume'] = True  # This would be overwritten.

        # When resuming, we not only want to populate the args object with the
        # values from the file, but we also want to check for some possible
        # conflicts between loaded and given arguments.
        for key, value in train_config.__dict__.items():
            if key in args_resumed:
                resumed_value = args_resumed[key]
                if resumed_value != value:
                    print('Warning: For the argument `{}` we are using the'
                          ' loaded value `{}`. The provided value was `{}`'
                          '.'.format(key, resumed_value, value))
                    train_config.__dict__[key] = resumed_value
            else:
                print('Warning: A new argument was added since the last run:'
                      ' `{}`. Using the new value: `{}`.'.format(key, value))

    else:
        # If the experiment directory exists already, we bail in fear.
        if os.path.exists(train_config.experiment_root):
            if os.listdir(train_config.experiment_root):
                print('The directory {} already exists and is not empty.'
                      ' If you want to resume training, append --resume to'
                      ' your call.'.format(train_config.experiment_root))
                exit(1)
        else:
            os.makedirs(train_config.experiment_root)

        # Store the passed arguments for later resuming and grepping in a nice
        # and readable format.
        with open(args_file, 'w') as f:
            json.dump(vars(args), f, ensure_ascii=False, indent=2, sort_keys=True)

    log_file = os.path.join(train_config.experiment_root, "train")
    logging.config.dictConfig(common.get_logging_dict(log_file))
    log = logging.getLogger('train')

    # Also show all parameter values at the start, for ease of reading logs.
    log.info('Training using the following parameters:')
    for key, value in sorted(vars(args).items()):
        log.info('{}: {}'.format(key, value))

    # Check them here, so they are not required when --resume-ing.
    if not train_config.train_set:
        parser.print_help()
        log.error("You did not specify the `train_set` argument!")
        sys.exit(1)
    if not train_config.image_root:
        parser.print_help()
        log.error("You did not specify the required `image_root` argument!")
        sys.exit(1)

    # Load the data from the CSV file.
    pids, fids = common.load_dataset(train_config.train_set, train_config.image_root, is_train=True)

    max_fid_len = max(map(len, fids))  # We'll need this later for logfiles

    # Setup a tf.Dataset where one "epoch" loops over all PIDS.
    # PIDS are shuffled after every epoch and continue indefinitely.

    unique_pids = np.unique(pids)

    dataset = tf.data.Dataset.from_tensor_slices(unique_pids)

    dataset = dataset.shuffle(len(unique_pids))

    # Constrain the dataset size to a multiple of the batch-size, so that
    # we don't get overlap at the end of each epoch.
    dataset = dataset.take((len(unique_pids) // train_config.batch_p) * train_config.batch_p)
    # take(count)  Creates a Dataset with at most count elements from this dataset.

    dataset = dataset.repeat(None)  # Repeat forever. Funny way of stating it.
    # Repeats this dataset count times.

    # For every PID, get K images.
    dataset = dataset.map(lambda pid: sample_k_fids_for_pid(
        pid, all_fids=fids, all_pids=pids, batch_k=train_config.batch_k))

    # Ungroup/flatten the batches for easy loading of the files.
    dataset = dataset.apply(tf.contrib.data.unbatch())
    # apply(transformation_func) Apply a transformation function to this dataset.
    # apply enables chaining of custom Dataset transformations, which are represented as functions that take one Dataset argument and return a transformed Dataset.


    # Convert filenames to actual image tensors.
    net_input_size = (train_config.net_input_height, train_config.net_input_width)
    # 256,128
    pre_crop_size = (train_config.pre_crop_height, train_config.pre_crop_width)
    # 288,144
    
    dataset = dataset.map(
        lambda fid, pid: common.fid_to_image_label(
            fid, pid, image_root=train_config.image_root,
            image_size=pre_crop_size if train_config.crop_augment else net_input_size),
        num_parallel_calls=train_config.loading_threads)

###########################################################################################
    dataset = dataset.map(
        lambda im, keypt, mask, fid, pid: (tf.concat([im, keypt, mask], 2), fid, pid))

###########################################################################################
    
    # Augment the data if specified by the arguments.
    if train_config.flip_augment:
        dataset = dataset.map(
            lambda im, fid, pid: (tf.image.random_flip_left_right(im), fid, pid))


    # net_input_size_aug = net_input_size + (4,)
    if train_config.crop_augment:
        dataset = dataset.map(
            lambda im, fid, pid: (tf.random_crop(im, net_input_size + (21,)), fid, pid))
    # net_input_size + (21,) = (256, 128, 21)
    # split

#############################################################################################
    dataset = dataset.map(
        lambda im, fid, pid: (common.split(im, fid, pid)))

#############################################################################################

    # Group it back into PK batches.
    batch_size = train_config.batch_p * train_config.batch_k
    dataset = dataset.batch(batch_size)

    # Overlap producing and consuming for parallelism.
    dataset = dataset.prefetch(1)
    # prefetch(buffer_size)   Creates a Dataset that prefetches elements from this dataset.

    # Since we repeat the data infinitely, we only need a one-shot iterator.
    images, keypts, masks, fids, pids = dataset.make_one_shot_iterator().get_next()
    # tf.summary.image('image',images,10)

    # Create the model and an embedding head.
    model = import_module('nets.' + train_config.model_name)
    head = import_module('heads.' + train_config.head_name)

    # Feed the image through the model. The returned `body_prefix` will be used
    # further down to load the pre-trained weights for all variables with this
    # prefix.

    endpoints, body_prefix = model.endpoints(images, is_training=True)
    heatmap_in = endpoints[train_config.model_name + '/block4']
    # resnet_block_4_out = heatmap.resnet_block_4(heatmap_in)
    # resnet_block_3_4_out = heatmap.resnet_block_3_4(heatmap_in)
    # resnet_block_2_3_4_out = heatmap.resnet_block_2_3_4(heatmap_in)
    # head for heatmap
    with tf.name_scope('heatmap'):
        # heatmap_in = endpoints['model_output']
        # heatmap_out_layer_0 = heatmap.hmnet_layer_0(resnet_block_4_out, 1)
        # heatmap_out_layer_0 = heatmap.hmnet_layer_0(resnet_block_3_4_out, 1)
        # heatmap_out_layer_0 = heatmap.hmnet_layer_0(resnet_block_2_3_4_out, 1)
        heatmap_out_layer_0 = VAC.hmnet_layer_0(heatmap_in[:, :, :, 1020:2048], 1)
        heatmap_out_layer_1 = VAC.hmnet_layer_1(heatmap_out_layer_0, 1)
        heatmap_out_layer_2 = VAC.hmnet_layer_2(heatmap_out_layer_1, 1)
        heatmap_out_layer_3 = VAC.hmnet_layer_3(heatmap_out_layer_2, 1)
        heatmap_out_layer_4 = VAC.hmnet_layer_4(heatmap_out_layer_3, 1)
        heatmap_out = heatmap_out_layer_4
        heatmap_loss = VAC.loss_mutilayer(heatmap_out_layer_0, heatmap_out_layer_1, heatmap_out_layer_2,
                                              heatmap_out_layer_3, heatmap_out_layer_4, masks, net_input_size)
        # heatmap_loss = heatmap.loss(heatmap_out, labels, net_input_size)
        # heatmap_loss_mean = heatmap_loss

    with tf.name_scope('head'):
        # heatmap_sum = tf.reduce_sum(heatmap_out, axis=3)
        # heatmap_resize = tf.image.resize_images(tf.expand_dims(heatmap_sum, axis=3), [8, 4])
        # featuremap_tmp = tf.multiply(heatmap_resize, endpoints[args.model_name + '/block4'])
        # endpoints[args.model_name + '/block4'] = featuremap_tmp
        endpoints = head.head(endpoints, train_config.embedding_dim, is_training=True)

        tf.summary.image('feature_map', tf.expand_dims(endpoints[train_config.model_name + '/block4'][:, :, :, 0], axis=3), 4)


    with tf.name_scope('keypoints_pre'):
        keypoints_pre_in = endpoints[train_config.model_name + '/block4']
        # keypoints_pre_in_0 = keypoints_pre_in[:, :, :, 0:256]
        # keypoints_pre_in_1 = keypoints_pre_in[:, :, :, 256:512]
        # keypoints_pre_in_2 = keypoints_pre_in[:, :, :, 512:768]
        # keypoints_pre_in_3 = keypoints_pre_in[:, :, :, 768:1024]
        keypoints_pre_in_0 = keypoints_pre_in[:, :, :, 0:170]
        keypoints_pre_in_1 = keypoints_pre_in[:, :, :, 170:340]
        keypoints_pre_in_2 = keypoints_pre_in[:, :, :, 340:510]
        keypoints_pre_in_3 = keypoints_pre_in[:, :, :, 510:680]
        keypoints_pre_in_4 = keypoints_pre_in[:, :, :, 680:850]
        keypoints_pre_in_5 = keypoints_pre_in[:, :, :, 850:1020]

        labels = tf.image.resize_images(keypts, [128, 64])
        # keypoints_gt_0 = tf.concat([labels[:, :, :, 0:5], labels[:, :, :, 14:15], labels[:, :, :, 15:16], labels[:, :, :, 16:17], labels[:, :, :, 17:18]], 3)
        # keypoints_gt_1 = tf.concat([labels[:, :, :, 1:2], labels[:, :, :, 2:3], labels[:, :, :, 3:4], labels[:, :, :, 5:6]], 3)
        # keypoints_gt_2 = tf.concat([labels[:, :, :, 4:5], labels[:, :, :, 7:8], labels[:, :, :, 8:9], labels[:, :, :, 11:12]], 3)
        # keypoints_gt_3 = tf.concat([labels[:, :, :, 9:10], labels[:, :, :, 10:11], labels[:, :, :, 12:13], labels[:, :, :, 13:14]], 3)

        keypoints_gt_0 = labels[:, :, :, 0:5]
        keypoints_gt_1 = labels[:, :, :, 5:7]
        keypoints_gt_2 = labels[:, :, :, 7:9]
        keypoints_gt_3 = labels[:, :, :, 9:13]
        keypoints_gt_4 = labels[:, :, :, 13:15]
        keypoints_gt_5 = labels[:, :, :, 15:17]

        keypoints_pre_0 = PAC.tran_conv_0(keypoints_pre_in, kp_num=5)
        keypoints_pre_1 = PAC.tran_conv_1(keypoints_pre_in, kp_num=2)
        keypoints_pre_2 = PAC.tran_conv_2(keypoints_pre_in, kp_num=2)
        keypoints_pre_3 = PAC.tran_conv_3(keypoints_pre_in, kp_num=4)
        keypoints_pre_4 = PAC.tran_conv_4(keypoints_pre_in, kp_num=2)
        keypoints_pre_5 = PAC.tran_conv_5(keypoints_pre_in, kp_num=2)

        keypoints_loss_0 = PAC.keypoints_loss(keypoints_pre_0, keypoints_gt_0)
        keypoints_loss_1 = PAC.keypoints_loss(keypoints_pre_1, keypoints_gt_1)
        keypoints_loss_2 = PAC.keypoints_loss(keypoints_pre_2, keypoints_gt_2)
        keypoints_loss_3 = PAC.keypoints_loss(keypoints_pre_3, keypoints_gt_3)
        keypoints_loss_4 = PAC.keypoints_loss(keypoints_pre_4, keypoints_gt_4)
        keypoints_loss_5 = PAC.keypoints_loss(keypoints_pre_5, keypoints_gt_5)

        keypoints_loss = 5/17*keypoints_loss_0 + 2/17*keypoints_loss_1 + 2/17*keypoints_loss_2 + 4/17*keypoints_loss_3 + 2/17*keypoints_loss_4 + 2/17*keypoints_loss_5


    # Create the loss in two steps:
    # 1. Compute all pairwise distances according to the specified metric.
    # 2. For each anchor along the first dimension, compute its loss.
    dists = loss.cdist(endpoints['emb'], endpoints['emb'], metric=train_config.metric)
    losses, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[train_config.loss](
        dists, pids, train_config.margin, batch_precision_at_k=train_config.batch_k-1)

    # Count the number of active entries, and compute the total batch loss.
    num_active = tf.reduce_sum(tf.cast(tf.greater(losses, 1e-5), tf.float32))
    loss_mean = tf.reduce_mean(losses)
    
    scale_rate_0 = 1E-7
    scale_rate_1 = 6E-8
    total_loss = loss_mean + keypoints_loss*scale_rate_0 + heatmap_loss*scale_rate_1
    # total_loss = loss_mean + keypoints_loss * scale_rate_0
    # total_loss = loss_mean

    # Some logging for tensorboard.
    tf.summary.histogram('loss_distribution', losses)
    tf.summary.scalar('loss', loss_mean)
############################################################################################
    # tf.summary.histogram('hm_loss_distribution', heatmap_loss)
    tf.summary.scalar('keypt_loss_0', keypoints_loss_0)
    tf.summary.scalar('keypt_loss_1', keypoints_loss_1)
    tf.summary.scalar('keypt_loss_2', keypoints_loss_2)
    tf.summary.scalar('keypt_loss_3', keypoints_loss_3)
    tf.summary.scalar('keypt_loss_all', keypoints_loss)
############################################################################################
    tf.summary.scalar('total_loss', total_loss)
    tf.summary.scalar('batch_top1', train_top1)
    tf.summary.scalar('batch_prec_at_{}'.format(args.batch_k-1), prec_at_k)
    tf.summary.scalar('active_count', num_active)
    tf.summary.histogram('embedding_dists', dists)
    tf.summary.histogram('embedding_pos_dists', pos_dists)
    tf.summary.histogram('embedding_neg_dists', neg_dists)
    tf.summary.histogram('embedding_lengths',
                         tf.norm(endpoints['emb_raw'], axis=1))

    # Create the mem-mapped arrays in which we'll log all training detail in
    # addition to tensorboard, because tensorboard is annoying for detailed
    # inspection and actually discards data in histogram summaries.
    if args.detailed_logs:
        log_embs = lb.create_or_resize_dat(
            os.path.join(train_config.experiment_root, 'embeddings'),
            dtype=np.float32, shape=(train_config.train_iterations, batch_size, args.embedding_dim))
        log_loss = lb.create_or_resize_dat(
            os.path.join(train_config.experiment_root, 'losses'),
            dtype=np.float32, shape=(train_config.train_iterations, batch_size))
        log_fids = lb.create_or_resize_dat(
            os.path.join(train_config.experiment_root, 'fids'),
            dtype='S' + str(max_fid_len), shape=(train_config.train_iterations, batch_size))

    # These are collected here before we add the optimizer, because depending
    # on the optimizer, it might add extra slots, which are also global
    # variables, with the exact same prefix.
    model_variables = tf.get_collection(
        tf.GraphKeys.GLOBAL_VARIABLES, body_prefix)

    # Define the optimizer and the learning-rate schedule.
    # Unfortunately, we get NaNs if we don't handle no-decay separately.
    global_step = tf.Variable(0, name='global_step', trainable=False)
    if 0 <= train_config.decay_start_iteration < train_config.train_iterations:
        learning_rate = tf.train.exponential_decay(
            train_config.learning_rate,
            tf.maximum(0, global_step - train_config.decay_start_iteration),
            train_config.train_iterations - train_config.decay_start_iteration, 0.001)
    else:
        learning_rate = train_config.learning_rate
    tf.summary.scalar('learning_rate', learning_rate)
    optimizer = tf.train.AdamOptimizer(learning_rate)
    # Feel free to try others!
    # optimizer = tf.train.AdadeltaOptimizer(learning_rate)

    # Update_ops are used to update batchnorm stats.
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
    #    train_op = optimizer.minimize(loss_mean, global_step=global_step)
        train_op = optimizer.minimize(total_loss, global_step=global_step)
    #


    # Define a saver for the complete model.
    checkpoint_saver = tf.train.Saver(max_to_keep=0)


    gpu_options = tf.GPUOptions(allow_growth=True)
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        if train_config.resume:
            # In case we're resuming, simply load the full checkpoint to init.
            last_checkpoint = tf.train.latest_checkpoint(args.experiment_root)
            log.info('Restoring from checkpoint: {}'.format(last_checkpoint))
            checkpoint_saver.restore(sess, last_checkpoint)
        else:
            # But if we're starting from scratch, we may need to load some
            # variables from the pre-trained weights, and random init others.
            sess.run(tf.global_variables_initializer())
            if train_config.initial_checkpoint is not None:
                saver = tf.train.Saver(model_variables, write_version=tf.train.SaverDef.V1)
                
                saver.restore(sess, train_config.initial_checkpoint)

                # name_11 = 'resnet_v1_50/block4'
                # name_12 = 'resnet_v1_50/block3'
                # name_13 = 'resnet_v1_50/block2'
                # name_21 = 'Resnet_block_2_3_4/block4'
                # name_22 = 'Resnet_block_2_3_4/block3'
                # name_23 = 'Resnet_block_2_3_4/block2'
                # for var in tf.trainable_variables():
                #     var_name = var.name
                #     if re.match(name_11, var_name):
                #         dst_name = var_name.replace(name_11, name_21)
                #         tensor = tf.get_default_graph().get_tensor_by_name(var_name)
                #         dst_tensor = tf.get_default_graph().get_tensor_by_name(dst_name)
                #         tf.assign(dst_tensor, tensor)
                #     if re.match(name_12, var_name):
                #         dst_name = var_name.replace(name_12, name_22)
                #         tensor = tf.get_default_graph().get_tensor_by_name(var_name)
                #         dst_tensor = tf.get_default_graph().get_tensor_by_name(dst_name)
                #         tf.assign(dst_tensor, tensor)
                #     if re.match(name_13, var_name):
                #         dst_name = var_name.replace(name_13, name_23)
                #         tensor = tf.get_default_graph().get_tensor_by_name(var_name)
                #         dst_tensor = tf.get_default_graph().get_tensor_by_name(dst_name)
                #         tf.assign(dst_tensor, tensor)
            # In any case, we also store this initialization as a checkpoint,
            # such that we could run exactly reproduceable experiments.
            checkpoint_saver.save(sess, os.path.join(
                train_config.experiment_root, 'checkpoint'), global_step=0)

        merged_summary = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(train_config.experiment_root, sess.graph)

        start_step = sess.run(global_step)
        log.info('Starting training from iteration {}.'.format(start_step))

        # Finally, here comes the main-loop. This `Uninterrupt` is a handy
        # utility such that an iteration still finishes on Ctrl+C and we can
        # stop the training cleanly.
        with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u:
            for i in range(start_step, train_config.train_iterations):

                # Compute gradients, update weights, store logs!
                start_time = time.time()
                _, summary, step, b_prec_at_k, b_embs, b_loss, b_fids = \
                    sess.run([train_op, merged_summary, global_step,
                              prec_at_k, endpoints['emb'], losses, fids])
                elapsed_time = time.time() - start_time

                # Compute the iteration speed and add it to the summary.
                # We did observe some weird spikes that we couldn't track down.
                summary2 = tf.Summary()
                summary2.value.add(tag='secs_per_iter', simple_value=elapsed_time)
                summary_writer.add_summary(summary2, step)
                summary_writer.add_summary(summary, step)

                if train_config.detailed_logs:
                    log_embs[i], log_loss[i], log_fids[i] = b_embs, b_loss, b_fids

                # Do a huge print out of the current progress.
                seconds_todo = (train_config.train_iterations - step) * elapsed_time
                log.info('iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, '
                         'batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it)'.format(
                             step,
                             float(np.min(b_loss)),
                             float(np.mean(b_loss)),
                             float(np.max(b_loss)),
                             train_config.batch_k-1, float(b_prec_at_k),
                             timedelta(seconds=int(seconds_todo)),
                             elapsed_time))
                sys.stdout.flush()
                sys.stderr.flush()

                # Save a checkpoint of training every so often.
                if (train_config.checkpoint_frequency > 0 and
                        step % train_config.checkpoint_frequency == 0):
                    checkpoint_saver.save(sess, os.path.join(
                        train_config.experiment_root, 'checkpoint'), global_step=step)

                # Stop the main-loop at the end of the step, if requested.
                if u.interrupted:
                    log.info("Interrupted on request!")
                    break

        # Store one final checkpoint. This might be redundant, but it is crucial
        # in case intermediate storing was disabled and it saves a checkpoint
        # when the process was interrupted.
        checkpoint_saver.save(sess, os.path.join(
            train_config.experiment_root, 'checkpoint'), global_step=step)