Ejemplo n.º 1
0
def main():
    # Verify that parameters are set correctly.
    args = parser.parse_args([])

    # Possibly auto-generate the output filename.
    if args.filename is None:
        basename = os.path.basename(args.dataset)
        args.filename = os.path.splitext(basename)[0] + '_embeddings.h5'
    args.filename = os.path.join(args.experiment_root, args.filename)

    _, data_fids = common.load_dataset(args.dataset, args.image_root)

    net_input_size = (args.net_input_height, args.net_input_width)
    # pre_crop_size = (args.pre_crop_height, args.pre_crop_width)

    # Setup a tf Dataset containing all images.
    dataset = tf.data.Dataset.from_tensor_slices(data_fids)

    # Convert filenames to actual image tensors.
    dataset = dataset.map(
        lambda fid: common.fid_to_image(fid,
                                        tf.constant('dummy'),
                                        image_root=args.image_root,
                                        image_size=net_input_size),
        num_parallel_calls=args.loading_threads)

    dataset = dataset.batch(args.batch_size)

    # Overlap producing and consuming.
    dataset = dataset.prefetch(1)

    model = Trinet(args.embedding_dim)
    # lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(args.learning_rate,args.train_iterations - args.decay_start_iteration, 0.001)
    optimizer = tf.keras.optimizers.Adam()
    ckpt = tf.train.Checkpoint(step=tf.Variable(0),
                               optimizer=optimizer,
                               net=model)
    manager = tf.train.CheckpointManager(ckpt,
                                         args.experiment_root,
                                         max_to_keep=10)

    ckpt.restore(manager.latest_checkpoint)

    with h5py.File(args.filename, 'w') as f_out:
        emb_storage = np.zeros((len(data_fids), args.embedding_dim),
                               np.float32)
        start_idx = 0
        for images, fids, pids in dataset:
            emb = model(images, training=False)
            emb_storage[start_idx:start_idx + len(emb)] = emb
            start_idx += args.batch_size
        emb_dataset = f_out.create_dataset('emb', data=emb_storage)
Ejemplo n.º 2
0
def main():
    args = parser.parse_args()

    # Data augmentation
    global seq_geo
    global seq_img
    seq_geo = iaa.SomeOf(
        (0, 5),
        [
            iaa.Fliplr(0.5),  # horizontally flip 50% of the images
            iaa.PerspectiveTransform(scale=(0, 0.075)),
            iaa.Affine(
                scale={
                    "x": (0.8, 1.0),
                    "y": (0.8, 1.0)
                },
                rotate=(-5, 5),
                translate_percent={
                    "x": (-0.1, 0.1),
                    "y": (-0.1, 0.1)
                },
            ),  # rotate by -45 to +45 degrees),
            iaa.Crop(pc=(
                0, 0.125
            )),  # crop images from each side by 0 to 12.5% (randomly chosen)
            iaa.CoarsePepper(p=0.01, size_percent=0.1)
        ],
        random_order=False)
    # Content transformation
    seq_img = iaa.SomeOf(
        (0, 3),
        [
            iaa.GaussianBlur(
                sigma=(0, 1.0)),  # blur images with a sigma of 0 to 2.0
            iaa.ContrastNormalization(alpha=(0.9, 1.1)),
            iaa.Grayscale(alpha=(0, 0.2)),
            iaa.Multiply((0.9, 1.1))
        ])

    # We store all arguments in a json file. This has two advantages:
    # 1. We can always get back and see what exactly that experiment was
    # 2. We can resume an experiment as-is without needing to remember all flags.
    args_file = os.path.join(args.experiment_root, 'args.json')
    if args.resume:
        if not os.path.isfile(args_file):
            raise IOError('`args.json` not found in {}'.format(args_file))

        print('Loading args from {}.'.format(args_file))
        with open(args_file, 'r') as f:
            args_resumed = json.load(f)
        args_resumed['resume'] = True  # This would be overwritten.

        # When resuming, we not only want to populate the args object with the
        # values from the file, but we also want to check for some possible
        # conflicts between loaded and given arguments.
        for key, value in args.__dict__.items():
            if key in args_resumed:
                resumed_value = args_resumed[key]
                if resumed_value != value:
                    print('Warning: For the argument `{}` we are using the'
                          ' loaded value `{}`. The provided value was `{}`'
                          '.'.format(key, resumed_value, value))
                    args.__dict__[key] = resumed_value
            else:
                print('Warning: A new argument was added since the last run:'
                      ' `{}`. Using the new value: `{}`.'.format(key, value))

    else:
        # If the experiment directory exists already, we bail in fear.
        if os.path.exists(args.experiment_root):
            if os.listdir(args.experiment_root):
                print('The directory {} already exists and is not empty.'
                      ' If you want to resume training, append --resume to'
                      ' your call.'.format(args.experiment_root))
                exit(1)
        else:
            os.makedirs(args.experiment_root)

        # Store the passed arguments for later resuming and grepping in a nice
        # and readable format.
        with open(args_file, 'w') as f:
            json.dump(vars(args),
                      f,
                      ensure_ascii=False,
                      indent=2,
                      sort_keys=True)

    log_file = os.path.join(args.experiment_root, "train")
    logging.config.dictConfig(common.get_logging_dict(log_file))
    log = logging.getLogger('train')

    # Also show all parameter values at the start, for ease of reading logs.
    log.info('Training using the following parameters:')
    for key, value in sorted(vars(args).items()):
        log.info('{}: {}'.format(key, value))

    # Check them here, so they are not required when --resume-ing.
    if not args.train_set:
        parser.print_help()
        log.error("You did not specify the `train_set` argument!")
        sys.exit(1)
    if not args.image_root:
        parser.print_help()
        log.error("You did not specify the required `image_root` argument!")
        sys.exit(1)

    # Load the data from the CSV file.
    pids, fids = common.load_dataset(args.train_set, args.image_root)
    max_fid_len = max(map(len, fids))  # We'll need this later for logfiles.

    # Load feature embeddings
    if args.hard_pool_size > 0:
        with h5py.File(args.train_embeddings, 'r') as f_train:
            train_embs = np.array(f_train['emb'])
            f_dists = scipy.spatial.distance.cdist(train_embs, train_embs)
            hard_ids = get_hard_id_pool(pids, f_dists, args.hard_pool_size)

    # Setup a tf.Dataset where one "epoch" loops over all PIDS.
    # PIDS are shuffled after every epoch and continue indefinitely.
    unique_pids = np.unique(pids)
    dataset = tf.data.Dataset.from_tensor_slices(unique_pids)
    dataset = dataset.shuffle(len(unique_pids))

    # Constrain the dataset size to a multiple of the batch-size, so that
    # we don't get overlap at the end of each epoch.
    if args.hard_pool_size == 0:
        dataset = dataset.take(
            (len(unique_pids) // args.batch_p) * args.batch_p)
        dataset = dataset.repeat(
            None)  # Repeat forever. Funny way of stating it.

    else:
        dataset = dataset.repeat(
            None)  # Repeat forever. Funny way of stating it.
        dataset = dataset.map(lambda pid: sample_batch_ids_for_pid(
            pid, all_pids=pids, batch_p=args.batch_p, all_hard_pids=hard_ids))
        # Unbatch the P PIDs
        dataset = dataset.apply(tf.contrib.data.unbatch())

    # For every PID, get K images.
    dataset = dataset.map(lambda pid: sample_k_fids_for_pid(
        pid, all_fids=fids, all_pids=pids, batch_k=args.batch_k))

    # Ungroup/flatten the batches for easy loading of the files.
    dataset = dataset.apply(tf.contrib.data.unbatch())

    # Convert filenames to actual image tensors.
    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)
    dataset = dataset.map(lambda im, fid, pid: common.fid_to_image(
        fid,
        pid,
        image_root=args.image_root,
        image_size=pre_crop_size if args.crop_augment else net_input_size),
                          num_parallel_calls=args.loading_threads)

    # Augment the data if specified by the arguments.
    if args.augment == False:
        dataset = dataset.map(
            lambda fid, pid: common.fid_to_image(fid,
                                                 pid,
                                                 image_root=args.image_root,
                                                 image_size=pre_crop_size
                                                 if args.crop_augment else
                                                 net_input_size),  #Ergys
            num_parallel_calls=args.loading_threads)

        if args.flip_augment:
            dataset = dataset.map(lambda im, fid, pid: (
                tf.image.random_flip_left_right(im), fid, pid))
        if args.crop_augment:
            dataset = dataset.map(lambda im, fid, pid: (tf.random_crop(
                im, net_input_size + (3, )), fid, pid))
    else:
        dataset = dataset.map(lambda im, fid, pid: common.fid_to_image(
            fid, pid, image_root=args.image_root, image_size=net_input_size),
                              num_parallel_calls=args.loading_threads)

        dataset = dataset.map(lambda im, fid, pid: (tf.py_func(
            augment_images, [im], [tf.float32]), fid, pid))
        dataset = dataset.map(lambda im, fid, pid: (tf.reshape(
            im[0],
            (args.net_input_height, args.net_input_width, 3)), fid, pid))

    # Group it back into PK batches.
    batch_size = args.batch_p * args.batch_k
    dataset = dataset.batch(batch_size)

    # Overlap producing and consuming for parallelism.
    dataset = dataset.prefetch(batch_size * 2)

    # Since we repeat the data infinitely, we only need a one-shot iterator.
    images, fids, pids = dataset.make_one_shot_iterator().get_next()

    # Create the model and an embedding head.
    model = import_module('nets.' + args.model_name)
    head = import_module('heads.' + args.head_name)

    # Feed the image through the model. The returned `body_prefix` will be used
    # further down to load the pre-trained weights for all variables with this
    # prefix.
    endpoints, body_prefix = model.endpoints(images, is_training=True)
    with tf.name_scope('head'):
        endpoints = head.head(endpoints, args.embedding_dim, is_training=True)

    # Create the loss in two steps:
    # 1. Compute all pairwise distances according to the specified metric.
    # 2. For each anchor along the first dimension, compute its loss.
    dists = loss.cdist(endpoints['emb'], endpoints['emb'], metric=args.metric)
    losses, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[
        args.loss](dists,
                   pids,
                   args.margin,
                   batch_precision_at_k=args.batch_k - 1)

    # Count the number of active entries, and compute the total batch loss.
    num_active = tf.reduce_sum(tf.cast(tf.greater(losses, 1e-5), tf.float32))
    loss_mean = tf.reduce_mean(losses)

    # Some logging for tensorboard.
    tf.summary.histogram('loss_distribution', losses)
    tf.summary.scalar('loss', loss_mean)
    tf.summary.scalar('batch_top1', train_top1)
    tf.summary.scalar('batch_prec_at_{}'.format(args.batch_k - 1), prec_at_k)
    tf.summary.scalar('active_count', num_active)
    tf.summary.histogram('embedding_dists', dists)
    tf.summary.histogram('embedding_pos_dists', pos_dists)
    tf.summary.histogram('embedding_neg_dists', neg_dists)
    tf.summary.histogram('embedding_lengths',
                         tf.norm(endpoints['emb_raw'], axis=1))

    # Create the mem-mapped arrays in which we'll log all training detail in
    # addition to tensorboard, because tensorboard is annoying for detailed
    # inspection and actually discards data in histogram summaries.
    if args.detailed_logs:
        log_embs = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'embeddings'),
            dtype=np.float32,
            shape=(args.train_iterations, batch_size, args.embedding_dim))
        log_loss = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'losses'),
            dtype=np.float32,
            shape=(args.train_iterations, batch_size))
        log_fids = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'fids'),
            dtype='S' + str(max_fid_len),
            shape=(args.train_iterations, batch_size))

    # These are collected here before we add the optimizer, because depending
    # on the optimizer, it might add extra slots, which are also global
    # variables, with the exact same prefix.
    model_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                        body_prefix)

    # Define the optimizer and the learning-rate schedule.
    # Unfortunately, we get NaNs if we don't handle no-decay separately.
    global_step = tf.Variable(0, name='global_step', trainable=False)
    if 0 <= args.decay_start_iteration < args.train_iterations:
        learning_rate = tf.train.exponential_decay(
            args.learning_rate,
            tf.maximum(0, global_step - args.decay_start_iteration),
            args.train_iterations - args.decay_start_iteration, 0.001)
    else:
        learning_rate = args.learning_rate
    tf.summary.scalar('learning_rate', learning_rate)
    optimizer = tf.train.AdamOptimizer(learning_rate)
    # Feel free to try others!
    # optimizer = tf.train.AdadeltaOptimizer(learning_rate)

    # Update_ops are used to update batchnorm stats.
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        train_op = optimizer.minimize(loss_mean, global_step=global_step)

    # Define a saver for the complete model.
    checkpoint_saver = tf.train.Saver(max_to_keep=0)

    with tf.Session() as sess:
        if args.resume:
            # In case we're resuming, simply load the full checkpoint to init.
            last_checkpoint = tf.train.latest_checkpoint(args.experiment_root)
            log.info('Restoring from checkpoint: {}'.format(last_checkpoint))
            checkpoint_saver.restore(sess, last_checkpoint)
        else:
            # But if we're starting from scratch, we may need to load some
            # variables from the pre-trained weights, and random init others.
            sess.run(tf.global_variables_initializer())
            if args.initial_checkpoint is not None:
                saver = tf.train.Saver(model_variables)
                saver.restore(sess, args.initial_checkpoint)

            # In any case, we also store this initialization as a checkpoint,
            # such that we could run exactly reproduceable experiments.
            checkpoint_saver.save(sess,
                                  os.path.join(args.experiment_root,
                                               'checkpoint'),
                                  global_step=0)

        merged_summary = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(args.experiment_root,
                                               sess.graph)

        start_step = sess.run(global_step)
        log.info('Starting training from iteration {}.'.format(start_step))

        # Finally, here comes the main-loop. This `Uninterrupt` is a handy
        # utility such that an iteration still finishes on Ctrl+C and we can
        # stop the training cleanly.
        with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u:
            for i in range(start_step, args.train_iterations):

                # Compute gradients, update weights, store logs!
                start_time = time.time()
                _, summary, step, b_prec_at_k, b_embs, b_loss, b_fids = \
                    sess.run([train_op, merged_summary, global_step,
                              prec_at_k, endpoints['emb'], losses, fids])
                elapsed_time = time.time() - start_time

                # Compute the iteration speed and add it to the summary.
                # We did observe some weird spikes that we couldn't track down.
                summary2 = tf.Summary()
                summary2.value.add(tag='secs_per_iter',
                                   simple_value=elapsed_time)
                summary_writer.add_summary(summary2, step)
                summary_writer.add_summary(summary, step)

                if args.detailed_logs:
                    log_embs[i], log_loss[i], log_fids[
                        i] = b_embs, b_loss, b_fids

                # Do a huge print out of the current progress.
                seconds_todo = (args.train_iterations - step) * elapsed_time
                log.info(
                    'iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, '
                    'batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it)'.format(
                        step, float(np.min(b_loss)), float(np.mean(b_loss)),
                        float(np.max(b_loss)), args.batch_k - 1,
                        float(b_prec_at_k),
                        timedelta(seconds=int(seconds_todo)), elapsed_time))
                sys.stdout.flush()
                sys.stderr.flush()

                # Save a checkpoint of training every so often.
                if (args.checkpoint_frequency > 0
                        and step % args.checkpoint_frequency == 0):
                    checkpoint_saver.save(sess,
                                          os.path.join(args.experiment_root,
                                                       'checkpoint'),
                                          global_step=step)

                # Stop the main-loop at the end of the step, if requested.
                if u.interrupted:
                    log.info("Interrupted on request!")
                    break

        # Store one final checkpoint. This might be redundant, but it is crucial
        # in case intermediate storing was disabled and it saves a checkpoint
        # when the process was interrupted.
        checkpoint_saver.save(sess,
                              os.path.join(args.experiment_root, 'checkpoint'),
                              global_step=step)
Ejemplo n.º 3
0
def main():
    args = parser.parse_args()

    # We store all arguments in a json file. This has two advantages:
    # 1. We can always get back and see what exactly that experiment was
    # 2. We can resume an experiment as-is without needing to remember all flags.
    args_file = os.path.join(args.experiment_root, 'args.json')
    if args.resume:
        if not os.path.isfile(args_file):
            raise IOError('`args.json` not found in {}'.format(args_file))

        print('Loading args from {}.'.format(args_file))
        with open(args_file, 'r') as f:
            args_resumed = json.load(f)
        args_resumed['resume'] = True  # This would be overwritten.

        # When resuming, we not only want to populate the args object with the
        # values from the file, but we also want to check for some possible
        # conflicts between loaded and given arguments.
        for key, value in args.__dict__.items():
            if key in args_resumed:
                resumed_value = args_resumed[key]
                if resumed_value != value:
                    print('Warning: For the argument `{}` we are using the'
                          ' loaded value `{}`. The provided value was `{}`'
                          '.'.format(key, resumed_value, value))
                    comand = input('Would you like to restore it?(yes/no)')
                    if comand == 'yes':
                        args.__dict__[key] = resumed_value
                        print(
                            'For the argument `{}` we are using the loaded value `{}`.'
                            .format(key, args.__dict__[key]))
                    else:
                        print(
                            'For the argument `{}` we are using the provided value `{}`.'
                            .format(key, args.__dict__[key]))
            else:
                print('Warning: A new argument was added since the last run:'
                      ' `{}`. Using the new value: `{}`.'.format(key, value))
        os.remove(args_file)
        with open(args_file, 'w') as f:
            json.dump(vars(args),
                      f,
                      ensure_ascii=False,
                      indent=2,
                      sort_keys=True)

    else:
        # If the experiment directory exists already, we bail in fear.
        if os.path.exists(args.experiment_root):
            if os.listdir(args.experiment_root):
                print('The directory {} already exists and is not empty.'
                      ' If you want to resume training, append --resume to'
                      ' your call.'.format(args.experiment_root))
                exit(1)
        else:
            os.makedirs(args.experiment_root)

        # Store the passed arguments for later resuming and grepping in a nice
        # and readable format.
        with open(args_file, 'w') as f:
            json.dump(vars(args),
                      f,
                      ensure_ascii=False,
                      indent=2,
                      sort_keys=True)

    log_file = os.path.join(args.experiment_root, "train")
    logging.config.dictConfig(common.get_logging_dict(log_file))
    log = logging.getLogger('train')

    # Also show all parameter values at the start, for ease of reading logs.
    log.info('Training using the following parameters:')
    for key, value in sorted(vars(args).items()):
        log.info('{}: {}'.format(key, value))

    # Check them here, so they are not required when --resume-ing.
    if not args.train_set:
        parser.print_help()
        log.error("You did not specify the `train_set` argument!")
        sys.exit(1)
    if not args.image_root:
        parser.print_help()
        log.error("You did not specify the required `image_root` argument!")
        sys.exit(1)

######################################################################################
#prepare the training dataset
# Load the data from the TxT file. see Common.load_dataset function for details
    pids_train, fids_train = common.load_dataset(args.train_set,
                                                 args.image_root)
    max_fid_len = max(map(len,
                          fids_train))  # We'll need this later for logfiles.

    # Setup a tf.Dataset where one "epoch" loops over all PIDS.
    # PIDS are shuffled after every epoch and continue indefinitely.
    unique_pids = np.unique(pids_train)
    dataset = tf.data.Dataset.from_tensor_slices(unique_pids)
    dataset = dataset.shuffle(len(unique_pids))

    # Constrain the dataset size to a multiple of the batch-size, so that
    # we don't get overlap at the end of each epoch.
    dataset = dataset.take((len(unique_pids) // args.batch_p) * args.batch_p)
    dataset = dataset.repeat(None)  # Repeat forever. Funny way of stating it.

    # For every PID, get K images.
    dataset = dataset.map(lambda pid: sample_k_fids_for_pid(
        pid, all_fids=fids_train, all_pids=pids_train, batch_k=args.batch_k
    ))  # now the dataset has been modified as [selected_fids
    # , pid] due to the return of the function 'sample_k_fids_for_pid'

    # Ungroup/flatten the batches for easy loading of the files.
    dataset = dataset.apply(tf.contrib.data.unbatch())

    # Convert filenames to actual image tensors.
    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)
    dataset = dataset.map(
        lambda fid, pid: common.fid_to_image(fid,
                                             pid,
                                             image_root=args.image_root,
                                             image_size=pre_crop_size if args.
                                             crop_augment else net_input_size),
        num_parallel_calls=args.loading_threads
    )  # now the dataset has been modified as [selected_images
    # , fid, pid] due to the return of the function 'fid_to_image'

    # Augment the data if specified by the arguments.
    if args.flip_augment:
        dataset = dataset.map(lambda im, fid, pid:
                              (tf.image.random_flip_left_right(im), fid, pid))
    if args.crop_augment:
        dataset = dataset.map(lambda im, fid, pid: (tf.random_crop(
            im, net_input_size + (3, )), fid, pid))

    # Group it back into PK batches.
    batch_size = args.batch_p * args.batch_k
    dataset = dataset.batch(batch_size)

    # Overlap producing and consuming for parallelism.
    dataset = dataset.prefetch(1)

    # Since we repeat the data infinitely, we only need a one-shot iterator.
    images_train, fids_train, pids_train = dataset.make_one_shot_iterator(
    ).get_next()
    ########################################################################################################################
    #prepare the validation set
    pids_val, fids_val = common.load_dataset(args.validation_set,
                                             args.validation_image_root)
    # Setup a tf.Dataset where one "epoch" loops over all PIDS.
    # PIDS are shuffled after every epoch and continue indefinitely.
    unique_pids_val = np.unique(pids_val)
    dataset_val = tf.data.Dataset.from_tensor_slices(unique_pids_val)
    dataset_val = dataset_val.shuffle(len(unique_pids_val))

    # Constrain the dataset size to a multiple of the batch-size, so that
    # we don't get overlap at the end of each epoch.
    dataset_val = dataset_val.take(
        (len(unique_pids_val) // args.batch_p) * args.batch_p)
    dataset_val = dataset_val.repeat(
        None)  # Repeat forever. Funny way of stating it.

    # For every PID, get K images.
    dataset_val = dataset_val.map(lambda pid: sample_k_fids_for_pid(
        pid, all_fids=fids_val, all_pids=pids_val, batch_k=args.batch_k
    ))  # now the dataset has been modified as [selected_fids
    # , pid] due to the return of the function 'sample_k_fids_for_pid'

    # Ungroup/flatten the batches for easy loading of the files.
    dataset_val = dataset_val.apply(tf.contrib.data.unbatch())

    # Convert filenames to actual image tensors.
    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)
    dataset_val = dataset_val.map(
        lambda fid, pid: common.fid_to_image(
            fid,
            pid,
            image_root=args.validation_image_root,
            image_size=pre_crop_size if args.crop_augment else net_input_size),
        num_parallel_calls=args.loading_threads
    )  # now the dataset has been modified as [selected_images
    # , fid, pid] due to the return of the function 'fid_to_image'

    # Augment the data if specified by the arguments.
    if args.flip_augment:
        dataset_val = dataset_val.map(lambda im, fid, pid: (
            tf.image.random_flip_left_right(im), fid, pid))
    if args.crop_augment:
        dataset_val = dataset_val.map(lambda im, fid, pid: (tf.random_crop(
            im, net_input_size + (3, )), fid, pid))

    # Group it back into PK batches.
    dataset_val = dataset_val.batch(batch_size)

    # Overlap producing and consuming for parallelism.
    dataset_val = dataset_val.prefetch(1)

    # Since we repeat the data infinitely, we only need a one-shot iterator.
    images_val, fids_val, pids_val = dataset_val.make_one_shot_iterator(
    ).get_next()
    ####################################################################################################################
    # Create the model and an embedding head.
    model = import_module('nets.' + args.model_name)
    head = import_module('heads.' + args.head_name)

    # Feed the image through the model. The returned `body_prefix` will be used
    # further down to load the pre-trained weights for all variables with this
    # prefix.
    input_images = tf.placeholder(
        dtype=tf.float32,
        shape=[None, args.net_input_height, args.net_input_width, 3],
        name='input')
    pids = tf.placeholder(dtype=tf.string, shape=[
        None,
    ], name='pids')
    fids = tf.placeholder(dtype=tf.string, shape=[
        None,
    ], name='fids')

    endpoints, body_prefix = model.endpoints(input_images, is_training=True)
    with tf.name_scope('head'):
        endpoints = head.head(endpoints, args.embedding_dim, is_training=True)

    # Create the loss in two steps:
    # 1. Compute all pairwise distances according to the specified metric.
    # 2. For each anchor along the first dimension, compute its loss.
    # dists = loss.cdist(endpoints['emb'], endpoints['emb'], metric=args.metric)
    # losses, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[args.loss](
    #     dists, pids, args.margin, batch_precision_at_k=args.batch_k-1)
    # # '_' stands for the boolean matrix shows topK where the correct match of the identities occurs
    # shape=(batch_size,K)


# 更改
# loss1
    dists1 = loss.cdist(endpoints['feature1'],
                        endpoints['feature1'],
                        metric=args.metric)
    losses1, _, _, _, _, _ = loss.LOSS_CHOICES[args.loss](
        dists1, pids, args.margin, batch_precision_at_k=args.batch_k - 1)
    dists2 = loss.cdist(endpoints['feature2'],
                        endpoints['feature2'],
                        metric=args.metric)
    losses2, _, _, _, _, _ = loss.LOSS_CHOICES[args.loss](
        dists2, pids, args.margin, batch_precision_at_k=args.batch_k - 1)
    dists3 = loss.cdist(endpoints['feature3'],
                        endpoints['feature3'],
                        metric=args.metric)
    losses3, _, _, _, _, _ = loss.LOSS_CHOICES[args.loss](
        dists3, pids, args.margin, batch_precision_at_k=args.batch_k - 1)
    dists4 = loss.cdist(endpoints['feature4'],
                        endpoints['feature4'],
                        metric=args.metric)
    losses4, _, _, _, _, _ = loss.LOSS_CHOICES[args.loss](
        dists4, pids, args.margin, batch_precision_at_k=args.batch_k - 1)
    dists_fu = loss.cdist(endpoints['fusion_layer'],
                          endpoints['fusion_layer'],
                          metric=args.metric)
    losses_fu, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[
        args.loss](dists_fu,
                   pids,
                   args.margin,
                   batch_precision_at_k=args.batch_k - 1)

    losses = losses1 + losses2 + losses3 + losses4 + losses_fu

    # 更改
    #loss
    # losses_fu, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[args.loss](
    #     endpoints,pids, model_type=args.model_name, metric=args.metric, batch_precision_at_k=args.batch_k - 1
    # )

    # Count the number of active entries, and compute the total batch loss.
    num_active = tf.reduce_sum(tf.cast(tf.greater(losses, 1e-5), tf.float32))

    # 此处losses即为 pospair 比 negpair+margin 还大的部分
    loss_mean = tf.reduce_mean(losses)

    # Some logging for tensorboard.
    tf.summary.histogram('loss_distribution', losses)
    tf.summary.scalar('loss', loss_mean)
    tf.summary.scalar('batch_top1', train_top1)
    tf.summary.scalar('batch_prec_at_{}'.format(args.batch_k - 1), prec_at_k)
    tf.summary.scalar('active_count', num_active)
    #tf.summary.histogram('embedding_dists', dists)
    tf.summary.histogram('embedding_pos_dists', pos_dists)
    tf.summary.histogram('embedding_neg_dists', neg_dists)
    tf.summary.histogram('embedding_lengths',
                         tf.norm(endpoints['emb_raw'], axis=1))

    # Create the mem-mapped arrays in which we'll log all training detail in
    # addition to tensorboard, because tensorboard is annoying for detailed
    # inspection and actually discards data in histogram summaries.
    if args.detailed_logs:
        log_embs = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'embeddings'),
            dtype=np.float32,
            shape=(args.train_iterations, batch_size, args.embedding_dim))
        log_loss = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'losses'),
            dtype=np.float32,
            shape=(args.train_iterations, batch_size))
        log_fids = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'fids'),
            dtype='S' + str(max_fid_len),
            shape=(args.train_iterations, batch_size))

    # These are collected here before we add the optimizer, because depending
    # on the optimizer, it might add extra slots, which are also global
    # variables, with the exact same prefix.
    model_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                        body_prefix)

    # Define the optimizer and the learning-rate schedule.
    # Unfortunately, we get NaNs if we don't handle no-decay separately.
    global_step = tf.Variable(
        0, name='global_step',
        trainable=False)  # 'global_step' means the number of batches seen
    #  by graph
    if 0 <= args.decay_start_iteration < args.train_iterations:
        learning_rate = tf.train.exponential_decay(
            args.learning_rate,
            tf.maximum(0, global_step - args.decay_start_iteration
                       ),  # decay every 'lr_decay_steps' after the
            # 'decay_start_iteration'
            # args.train_iterations - args.decay_start_iteration, args.weight_decay_factor)
            args.lr_decay_steps,
            args.lr_decay_factor,
            staircase=True)
    else:
        learning_rate = args.learning_rate  # the case when we set 'decay_start_iteration' as -1
    tf.summary.scalar('learning_rate', learning_rate)
    optimizer = tf.train.AdamOptimizer(learning_rate, epsilon=1e-3)
    # Feel free to try others!
    # optimizer = tf.train.AdadeltaOptimizer(learning_rate)

    # Update_ops are used to update batchnorm stats.
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        train_op = optimizer.minimize(loss_mean, global_step=global_step)

    # Define a saver for the complete model.
    checkpoint_saver = tf.train.Saver(max_to_keep=0)

    with tf.Session(config=config) as sess:
        if args.resume:
            # In case we're resuming, simply load the full checkpoint to init.
            if args.checkpoint is None:
                last_checkpoint = tf.train.latest_checkpoint(
                    args.experiment_root)
                log.info(
                    'Restoring from checkpoint: {}'.format(last_checkpoint))
                checkpoint_saver.restore(sess, last_checkpoint)
            else:
                ckpt_path = os.path.join(args.experiment_root, args.checkpoint)
                log.info('Restoring from checkpoint: {}'.format(
                    args.checkpoint))
                checkpoint_saver.restore(sess, ckpt_path)
        else:
            # But if we're starting from scratch, we may need to load some
            # variables from the pre-trained weights, and random init others.
            sess.run(tf.global_variables_initializer())
            if args.initial_checkpoint is not None:
                saver = tf.train.Saver(model_variables)
                saver.restore(
                    sess, args.initial_checkpoint
                )  # restore the pre-trained parameter from online model

            # In any case, we also store this initialization as a checkpoint,
            # such that we could run exactly reproduceable experiments.
            checkpoint_saver.save(sess,
                                  os.path.join(args.experiment_root,
                                               'checkpoint'),
                                  global_step=0)

        merged_summary = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(args.experiment_root,
                                               sess.graph)

        start_step = sess.run(global_step)
        log.info('Starting training from iteration {}.'.format(start_step))

        # Finally, here comes the main-loop. This `Uninterrupt` is a handy
        # utility such that an iteration still finishes on Ctrl+C and we can
        # stop the training cleanly.
        with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u:
            for i in range(start_step, args.train_iterations):

                # Compute gradients, update weights, store logs!
                start_time = time.time()
                _, summary, step, b_prec_at_k, b_embs, b_loss, b_fids = \
                    sess.run([train_op, merged_summary, global_step,
                              prec_at_k, endpoints['emb'], losses, fids], feed_dict={input_images:images_train.eval(),
                                                                                     pids:pids_train.eval(),
                                                                                     fids:fids_train.eval()})
                elapsed_time = time.time() - start_time

                # Compute the iteration speed and add it to the summary.
                # We did observe some weird spikes that we couldn't track down.
                summary2 = tf.Summary()
                summary2.value.add(tag='secs_per_iter',
                                   simple_value=elapsed_time)
                summary_writer.add_summary(summary2, step)
                summary_writer.add_summary(summary, step)

                if args.detailed_logs:
                    log_embs[i], log_loss[i], log_fids[
                        i] = b_embs, b_loss, b_fids

                # Do a huge print out of the current progress.
                seconds_todo = (args.train_iterations - step) * elapsed_time
                log.info(
                    'iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, '
                    'batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it)'.format(
                        step, float(np.min(b_loss)), float(np.mean(b_loss)),
                        float(np.max(b_loss)), args.batch_k - 1,
                        float(b_prec_at_k),
                        timedelta(seconds=int(seconds_todo)), elapsed_time))
                sys.stdout.flush()
                sys.stderr.flush()

                # Save a checkpoint of training every so often.
                if (args.checkpoint_frequency > 0
                        and step % args.checkpoint_frequency == 0):
                    checkpoint_saver.save(sess,
                                          os.path.join(args.experiment_root,
                                                       'checkpoint'),
                                          global_step=step)

                #get validation results
                if (args.validation_frequency > 0
                        and step % args.validation_frequency == 0):
                    b_prec_at_k_val, b_loss, b_fids = \
                       sess.run([prec_at_k, losses, fids], feed_dict={input_images : images_val.eval(),
                                                                      pids:pids_val.eval(),
                                                                      fids:fids_val.eval()})
                    log.info(
                        'Validation @:{:6d} iteration, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, '
                        'batch-p@{}: {:.2%}'.format(step,
                                                    float(np.min(b_loss)),
                                                    float(np.mean(b_loss)),
                                                    float(np.max(b_loss)),
                                                    args.batch_k - 1,
                                                    float(b_prec_at_k_val)))
                    sys.stdout.flush()
                    sys.stderr.flush()
                    summary3 = tf.Summary()
                    summary3.value.add(tag='secs_per_iter',
                                       simple_value=float(np.mean(b_loss)))
                    summary_writer.add_summary(summary3, step)
                    summary_writer.add_summary(summary3, step)

                # Stop the main-loop at the end of the step, if requested.
                if u.interrupted:
                    log.info("Interrupted on request!")
                    break

        # Store one final checkpoint. This might be redundant, but it is crucial
        # in case intermediate storing was disabled and it saves a checkpoint
        # when the process was interrupted.
        checkpoint_saver.save(sess,
                              os.path.join(args.experiment_root, 'checkpoint'),
                              global_step=step)
Ejemplo n.º 4
0
def calculate_emb_for_fids(args, data_fids):
    '''
    Calculate embeddings

    :param args: input arguments
    :param data_fids: relative paths to the imagies
    :return: matrix with shape len(data_fids) x embedding_dim (embedding vector for each image - one row)
    '''
    ###################################################################################################################
    # LOAD DATA
    ###################################################################################################################
    # Load the args from the original experiment.
    net_input_height = 256
    net_input_width = 128
    pre_crop_height = 288
    pre_crop_width = 144
    net_input_size = (net_input_height, net_input_width)
    pre_crop_size = (pre_crop_height, pre_crop_width)

    ###################################################################################################################
    # PREPARE DATA
    ###################################################################################################################
    # Setup a tf Dataset containing all images.
    dataset = tf.data.Dataset.from_tensor_slices(data_fids)

    # Convert filenames to actual image tensors.
    # dataset tensor: [image_resized, fid, pid]
    dataset = dataset.map(
        lambda fid: common.fid_to_image(fid,
                                        tf.constant("dummy", dtype=tf.string),
                                        image_root=args.image_root,
                                        image_size=pre_crop_size),
        num_parallel_calls=8)

    # Augment the data if specified by the arguments.
    # `modifiers` is a list of strings that keeps track of which augmentations
    # have been applied, so that a human can understand it later on.
    modifiers = ['original']
    dataset = dataset.map(flip_augment)
    dataset = dataset.apply(tf.contrib.data.unbatch())
    modifiers = [o + m for m in ['', '_flip'] for o in modifiers]

    dataset = dataset.map(lambda im, fid, pid:
                          (tf.stack(five_crops(im, net_input_size)),
                           tf.stack([fid] * 5), tf.stack([pid] * 5)))
    dataset = dataset.apply(tf.contrib.data.unbatch())
    modifiers = [
        o + m for o in modifiers for m in [
            '_center', '_top_left', '_top_right', '_bottom_left',
            '_bottom_right'
        ]
    ]

    # Group it back into PK batches.
    dataset = dataset.batch(256)

    # Overlap producing and consuming.
    dataset = dataset.prefetch(1)

    images, _, _ = dataset.make_one_shot_iterator().get_next()

    ###################################################################################################################
    # CREATE MODEL
    ###################################################################################################################

    # Get the weights
    model = import_module('nets.resnet_v1_50')
    embedding_dim = 128
    block4_units = 1
    endpoints = model.endpoints(images,
                                block4_units=block4_units,
                                is_training=False,
                                embedding_dim=embedding_dim)

    with tf.Session() as sess:
        # Initialize the network/load the checkpoint.
        print('Restoring from checkpoint: {}'.format(args.checkpoint))
        tf.train.Saver().restore(sess, args.checkpoint)

        # Go ahead and embed the whole dataset, with all augmented versions too.
        emb_storage = np.zeros(
            (len(data_fids) * len(modifiers), embedding_dim), np.float32)
        for start_idx in count(step=256):
            try:
                emb = sess.run(endpoints['emb'])
                print('\rEmbedded batch {}-{}/{}'.format(
                    start_idx, start_idx + len(emb), len(emb_storage)),
                      flush=True,
                      end='')
                emb_storage[start_idx:start_idx + len(emb)] = emb
            except tf.errors.OutOfRangeError:
                break  # This just indicates the end of the dataset.

        print()
        print("Done with embedding, aggregating augmentations...", flush=True)

        # Pull out the augmentations into a separate first dimension.
        emb_storage = emb_storage.reshape(len(data_fids), len(modifiers), -1)
        emb_storage = emb_storage.transpose((1, 0, 2))  # (Aug,FID,128D)

        # Aggregate according to the specified parameter.
        emb_storage = np.mean(emb_storage, axis=0)

    tf.reset_default_graph()
    return emb_storage
Ejemplo n.º 5
0
def main(argv):
    args = parser.parse_args(argv)

    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    # We store all arguments in a json file. This has two advantages:
    # 1. We can always get back and see what exactly that experiment was
    # 2. We can resume an experiment as-is without needing to remember all flags.
    args_file = os.path.join(args.experiment_root, 'args.json')
    if args.resume:
        if not os.path.isfile(args_file):
            raise IOError('`args.json` not found in {}'.format(args_file))

        print('Loading args from {}.'.format(args_file))
        with open(args_file, 'r') as f:
            args_resumed = json.load(f)
        args_resumed['resume'] = True  # This would be overwritten.

        # When resuming, we not only want to populate the args object with the
        # values from the file, but we also want to check for some possible
        # conflicts between loaded and given arguments.
        for key, value in args.__dict__.items():
            if key in args_resumed:
                resumed_value = args_resumed[key]
                if resumed_value != value:
                    print('Warning: For the argument `{}` we are using the'
                          ' loaded value `{}`. The provided value was `{}`'
                          '.'.format(key, resumed_value, value))
                    args.__dict__[key] = resumed_value
            else:
                print('Warning: A new argument was added since the last run:'
                      ' `{}`. Using the new value: `{}`.'.format(key, value))

    else:
        # If the experiment directory exists already, we bail in fear.
        if os.path.exists(args.experiment_root):
            if os.listdir(args.experiment_root):
                print('The directory {} already exists and is not empty.'
                      ' If you want to resume training, append --resume to'
                      ' your call.'.format(args.experiment_root))
                exit(1)
        else:
            os.makedirs(args.experiment_root)

        # Store the passed arguments for later resuming and grepping in a nice
        # and readable format.
        with open(args_file, 'w') as f:
            json.dump(vars(args),
                      f,
                      ensure_ascii=False,
                      indent=2,
                      sort_keys=True)

    log_file = os.path.join(args.experiment_root, "train")
    logging.config.dictConfig(common.get_logging_dict(log_file))
    log = logging.getLogger('train')

    # Also show all parameter values at the start, for ease of reading logs.
    log.info('Training using the following parameters:')
    for key, value in sorted(vars(args).items()):
        log.info('{}: {}'.format(key, value))

    # Check them here, so they are not required when --resume-ing.
    if not args.train_set:
        parser.print_help()
        log.error("You did not specify the `train_set` argument!")
        sys.exit(1)
    if not args.image_root:
        parser.print_help()
        log.error("You did not specify the required `image_root` argument!")
        sys.exit(1)

    # Load the data from the CSV file.
    pids, fids = common.load_dataset(args.train_set, args.image_root)
    max_fid_len = max(map(len, fids))  # We'll need this later for logfiles.

    # Setup a tf.Dataset where one "epoch" loops over all PIDS.
    # PIDS are shuffled after every epoch and continue indefinitely.
    unique_pids = np.unique(pids)
    if len(unique_pids) < args.batch_p:
        unique_pids = np.tile(unique_pids,
                              int(np.ceil(args.batch_p / len(unique_pids))))
    dataset = tf.data.Dataset.from_tensor_slices(unique_pids)
    dataset = dataset.shuffle(len(unique_pids))

    # Constrain the dataset size to a multiple of the batch-size, so that
    # we don't get overlap at the end of each epoch.
    dataset = dataset.take((len(unique_pids) // args.batch_p) * args.batch_p)
    dataset = dataset.repeat(None)  # Repeat forever. Funny way of stating it.

    # For every PID, get K images.
    dataset = dataset.map(lambda pid: sample_k_fids_for_pid(
        pid, all_fids=fids, all_pids=pids, batch_k=args.batch_k))

    # Ungroup/flatten the batches for easy loading of the files.
    dataset = dataset.apply(tf.contrib.data.unbatch())

    # Convert filenames to actual image tensors.
    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)

    dataset = dataset.map(lambda fid, pid: common.fid_to_image(
        fid,
        pid,
        image_root=args.image_root,
        image_size=pre_crop_size if args.crop_augment else net_input_size),
                          num_parallel_calls=args.loading_threads)

    # Augment the data if specified by the arguments.

    dataset = dataset.map(
        lambda im, fid, pid: common.fid_to_image(fid,
                                                 pid,
                                                 image_root=args.image_root,
                                                 image_size=pre_crop_size
                                                 if args.crop_augment else
                                                 net_input_size),  # Ergys
        num_parallel_calls=args.loading_threads)

    if args.flip_augment:
        dataset = dataset.map(lambda im, fid, pid:
                              (tf.image.random_flip_left_right(im), fid, pid))
    if args.crop_augment:
        dataset = dataset.map(lambda im, fid, pid: (tf.random_crop(
            im, net_input_size + (3, )), fid, pid))

    # Group it back into PK batches.
    batch_size = args.batch_p * args.batch_k
    dataset = dataset.batch(batch_size)

    # Overlap producing and consuming for parallelism.
    dataset = dataset.prefetch(1)

    # Since we repeat the data infinitely, we only need a one-shot iterator.
    images, fids, pids = dataset.make_one_shot_iterator().get_next()

    # Create the model and an embedding head.
    model = import_module('nets.' + args.model_name)
    head = import_module('heads.' + args.head_name)

    # Feed the image through the model. The returned `body_prefix` will be used
    # further down to load the pre-trained weights for all variables with this
    # prefix.

    weight_decay = 10e-4
    weights_regularizer = tf.contrib.layers.l2_regularizer(scale=weight_decay)
    endpoints, body_prefix = model.endpoints(images, is_training=True)
    with tf.name_scope('head'):
        endpoints = head.head(endpoints,
                              args.embedding_dim,
                              is_training=True,
                              weights_regularizer=weights_regularizer)

    # Create the loss in two steps:
    # 1. Compute all pairwise distances according to the specified metric.
    # 2. For each anchor along the first dimension, compute its loss.
    # batch_embedding = endpoints['emb']
    batch_embedding = endpoints['emb']
    if args.loss == 'semi_hard_triplet':
        triplet_loss = triplet_semihard_loss(batch_embedding, pids,
                                             args.margin)
    elif args.loss == 'hard_triplet':
        triplet_loss = batch_hard(batch_embedding, pids, args.margin,
                                  args.metric)
    elif args.loss == 'lifted_loss':
        triplet_loss = lifted_loss(pids, batch_embedding, margin=args.margin)
    elif args.loss == 'contrastive_loss':
        assert batch_size % 2 == 0
        assert args.batch_k == 4  ## Can work with other number but will need tuning

        contrastive_idx = np.tile([0, 1, 4, 3, 2, 5, 6, 7], args.batch_p // 2)
        for i in range(args.batch_p // 2):
            contrastive_idx[i * 8:i * 8 + 8] += i * 8

        contrastive_idx = np.expand_dims(contrastive_idx, 1)
        batch_embedding_ordered = tf.gather_nd(batch_embedding,
                                               contrastive_idx)
        pids_ordered = tf.gather_nd(pids, contrastive_idx)
        # batch_embedding_ordered = tf.Print(batch_embedding_ordered,[pids_ordered],'pids_ordered :: ',summarize=1000)
        embeddings_anchor, embeddings_positive = tf.unstack(
            tf.reshape(batch_embedding_ordered, [-1, 2, args.embedding_dim]),
            2, 1)
        # embeddings_anchor = tf.Print(embeddings_anchor,[pids_ordered,embeddings_anchor,embeddings_positive,batch_embedding,batch_embedding_ordered],"Tensors ", summarize=1000)

        fixed_labels = np.tile([1, 0, 0, 1], args.batch_p // 2)
        # fixed_labels = np.reshape(fixed_labels,(len(fixed_labels),1))
        # print(fixed_labels)
        labels = tf.constant(fixed_labels)
        # labels = tf.Print(labels,[labels],'labels ',summarize=1000)
        triplet_loss = contrastive_loss(labels,
                                        embeddings_anchor,
                                        embeddings_positive,
                                        margin=args.margin)
    elif args.loss == 'angular_loss':
        embeddings_anchor, embeddings_positive = tf.unstack(
            tf.reshape(batch_embedding, [-1, 2, args.embedding_dim]), 2, 1)
        # pids = tf.Print(pids, [pids], 'pids:: ', summarize=100)
        pids, _ = tf.unstack(tf.reshape(pids, [-1, 2, 1]), 2, 1)
        # pids = tf.Print(pids,[pids],'pids:: ',summarize=100)
        triplet_loss = angular_loss(pids,
                                    embeddings_anchor,
                                    embeddings_positive,
                                    batch_size=args.batch_p,
                                    with_l2reg=True)
    elif args.loss == 'npairs_loss':
        assert args.batch_k == 2  ## Single positive pair per class
        embeddings_anchor, embeddings_positive = tf.unstack(
            tf.reshape(batch_embedding, [-1, 2, args.embedding_dim]), 2, 1)
        pids, _ = tf.unstack(tf.reshape(pids, [-1, 2, 1]), 2, 1)
        pids = tf.reshape(pids, [-1])
        triplet_loss = npairs_loss(pids, embeddings_anchor,
                                   embeddings_positive)
    else:
        raise NotImplementedError('loss function {} NotImplemented'.format(
            args.loss))

    loss_mean = tf.reduce_mean(triplet_loss)

    # These are collected here before we add the optimizer, because depending
    # on the optimizer, it might add extra slots, which are also global
    # variables, with the exact same prefix.
    model_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                        body_prefix)

    # Define the optimizer and the learning-rate schedule.
    # Unfortunately, we get NaNs if we don't handle no-decay separately.
    global_step = tf.Variable(0, name='global_step', trainable=False)

    if 0 <= args.decay_start_iteration < args.train_iterations:
        learning_rate = tf.train.polynomial_decay(args.learning_rate,
                                                  global_step,
                                                  args.train_iterations,
                                                  end_learning_rate=1e-7,
                                                  power=1)
    else:
        learning_rate = args.learning_rate

    if args.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate)
    elif args.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.9)
    else:
        raise NotImplementedError('Invalid optimizer {}'.format(
            args.optimizer))
    #
    # learning_rate = tf.train.polynomial_decay(args.learning_rate, global_step,
    #                                           args.train_iterations, end_learning_rate= 1e-7,
    #                                           power=1)
    #

    # Feel free to try others!
    # optimizer = tf.train.AdadeltaOptimizer(learning_rate)

    # Update_ops are used to update batchnorm stats.
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        train_op = optimizer.minimize(loss_mean, global_step=global_step)

    # Define a saver for the complete model.
    checkpoint_saver = tf.train.Saver(max_to_keep=2)
    gpu_options = tf.GPUOptions(allow_growth=True)
    gpu_config = tf.ConfigProto(gpu_options=gpu_options)
    with tf.Session(config=gpu_config) as sess:
        if args.resume:
            # In case we're resuming, simply load the full checkpoint to init.
            last_checkpoint = tf.train.latest_checkpoint(args.experiment_root)
            if last_checkpoint == None:
                print('Resume with No previous checkpoint')

                # But if we're starting from scratch, we may need to load some
                # variables from the pre-trained weights, and random init others.
                sess.run(tf.global_variables_initializer())
                if args.initial_checkpoint is not None:
                    saver = tf.train.Saver(model_variables)
                    saver.restore(sess, args.initial_checkpoint)

                # In any case, we also store this initialization as a checkpoint,
                # such that we could run exactly reproduceable experiments.
                checkpoint_saver.save(sess,
                                      os.path.join(args.experiment_root,
                                                   'checkpoint'),
                                      global_step=0)

            else:
                log.info(
                    'Restoring from checkpoint: {}'.format(last_checkpoint))
                checkpoint_saver.restore(sess, last_checkpoint)
        else:
            # But if we're starting from scratch, we may need to load some
            # variables from the pre-trained weights, and random init others.
            sess.run(tf.global_variables_initializer())
            if args.initial_checkpoint is not None:
                saver = tf.train.Saver(model_variables)
                saver.restore(sess, args.initial_checkpoint)

            # In any case, we also store this initialization as a checkpoint,
            # such that we could run exactly reproduceable experiments.
            checkpoint_saver.save(sess,
                                  os.path.join(args.experiment_root,
                                               'checkpoint'),
                                  global_step=0)

        start_step = sess.run(global_step)
        log.info('Starting training from iteration {}.'.format(start_step))

        # Finally, here comes the main-loop. This `Uninterrupt` is a handy
        # utility such that an iteration still finishes on Ctrl+C and we can
        # stop the training cleanly.
        with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u:
            for i in range(start_step, args.train_iterations):

                # Compute gradients, update weights, store logs!
                start_time = time.time()
                _,  step, b_embs, b_loss, b_fids = \
                    sess.run([train_op, global_step, endpoints['emb'], triplet_loss, fids])
                elapsed_time = time.time() - start_time

                # Do a huge print out of the current progress.
                seconds_todo = (args.train_iterations - step) * elapsed_time
                log.info(
                    'iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, ETA: {} ({:.2f}s/it)'
                    .format(
                        step,
                        float(np.min(b_loss)),
                        float(np.mean(b_loss)),
                        float(np.max(b_loss)),
                        # args.batch_k - 1, float(b_prec_at_k),
                        timedelta(seconds=int(seconds_todo)),
                        elapsed_time))
                sys.stdout.flush()
                sys.stderr.flush()

                # Save a checkpoint of training every so often.
                if (args.checkpoint_frequency > 0
                        and step % args.checkpoint_frequency == 0):
                    checkpoint_saver.save(sess,
                                          os.path.join(args.experiment_root,
                                                       'checkpoint'),
                                          global_step=step)

                # Stop the main-loop at the end of the step, if requested.
                if u.interrupted:
                    log.info("Interrupted on request!")
                    break

        # Store one final checkpoint. This might be redundant, but it is crucial
        # in case intermediate storing was disabled and it saves a checkpoint
        # when the process was interrupted.
        checkpoint_saver.save(sess,
                              os.path.join(args.experiment_root, 'checkpoint'),
                              global_step=step)
def main():
    args = parser.parse_args()

    # We store all arguments in a json file. This has two advantages:
    # 1. We can always get back and see what exactly that experiment was
    # 2. We can resume an experiment as-is without needing to remember all flags.
    args_file = os.path.join(args.experiment_root, 'args.json')
    if args.resume:
        if not os.path.isfile(args_file):
            raise IOError('`args.json` not found in {}'.format(args_file))

        print('Loading args from {}.'.format(args_file))
        with open(args_file, 'r') as f:
            args_resumed = json.load(f)
        args_resumed['resume'] = True  # This would be overwritten.

        # When resuming, we not only want to populate the args object with the
        # values from the file, but we also want to check for some possible
        # conflicts between loaded and given arguments.
        for key, value in args.__dict__.items():
            if key in args_resumed:
                resumed_value = args_resumed[key]
                if resumed_value != value:
                    print('Warning: For the argument `{}` we are using the'
                          ' loaded value `{}`. The provided value was `{}`'
                          '.'.format(key, resumed_value, value))
                    args.__dict__[key] = resumed_value
            else:
                print('Warning: A new argument was added since the last run:'
                      ' `{}`. Using the new value: `{}`.'.format(key, value))

    else:
        # If the experiment directory exists already, we bail in fear.
        if os.path.exists(args.experiment_root):
            if os.listdir(args.experiment_root):
                print('The directory {} already exists and is not empty.'
                      ' If you want to resume training, append --resume to'
                      ' your call.'.format(args.experiment_root))
                exit(1)
        else:
            os.makedirs(args.experiment_root)

        # Store the passed arguments for later resuming and grepping in a nice
        # and readable format.
        with open(args_file, 'w') as f:
            json.dump(vars(args),
                      f,
                      ensure_ascii=False,
                      indent=2,
                      sort_keys=True)

    log_file = os.path.join(args.experiment_root, "train")
    logging.config.dictConfig(common.get_logging_dict(log_file))
    log = logging.getLogger('train')

    # Also show all parameter values at the start, for ease of reading logs.
    log.info('Training using the following parameters:')
    for key, value in sorted(vars(args).items()):
        log.info('{}: {}'.format(key, value))

    # Check them here, so they are not required when --resume-ing.
    if not args.train_set:
        parser.print_help()
        log.error("You did not specify the `train_set` argument!")
        sys.exit(1)
    if not args.image_root:
        parser.print_help()
        log.error("You did not specify the required `image_root` argument!")
        sys.exit(1)

    # Load the data from the CSV file.
    pids, fids = common.load_dataset(args.train_set, args.image_root)
    max_fid_len = max(map(len, fids))  # We'll need this later for logfiles.

    # Setup a tf.Dataset where one "epoch" loops over all PIDS.
    # PIDS are shuffled after every epoch and continue indefinitely.
    unique_pids = np.unique(pids)
    dataset = tf.data.Dataset.from_tensor_slices(unique_pids)
    dataset = dataset.shuffle(len(unique_pids))

    # Constrain the dataset size to a multiple of the batch-size, so that
    # we don't get overlap at the end of each epoch.
    dataset = dataset.take((len(unique_pids) // args.batch_p) * args.batch_p)
    dataset = dataset.repeat(None)  # Repeat forever. Funny way of stating it.

    # For every PID, get K images.
    dataset = dataset.map(lambda pid: sample_k_fids_for_pid(
        pid, all_fids=fids, all_pids=pids, batch_k=args.batch_k))

    # Ungroup/flatten the batches for easy loading of the files.
    dataset = dataset.apply(tf.contrib.data.unbatch())

    # Convert filenames to actual image tensors.
    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)
    dataset = dataset.map(lambda fid, pid: common.fid_to_image(
        fid,
        pid,
        image_root=args.image_root,
        image_size=pre_crop_size if args.crop_augment else net_input_size),
                          num_parallel_calls=args.loading_threads)

    # Augment the data if specified by the arguments.
    if args.flip_augment:
        dataset = dataset.map(lambda im, fid, pid:
                              (tf.image.random_flip_left_right(im), fid, pid))
    if args.crop_augment:
        dataset = dataset.map(lambda im, fid, pid: (tf.random_crop(
            im, net_input_size + (3, )), fid, pid))

    # Group it back into PK batches.
    batch_size = args.batch_p * args.batch_k
    dataset = dataset.batch(batch_size)

    # Overlap producing and consuming for parallelism.
    dataset = dataset.prefetch(1)

    # Since we repeat the data infinitely, we only need a one-shot iterator.
    images, fids, pids = dataset.make_one_shot_iterator().get_next()

    # Create the model and an embedding head.
    model = import_module('nets.' + args.model_name)
    head = import_module('heads.' + args.head_name)

    # Feed the image through the model. The returned `body_prefix` will be used
    # further down to load the pre-trained weights for all variables with this
    # prefix.
    endpoints, body_prefix = model.endpoints(images, is_training=True)
    with tf.name_scope('head'):
        endpoints = head.head(endpoints, args.embedding_dim, is_training=True)

    # Create the loss in two steps:
    # 1. Compute all pairwise distances according to the specified metric.
    # 2. For each anchor along the first dimension, compute its loss.
    dists = loss.cdist(endpoints['emb'], endpoints['emb'], metric=args.metric)
    losses, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[
        args.loss](dists,
                   pids,
                   args.margin,
                   batch_precision_at_k=args.batch_k - 1)

    decDense = tf.layers.dense(
        inputs=endpoints['emb'], units=5120,
        name='decDense')  #  ,activation = tf.nn.relu  ################
    unflat = tf.reshape(decDense, shape=[tf.shape(decDense)[0], 32, 16, 10])
    unp3shape = tf.TensorShape(
        [2 * di for di in unflat.get_shape().as_list()[1:-1]])
    unPool3 = tf.image.resize_nearest_neighbor(unflat,
                                               unp3shape,
                                               name='unpool3')

    deConv3 = tf.layers.conv2d(inputs=unPool3,
                               filters=64,
                               kernel_size=[5, 5],
                               strides=(1, 1),
                               padding='same',
                               activation=tf.nn.relu,
                               name='deConv3')
    unp2shape = tf.TensorShape(
        [2 * di for di in deConv3.get_shape().as_list()[1:-1]])
    unPool2 = tf.image.resize_nearest_neighbor(deConv3,
                                               unp2shape,
                                               name='unpool2')

    deConv2 = tf.layers.conv2d(inputs=unPool2,
                               filters=32,
                               kernel_size=[5, 5],
                               strides=(1, 1),
                               padding='same',
                               activation=tf.nn.relu,
                               name='deConv2')
    unp1shape = tf.TensorShape(
        [2 * di for di in deConv2.get_shape().as_list()[1:-1]])
    unPool1 = tf.image.resize_nearest_neighbor(deConv2,
                                               unp1shape,
                                               name='unpool1')

    deConv1 = tf.layers.conv2d(inputs=unPool1,
                               filters=3,
                               kernel_size=[5, 5],
                               strides=(1, 1),
                               padding='same',
                               activation=None,
                               name='deConv1')
    imClip = deConv1  #tf.clip_by_value(t = deConv1,clip_value_min = -1.0,clip_value_max = 1.0,name='clipRelu')
    print('RconstructeddImage :  ', imClip.name)

    recLoss = tf.multiply(
        0.01, tf.losses.mean_squared_error(
            labels=images,
            predictions=imClip,
        ))
    print('recLoss :  ', recLoss.name)

    decDense1 = tf.layers.dense(
        inputs=endpoints['emb'], units=5120,
        name='decDense1')  #  ,activation = tf.nn.relu  ################
    unflat1 = tf.reshape(decDense1, shape=[tf.shape(decDense1)[0], 32, 16, 10])
    unp3shape1 = tf.TensorShape(
        [2 * di for di in unflat1.get_shape().as_list()[1:-1]])
    unPool3_new = tf.image.resize_nearest_neighbor(unflat1,
                                                   unp3shape1,
                                                   name='unpool3_new')

    deConv3_new = tf.layers.conv2d(inputs=unPool3_new,
                                   filters=64,
                                   kernel_size=[5, 5],
                                   strides=(1, 1),
                                   padding='same',
                                   activation=tf.nn.relu,
                                   name='deConv3_new')
    unp2shape_new = tf.TensorShape(
        [2 * di for di in deConv3_new.get_shape().as_list()[1:-1]])
    unPool2_new = tf.image.resize_nearest_neighbor(deConv3_new,
                                                   unp2shape_new,
                                                   name='unpool2_new')

    deConv2_new = tf.layers.conv2d(inputs=unPool2_new,
                                   filters=3,
                                   kernel_size=[5, 5],
                                   strides=(1, 1),
                                   padding='same',
                                   activation=tf.nn.relu,
                                   name='deConv2_new')
    unp1shape_new = tf.TensorShape(
        [2 * di for di in deConv2_new.get_shape().as_list()[1:-1]])
    unPool1_new = tf.image.resize_nearest_neighbor(deConv2_new,
                                                   unp1shape_new,
                                                   name='unpool1_new')

    deConv1_new = tf.layers.conv2d(inputs=unPool1_new,
                                   filters=3,
                                   kernel_size=[5, 5],
                                   strides=(1, 1),
                                   padding='same',
                                   activation=None,
                                   name='deConv1_new')
    imClip1 = deConv2_new
    print('RconstructeddImage :  ', imClip1.name)
    print(imClip1.shape)

    images2 = tf.image.resize_images(images, [128, 64])
    print(images2.shape)

    recLoss1 = tf.multiply(
        0.01,
        tf.losses.mean_squared_error(
            labels=images2,
            predictions=imClip1,
        ))
    print('recLoss_new :  ', recLoss1.name)

    decDense2 = tf.layers.dense(
        inputs=endpoints['emb'], units=5120,
        name='decDense2')  #  ,activation = tf.nn.relu  ################
    unflat12 = tf.reshape(decDense2,
                          shape=[tf.shape(decDense2)[0], 32, 16, 10])
    unp3shape12 = tf.TensorShape(
        [2 * di for di in unflat12.get_shape().as_list()[1:-1]])
    unPool3_new2 = tf.image.resize_nearest_neighbor(unflat12,
                                                    unp3shape12,
                                                    name='unpool3_new2')

    deConv3_new2 = tf.layers.conv2d(inputs=unPool3_new2,
                                    filters=3,
                                    kernel_size=[5, 5],
                                    strides=(1, 1),
                                    padding='same',
                                    activation=tf.nn.relu,
                                    name='deConv3_new2')
    unp2shape_new2 = tf.TensorShape(
        [2 * di for di in deConv3_new2.get_shape().as_list()[1:-1]])
    unPool2_new2 = tf.image.resize_nearest_neighbor(deConv3_new2,
                                                    unp2shape_new2,
                                                    name='unpool2_new2')

    imClip11 = deConv3_new2
    images21 = tf.image.resize_images(images, [64, 32])

    recLoss2 = tf.multiply(
        0.01,
        tf.losses.mean_squared_error(
            labels=images21,
            predictions=imClip11,
        ))
    print('recLoss_new :  ', recLoss2.name)

    decDensel = tf.layers.dense(
        inputs=endpoints['emb'], units=5120,
        name='decDensel')  #  ,activation = tf.nn.relu  ################
    unflatl = tf.reshape(decDensel, shape=[tf.shape(decDensel)[0], 32, 16, 10])
    unp3shapel = tf.TensorShape(
        [2 * di for di in unflatl.get_shape().as_list()[1:-1]])
    unPool3l = tf.image.resize_nearest_neighbor(unflatl,
                                                unp3shapel,
                                                name='unpool3l')

    deConv3l = tf.layers.conv2d(inputs=unPool3l,
                                filters=64,
                                kernel_size=[5, 5],
                                strides=(1, 1),
                                padding='same',
                                activation=tf.nn.relu,
                                name='deConv3l')
    unp2shapel = tf.TensorShape(
        [2 * di for di in deConv3l.get_shape().as_list()[1:-1]])
    unPool2l = tf.image.resize_nearest_neighbor(deConv3l,
                                                unp2shapel,
                                                name='unpool2l')

    deConv2l = tf.layers.conv2d(inputs=unPool2l,
                                filters=32,
                                kernel_size=[5, 5],
                                strides=(1, 1),
                                padding='same',
                                activation=tf.nn.relu,
                                name='deConv2l')
    unp1shapel = tf.TensorShape(
        [2 * di for di in deConv2l.get_shape().as_list()[1:-1]])
    unPool1l = tf.image.resize_nearest_neighbor(deConv2l,
                                                unp1shapel,
                                                name='unpool1l')

    deConv1l = tf.layers.conv2d(inputs=unPool1l,
                                filters=3,
                                kernel_size=[5, 5],
                                strides=(1, 1),
                                padding='same',
                                activation=None,
                                name='deConv1l')
    imClipl = deConv1l  #tf.clip_by_value(t = deConv1,clip_value_min = -1.0,clip_value_max = 1.0,name='clipRelu')
    print('RconstructeddImage :  ', imClipl.name)

    recLossl = tf.multiply(
        0.01, tf.losses.mean_squared_error(
            labels=images,
            predictions=imClipl,
        ))
    print('recLoss :  ', recLossl.name)

    # Count the number of active entries, and compute the total batch loss.
    num_active = tf.reduce_sum(tf.cast(tf.greater(losses, 1e-5), tf.float32))
    loss_mean = tf.reduce_mean(losses)

    # Some logging for tensorboard.
    tf.summary.histogram('loss_distribution', losses)
    tf.summary.scalar('loss', loss_mean)
    tf.summary.scalar('batch_top1', train_top1)
    tf.summary.scalar('batch_prec_at_{}'.format(args.batch_k - 1), prec_at_k)
    tf.summary.scalar('active_count', num_active)
    tf.summary.histogram('embedding_dists', dists)
    tf.summary.histogram('embedding_pos_dists', pos_dists)
    tf.summary.histogram('embedding_neg_dists', neg_dists)
    tf.summary.histogram('embedding_lengths',
                         tf.norm(endpoints['emb_raw'], axis=1))

    # Create the mem-mapped arrays in which we'll log all training detail in
    # addition to tensorboard, because tensorboard is annoying for detailed
    # inspection and actually discards data in histogram summaries.
    if args.detailed_logs:
        log_embs = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'embeddings'),
            dtype=np.float32,
            shape=(args.train_iterations, batch_size, args.embedding_dim))
        log_loss = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'losses'),
            dtype=np.float32,
            shape=(args.train_iterations, batch_size))
        log_fids = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'fids'),
            dtype='S' + str(max_fid_len),
            shape=(args.train_iterations, batch_size))

    # These are collected here before we add the optimizer, because depending
    # on the optimizer, it might add extra slots, which are also global
    # variables, with the exact same prefix.
    model_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                        body_prefix)

    # Define the optimizer and the learning-rate schedule.
    # Unfortunately, we get NaNs if we don't handle no-decay separately.
    global_step = tf.Variable(0, name='global_step', trainable=False)
    if 0 <= args.decay_start_iteration < args.train_iterations:
        learning_rate = tf.train.exponential_decay(
            args.learning_rate,
            tf.maximum(0, global_step - args.decay_start_iteration),
            args.train_iterations - args.decay_start_iteration, 0.001)
    else:
        learning_rate = args.learning_rate
    tf.summary.scalar('learning_rate', learning_rate)
    optimizer = tf.train.AdamOptimizer(learning_rate)
    # Feel free to try others!
    # optimizer = tf.train.AdadeltaOptimizer(learning_rate)

    # Update_ops are used to update batchnorm stats.
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        train_op = optimizer.minimize(tf.add(
            loss_mean,
            tf.add(recLoss, tf.add(recLoss1, tf.add(recLoss2, recLossl)))),
                                      global_step=global_step)

    # Define a saver for the complete model.
    checkpoint_saver = tf.train.Saver(max_to_keep=0)

    with tf.Session() as sess:
        if args.resume:
            # In case we're resuming, simply load the full checkpoint to init.
            last_checkpoint = tf.train.latest_checkpoint(args.experiment_root)
            log.info('Restoring from checkpoint: {}'.format(last_checkpoint))
            checkpoint_saver.restore(sess, last_checkpoint)
        else:
            # But if we're starting from scratch, we may need to load some
            # variables from the pre-trained weights, and random init others.
            sess.run(tf.global_variables_initializer())
            if args.initial_checkpoint is not None:
                saver = tf.train.Saver(model_variables)
                saver.restore(sess, args.initial_checkpoint)

            # In any case, we also store this initialization as a checkpoint,
            # such that we could run exactly reproduceable experiments.
            checkpoint_saver.save(sess,
                                  os.path.join(args.experiment_root,
                                               'checkpoint'),
                                  global_step=0)

        merged_summary = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(args.experiment_root,
                                               sess.graph)

        start_step = sess.run(global_step)
        log.info('Starting training from iteration {}.'.format(start_step))

        # Finally, here comes the main-loop. This `Uninterrupt` is a handy
        # utility such that an iteration still finishes on Ctrl+C and we can
        # stop the training cleanly.
        with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u:
            for i in range(start_step, args.train_iterations):

                # Compute gradients, update weights, store logs!
                start_time = time.time()
                _, summary, step, b_prec_at_k, b_embs, b_loss, b_fids ,b_rec, b_rec1= \
                    sess.run([train_op, merged_summary, global_step,
                              prec_at_k, endpoints['emb'], losses, fids,recLoss, recLoss1])
                elapsed_time = time.time() - start_time

                # Compute the iteration speed and add it to the summary.
                # We did observe some weird spikes that we couldn't track down.
                summary2 = tf.Summary()
                summary2.value.add(tag='secs_per_iter',
                                   simple_value=elapsed_time)
                summary_writer.add_summary(summary2, step)
                summary_writer.add_summary(summary, step)

                if args.detailed_logs:
                    log_embs[i], log_loss[i], log_fids[
                        i] = b_embs, b_loss, b_fids

                # Do a huge print out of the current progress.
                seconds_todo = (args.train_iterations - step) * elapsed_time
                log.info(
                    'iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, '
                    'recLoss: {:.3f} batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it)'.
                    format(step, float(np.min(b_loss)), float(np.mean(b_loss)),
                           float(np.max(b_loss)), b_rec, args.batch_k - 1,
                           float(b_prec_at_k),
                           timedelta(seconds=int(seconds_todo)), elapsed_time))
                sys.stdout.flush()
                sys.stderr.flush()

                # Save a checkpoint of training every so often.
                if (args.checkpoint_frequency > 0
                        and step % args.checkpoint_frequency == 0):
                    checkpoint_saver.save(sess,
                                          os.path.join(args.experiment_root,
                                                       'checkpoint'),
                                          global_step=step)

                # Stop the main-loop at the end of the step, if requested.
                if u.interrupted:
                    log.info("Interrupted on request!")
                    break

        # Store one final checkpoint. This might be redundant, but it is crucial
        # in case intermediate storing was disabled and it saves a checkpoint
        # when the process was interrupted.
        checkpoint_saver.save(sess,
                              os.path.join(args.experiment_root, 'checkpoint'),
                              global_step=step)
Ejemplo n.º 7
0
def main():
    # my_devices = tf.config.experimental.list_physical_devices(device_type='CPU')
    # tf.config.experimental.set_visible_devices(devices= my_devices, device_type='CPU')

    # # To find out which devices your operations and tensors are assigned to
    # tf.debugging.set_log_device_placement(True)
    args = parser.parse_args(args=[])

    show_all_parameters(args)

    if not args.train_set:
        parser.print_help()
        print("You didn't specify the 'train_set' argument!")
        sys.exit(1)
    if not args.image_root:
        parser.print_help()
        print("You didn't specify the 'image_root' argument!")
        sys.exit(1)

    pids, fids = common.load_dataset(args.train_set, args.image_root)

    unique_pids = np.unique(pids)
    dataset = tf.data.Dataset.from_tensor_slices(unique_pids)
    dataset = dataset.shuffle(len(unique_pids))

    # Take the dataset size equal to a multiple of the batch-size, so that
    # we don't get overlap at the end of each epoch.
    dataset = dataset.take((len(unique_pids) // args.batch_p) * args.batch_p)
    dataset = dataset.repeat(None)  # Repeat indefinitely.

    # For every PID, get K images.
    dataset = dataset.map(lambda pid: sample_k_fids_for_pid(
        pid, all_fids=fids, all_pids=pids, batch_k=args.batch_k))

    # Ungroup/flatten the batches
    dataset = dataset.unbatch()

    # Convert filenames to actual image tensors.
    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)
    dataset = dataset.map(lambda fid, pid: common.fid_to_image(
        fid,
        pid,
        image_root=args.image_root,
        image_size=pre_crop_size if args.crop_augment else net_input_size))

    if args.flip_augment:
        dataset = dataset.map(lambda im, fid, pid:
                              (tf.image.random_flip_left_right(im), fid, pid))
    if args.crop_augment:
        dataset = dataset.map(lambda im, fid, pid: (tf.image.random_crop(
            im, net_input_size + (3, )), fid, pid))

    # Group the data into PK batches.
    batch_size = args.batch_p * args.batch_k
    dataset = dataset.batch(batch_size)

    dataset = dataset.prefetch(1)
    dataiter = iter(dataset)

    model = Trinet(args.embedding_dim)
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        args.learning_rate, args.train_iterations - args.decay_start_iteration,
        0.001)
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

    writer = tf.summary.create_file_writer(args.experiment_root)
    ckpt = tf.train.Checkpoint(step=tf.Variable(0),
                               optimizer=optimizer,
                               net=model)
    manager = tf.train.CheckpointManager(ckpt,
                                         args.experiment_root,
                                         max_to_keep=10)

    if args.resume:
        ckpt.restore(manager.latest_checkpoint)

    for epoch in range(args.train_iterations):

        # for images,fids,pids in dataset:
        images, fids, pids = next(dataiter)
        with tf.GradientTape() as tape:
            emb = model(images)
            dists = loss.cdist(emb, emb)
            losses, top1, prec, topksame, negdist, posdist = loss.batch_hard(
                dists, pids, args.margin, args.batch_k)
            lossavg = tf.reduce_mean(losses)
            lossnp = losses.numpy()
        with writer.as_default():
            tf.summary.scalar("loss", lossavg, step=epoch)
            tf.summary.scalar('batch_top1', top1, step=epoch)
            tf.summary.scalar('batch_prec_at_{}'.format(args.batch_k - 1),
                              prec,
                              step=epoch)
            tf.summary.histogram('losses', losses, step=epoch)
            tf.summary.histogram('embedding_dists', dists, step=epoch)
            tf.summary.histogram('embedding_pos_dists', negdist, step=epoch)
            tf.summary.histogram('embedding_neg_dists', posdist, step=epoch)

        print('iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, '
              ' batch-p@{}: {:.2%}'.format(epoch, float(np.min(lossnp)),
                                           float(np.mean(lossnp)),
                                           float(np.max(lossnp)),
                                           args.batch_k - 1, float(prec)))
        grad = tape.gradient(lossavg, model.trainable_variables)
        optimizer.apply_gradients(zip(grad, model.trainable_variables))
        ckpt.step.assign_add(1)
        if epoch % args.checkpoint_frequency == 0:
            manager.save()
Ejemplo n.º 8
0
def main():
    # Verify that parameters are set correctly.
    args = parser.parse_args()

    # Possibly auto-generate the output filename.
    if args.filename is None:
        basename = os.path.basename(args.dataset)
        args.filename = os.path.splitext(basename)[0] + '_embeddings.h5'

    # Load the args from the original experiment.
    args_file = os.path.join(args.experiment_root, 'args.json')

    if os.path.isfile(args_file):
        if not args.quiet:
            print('Loading args from {}.'.format(args_file))
        with open(args_file, 'r') as f:
            args_resumed = json.load(f)

        # Add arguments from training.
        for key, value in args_resumed.items():
            args.__dict__.setdefault(key, value)
        # won't be used
        args.image_root = args.image_root or args_resumed['image_root']
    else:
        raise IOError(
            '`args.json` could not be found in: {}'.format(args_file))

    if not args.quiet:
        print('Evaluating using the following parameters:')
        for key, value in sorted(vars(args).items()):
            print('{}: {}'.format(key, value))

    # Load the data from the CSV file.
    _, data_fids = common.load_dataset(args.dataset, args.image_root)

    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)

    # Setup a tf Dataset containing all images.
    dataset = tf.data.Dataset.from_tensor_slices(data_fids)

    # Convert filenames to actual image tensors.
    dataset = dataset.map(lambda fid: common.fid_to_image(
        fid,
        tf.constant('dummy'),
        image_root=args.image_root,
        image_size=pre_crop_size if args.crop_augment else net_input_size),
                          num_parallel_calls=args.loading_threads)

    modifiers = ['original']

    # Group it back into PK batches.
    dataset = dataset.batch(args.batch_size)

    # Overlap producing and consuming.
    dataset = dataset.prefetch(1)

    images, _, _ = dataset.make_one_shot_iterator().get_next()

    # Create the model and an embedding head.
    model = import_module('nets.' + args.model_name)
    head = import_module('heads.' + args.head_name)

    endpoints, body_prefix = model.endpoints(images, is_training=False)
    with tf.name_scope('head'):
        endpoints = head.head(endpoints, args.embedding_dim, is_training=False)

    with h5py.File(args.filename, 'w') as f_out, tf.Session() as sess:
        # Initialize the network/load the checkpoint.
        if args.checkpoint is None:
            checkpoint = tf.train.latest_checkpoint(args.experiment_root)
        else:
            checkpoint = os.path.join(args.experiment_root, args.checkpoint)
        if not args.quiet:
            print('Restoring from checkpoint: {}'.format(checkpoint))
        tf.train.Saver().restore(sess, checkpoint)

        # Go ahead and embed the whole dataset, with all augmented versions too.
        emb_storage = np.zeros(
            (len(data_fids) * len(modifiers), args.embedding_dim), np.float32)
        for start_idx in count(step=args.batch_size):
            try:
                emb = sess.run(endpoints['emb'])
                print('\rEmbedded batch {}-{}/{}'.format(
                    start_idx, start_idx + len(emb), len(emb_storage)),
                      flush=True,
                      end='')
                emb_storage[start_idx:start_idx + len(emb)] = emb
            except tf.errors.OutOfRangeError:
                break  # This just indicates the end of the dataset.

        print()

        # Store the final embeddings.
        emb_dataset = f_out.create_dataset('emb', data=emb_storage)

        # Store information about the produced augmentation and in case no crop
        # augmentation was used, if the images are resized or avg pooled.
        f_out.create_dataset('augmentation_types',
                             data=np.asarray(modifiers, dtype='|S'))
Ejemplo n.º 9
0
if args.filename is None:
    basename = os.path.basename(args.dataset)
    args.filename = os.path.splitext(basename)[0] + '_embeddings.h5'
args.filename = os.path.join(args.experiment_root, args.filename)

_, data_fids = common.load_dataset(args.dataset, args.image_root)

net_input_size = (args.net_input_height, args.net_input_width)
# pre_crop_size = (args.pre_crop_height, args.pre_crop_width)

dataset = tf.data.Dataset.from_tensor_slices(data_fids)

dataset = dataset.map(
    lambda fid: common.fid_to_image(fid,
                                    tf.constant('dummy'),
                                    image_root=args.image_root,
                                    image_size=net_input_size),
    num_parallel_calls=args.loading_threads)

dataset = dataset.batch(args.batch_size)
dataset = dataset.prefetch(1)

model = Trinet(args.embedding_dim)
optimizer = tf.keras.optimizers.Adam()
ckpt = tf.train.Checkpoint(step=tf.Variable(0), optimizer=optimizer, net=model)
manager = tf.train.CheckpointManager(ckpt,
                                     args.experiment_root,
                                     max_to_keep=10)

ckpt.restore(manager.latest_checkpoint).expect_partial()
Ejemplo n.º 10
0
def main():
    args = parser.parse_args()

    # We store all arguments in a json file. This has two advantages:       我们将所有参数存储在json文件中。 这有两个好处:
    # 1. We can always get back and see what exactly that experiment was    1.我们总是可以回头看看实验是什么
    # 2. We can resume an experiment as-is without needing to remember all flags.2.我们可以按原样恢复实验,无需记住所有标志。
    args_file = os.path.join(args.experiment_root, 'args.json')
    if args.resume:
        if not os.path.isfile(args_file):
            raise IOError('`args.json` not found in {}'.format(args_file))

        print('Loading args from {}.'.format(args_file))
        with open(args_file, 'r') as f:
            args_resumed = json.load(f)
        args_resumed['resume'] = True  # This would be overwritten.

        # When resuming, we not only want to populate the args object with the
        # values from the file, but we also want to check for some possible
        # conflicts between loaded and given arguments.
        #恢复时,我们不仅需要使用文件中的值填充args对象,但我们也想检查加载参数和给定参数之间的一些可能的冲突。
        for key, value in args.__dict__.items():
            if key in args_resumed:
                resumed_value = args_resumed[key]
                if resumed_value != value:
                    print('Warning: For the argument `{}` we are using the'
                          ' loaded value `{}`. The provided value was `{}`'
                          '.'.format(key, resumed_value, value))
                    args.__dict__[key] = resumed_value
            else:
                print('Warning: A new argument was added since the last run:'
                      ' `{}`. Using the new value: `{}`.'.format(key, value))

    else:
        # If the experiment directory exists already, we bail in fear.如果实验目录已经存在,我们就会担心。
        if os.path.exists(args.experiment_root):
            if os.listdir(args.experiment_root):
                print('The directory {} already exists and is not empty.'
                      ' If you want to resume training, append --resume to'
                      ' your call.'.format(args.experiment_root))
                exit(1)
        else:
            os.makedirs(args.experiment_root)

        # Store the passed arguments for later resuming and grepping in a nice
        # and readable format.将传递的参数存储起来,以便稍后以一种很好且可读的格式恢复和刷新。
        with open(args_file, 'w') as f:
            json.dump(vars(args),
                      f,
                      ensure_ascii=False,
                      indent=2,
                      sort_keys=True)

    log_file = os.path.join(args.experiment_root, "train")
    logging.config.dictConfig(common.get_logging_dict(log_file))
    log = logging.getLogger('train')

    # Also show all parameter values at the start, for ease of reading logs.
    #同时在开始时显示所有参数值,便于读取日志。
    log.info('Training using the following parameters:')
    for key, value in sorted(vars(args).items()):
        log.info('{}: {}'.format(key, value))

    # Check them here, so they are not required when --resume-ing.在这里检查他们,
    if not args.train_set:
        parser.print_help()
        log.error("You did not specify the `train_set` argument!")
        sys.exit(1)
    if not args.image_root:
        parser.print_help()
        log.error("You did not specify the required `image_root` argument!")
        sys.exit(1)

    # Load the data from the CSV file.                               加载CSV文件中的数据。
    pids, fids = common.load_dataset(args.train_set, args.image_root)
    max_fid_len = max(map(
        len, fids))  # We'll need this later for logfiles.我们稍后需要这个日志文件。

    # Setup a tf.Dataset where one "epoch" loops over all PIDS. 设置一个tf.Dataset,其中一个“epoch”在所有PIDS上循环。
    # PIDS are shuffled after every epoch and continue indefinitely.PIDS在每个时代之后都会被洗牌并无限期地继续下去。

    unique_pids = np.unique(pids)
    dataset = tf.data.Dataset.from_tensor_slices(unique_pids)
    dataset = dataset.shuffle(len(unique_pids))

    # Constrain the dataset size to a multiple of the batch-size, so that
    # we don't get overlap at the end of each epoch.将数据集大小限制为批量的倍数,以便在每个时期结束时我们不会重叠。
    dataset = dataset.take((len(unique_pids) // args.batch_p) * args.batch_p)
    dataset = dataset.repeat(
        None)  # Repeat forever. Funny way of stating it.  永远重复。 说明它的有趣方式。

    # For every PID, get K images.对于每个PID,获得K个图像。
    dataset = dataset.map(lambda pid: sample_k_fids_for_pid(
        pid, all_fids=fids, all_pids=pids, batch_k=args.batch_k))

    # Ungroup/flatten the batches for easy loading of the files.取消组合/拼合批次以便轻松加载文件。
    dataset = dataset.apply(tf.contrib.data.unbatch())

    # Convert filenames to actual image tensors.                 将文件名转换为实际图像张量。
    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)
    dataset = dataset.map(lambda fid, pid: common.fid_to_image(
        fid,
        pid,
        image_root=args.image_root,
        image_size=pre_crop_size if args.crop_augment else net_input_size),
                          num_parallel_calls=args.loading_threads)

    # Augment the data if specified by the arguments.            如果由参数指定,则增加数据。
    if args.flip_augment:
        dataset = dataset.map(lambda im, fid, pid:
                              (tf.image.random_flip_left_right(im), fid, pid))
    if args.crop_augment:
        dataset = dataset.map(lambda im, fid, pid: (tf.random_crop(
            im, net_input_size + (3, )), fid, pid))

    # Group it back into PK batches.                将其重新分组为PK批次。
    batch_size = args.batch_p * args.batch_k
    dataset = dataset.batch(batch_size)

    # Overlap producing and consuming for parallelism.
    dataset = dataset.prefetch(1)

    # Since we repeat the data infinitely, we only need a one-shot iterator.由于我们无限地重复数据,因此我们只需要一个one-shot 迭代器。
    images, fids, pids = dataset.make_one_shot_iterator().get_next()

    # Create the model and an embedding head.
    model = import_module('nets.' + args.model_name)
    head = import_module('heads.' + args.head_name)

    # Feed the image through the model. The returned `body_prefix` will be used
    # further down to load the pre-trained weights for all variables with this
    # prefix.
    #通过模型提供图像。返回的`body_prefix`将进一步用于加载具有此前缀的所有变量的预先训练权重。
    endpoints, body_prefix = model.endpoints(images, is_training=True)
    with tf.name_scope('head'):
        endpoints = head.head(endpoints, args.embedding_dim, is_training=True)

    # Create the loss in two steps:                                       分两步创建损失:
    # 1. Compute all pairwise distances according to the specified metric.1.根据指定的度量计算所有成对距离。
    # 2. For each anchor along the first dimension, compute its loss.     2.对于第一维中的每个锚,计算其损失。
    dists = loss.cdist(endpoints['emb'], endpoints['emb'], metric=args.metric)
    losses, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[
        args.loss](dists,
                   pids,
                   args.margin,
                   batch_precision_at_k=args.batch_k - 1)

    # Count the number of active entries, and compute the total batch loss.  计算活动条目的数量,并计算总批量损失。
    num_active = tf.reduce_sum(tf.cast(tf.greater(losses, 1e-5), tf.float32))
    loss_mean = tf.reduce_mean(losses)

    # Some logging for tensorboard.            一些日志记录在tensorboard。
    tf.summary.histogram('loss_distribution', losses)
    tf.summary.scalar('loss', loss_mean)
    tf.summary.scalar('batch_top1', train_top1)
    tf.summary.scalar('batch_prec_at_{}'.format(args.batch_k - 1), prec_at_k)
    tf.summary.scalar('active_count', num_active)
    tf.summary.histogram('embedding_dists', dists)
    tf.summary.histogram('embedding_pos_dists', pos_dists)
    tf.summary.histogram('embedding_neg_dists', neg_dists)
    tf.summary.histogram('embedding_lengths',
                         tf.norm(endpoints['emb_raw'], axis=1))

    # Create the mem-mapped arrays in which we'll log all training detail in
    # addition to tensorboard, because tensorboard is annoying for detailed
    # inspection and actually discards data in histogram summaries.
    #创建mem - mapped数组,我们将记录除tensorboard之外的所有训练细节,因为tensorboard对于详细检查很烦人,实际上在直方图总结中丢弃了数据。

    if args.detailed_logs:
        log_embs = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'embeddings'),
            dtype=np.float32,
            shape=(args.train_iterations, batch_size, args.embedding_dim))
        log_loss = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'losses'),
            dtype=np.float32,
            shape=(args.train_iterations, batch_size))
        log_fids = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'fids'),
            dtype='S' + str(max_fid_len),
            shape=(args.train_iterations, batch_size))

    # These are collected here before we add the optimizer, because depending
    # on the optimizer, it might add extra slots, which are also global
    # variables, with the exact same prefix.
    #在我们添加优化器之前收集这些信息,因为根据优化器的不同,它可能会添加额外的插槽,这些插槽也是全局变量,具有完全相同的前缀。
    model_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                        body_prefix)

    # Define the optimizer and the learning-rate schedule.              定义优化器和学习率计划。
    # Unfortunately, we get NaNs if we don't handle no-decay separately. 不幸的是,如果我们不单独处理无衰减,我们会得到NaNs。
    global_step = tf.Variable(0, name='global_step', trainable=False)
    if 0 <= args.decay_start_iteration < args.train_iterations:
        learning_rate = tf.train.exponential_decay(
            args.learning_rate,
            tf.maximum(0, global_step - args.decay_start_iteration),
            args.train_iterations - args.decay_start_iteration, 0.001)
    else:
        learning_rate = args.learning_rate
    tf.summary.scalar('learning_rate', learning_rate)
    optimizer = tf.train.AdamOptimizer(learning_rate)
    # Feel free to try others!                               随意尝试别的的优化器!
    # optimizer = tf.train.AdadeltaOptimizer(learning_rate)

    # Update_ops are used to update batchnorm stats.  Update_ops用于更新batchnorm统计信息。
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        train_op = optimizer.minimize(loss_mean, global_step=global_step)

    # Define a saver for the complete model.  为整个模型定义一个保存器。
    checkpoint_saver = tf.train.Saver(max_to_keep=0)

    with tf.Session() as sess:
        if args.resume:
            # In case we're resuming, simply load the full checkpoint to init.如果我们正在恢复,只需将完整的检查点加载到init。
            last_checkpoint = tf.train.latest_checkpoint(args.experiment_root)
            log.info('Restoring from checkpoint: {}'.format(last_checkpoint))
            checkpoint_saver.restore(sess, last_checkpoint)
        else:
            # But if we're starting from scratch, we may need to load some
            # variables from the pre-trained weights, and random init others.
            #但是如果我们从头开始,我们可能需要从预先训练的权重中加载一些变量,并随机初始化其他的变量。
            sess.run(tf.global_variables_initializer())
            if args.initial_checkpoint is not None:
                saver = tf.train.Saver(model_variables)
                saver.restore(sess, args.initial_checkpoint)

            # In any case, we also store this initialization as a checkpoint,
            # such that we could run exactly reproduceable experiments.
            #无论如何,我们也将这个初始化作为检查点存储,以便我们可以运行完全可再生的实验。
            checkpoint_saver.save(sess,
                                  os.path.join(args.experiment_root,
                                               'checkpoint'),
                                  global_step=0)

        merged_summary = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(args.experiment_root,
                                               sess.graph)

        start_step = sess.run(global_step)
        log.info('Starting training from iteration {}.'.format(start_step))

        # Finally, here comes the main-loop. This `Uninterrupt` is a handy
        # utility such that an iteration still finishes on Ctrl+C and we can
        # stop the training cleanly.
        #最后,这里是主循环。这个`Uninterrupt`是一个非常方便的工具,可以在Ctrl + C之后完成迭代才停止,我们可以干净地停止训练。
        with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u:
            for i in range(start_step, args.train_iterations):

                # Compute gradients, update weights, store logs!计算梯度,更新权重,存储日志!
                start_time = time.time()
                _, summary, step, b_prec_at_k, b_embs, b_loss, b_fids = \
                    sess.run([train_op, merged_summary, global_step,
                              prec_at_k, endpoints['emb'], losses, fids])
                elapsed_time = time.time() - start_time

                # Compute the iteration speed and add it to the summary.计算迭代速度并将其添加到摘要中。
                # We did observe some weird spikes that we couldn't track down.我们确实观察到一些我们无法追查的奇怪尖峰。
                summary2 = tf.Summary()
                summary2.value.add(tag='secs_per_iter',
                                   simple_value=elapsed_time)
                summary_writer.add_summary(summary2, step)
                summary_writer.add_summary(summary, step)

                if args.detailed_logs:
                    log_embs[i], log_loss[i], log_fids[
                        i] = b_embs, b_loss, b_fids

                # Do a huge print out of the current progress.            在当前的进展中做大量的印刷。
                seconds_todo = (args.train_iterations - step) * elapsed_time
                log.info(
                    'iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, '
                    'batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it)'.format(
                        step, float(np.min(b_loss)), float(np.mean(b_loss)),
                        float(np.max(b_loss)), args.batch_k - 1,
                        float(b_prec_at_k),
                        timedelta(seconds=int(seconds_todo)), elapsed_time))
                sys.stdout.flush()
                sys.stderr.flush()

                # Save a checkpoint of training every so often.每隔一段时间保存一次训练的检查点checkpoint。
                if (args.checkpoint_frequency > 0
                        and step % args.checkpoint_frequency == 0):
                    checkpoint_saver.save(sess,
                                          os.path.join(args.experiment_root,
                                                       'checkpoint'),
                                          global_step=step)

                # Stop the main-loop at the end of the step, if requested. 如果需要,在步骤结束时停止主循环。
                if u.interrupted:
                    log.info("Interrupted on request!")
                    break

        # Store one final checkpoint. This might be redundant, but it is crucial
        # in case intermediate storing was disabled and it saves a checkpoint
        # when the process was interrupted.
        #存储一个最终检查点。 这可能是多余的,但是在中间存储被禁用的情况下它是至关重要的,并且在进程被中断时保存检查点。
        checkpoint_saver.save(sess,
                              os.path.join(args.experiment_root, 'checkpoint'),
                              global_step=step)
Ejemplo n.º 11
0
        print('{}: {}'.format(key, value))

# Load the data from the CSV file.
_, data_fids = common.load_dataset(args.dataset, args.image_root)

net_input_size = (args.net_input_height, args.net_input_width)
pre_crop_size = (args.pre_crop_height, args.pre_crop_width)

# Setup a tf Dataset containing all images.
dataset = tf.data.Dataset.from_tensor_slices(data_fids)

# Convert filenames to actual image tensors.
dataset = dataset.map(
    lambda fid: common.fid_to_image(fid,
                                    tf.convert_to_tensor('dummy'),
                                    image_root=args.image_root,
                                    image_size=pre_crop_size
                                    if args.crop_augment else net_input_size),
    num_parallel_calls=args.loading_threads)

# Augment the data if specified by the arguments.
# `modifiers` is a list of strings that keeps track of which augmentations
# have been applied, so that a human can understand it later on.
modifiers = ['original']

#if args.flip_augment:
#    dataset = dataset.map(flip_augment)
#    dataset = dataset.apply(tf.contrib.data.unbatch())
#    modifiers = [o + m for m in ['', '_flip'] for o in modifiers]
#if args.crop_augment == 'center':
#    dataset = dataset.map(lambda im, fid, pid:
Ejemplo n.º 12
0
def main(argv):
    # Verify that parameters are set correctly.
    args = parser.parse_args(argv)

    if not os.path.exists(args.dataset):
        return

    # Possibly auto-generate the output filename.
    if args.filename is None:
        basename = os.path.basename(args.dataset)
        args.filename = os.path.splitext(basename)[0] + '_embeddings.h5'

    os_utils.touch_dir(os.path.join(args.experiment_root, args.foldername))

    log_file = os.path.join(args.experiment_root, args.foldername, "embed")
    logging.config.dictConfig(common.get_logging_dict(log_file))
    log = logging.getLogger('embed')

    args.filename = os.path.join(args.experiment_root, args.foldername,
                                 args.filename)
    var_filepath = os.path.join(args.experiment_root, args.foldername,
                                args.filename[:-3] + '_var.txt')
    # Load the args from the original experiment.
    args_file = os.path.join(args.experiment_root, 'args.json')

    if os.path.isfile(args_file):
        if not args.quiet:
            print('Loading args from {}.'.format(args_file))
        with open(args_file, 'r') as f:
            args_resumed = json.load(f)

        # Add arguments from training.
        for key, value in args_resumed.items():
            args.__dict__.setdefault(key, value)

        # A couple special-cases and sanity checks
        if (args_resumed['crop_augment']) == (args.crop_augment is None):
            print('WARNING: crop augmentation differs between training and '
                  'evaluation.')
        args.image_root = args.image_root or args_resumed['image_root']
    else:
        raise IOError(
            '`args.json` could not be found in: {}'.format(args_file))

    # Check a proper aggregator is provided if augmentation is used.
    if args.flip_augment or args.crop_augment == 'five':
        if args.aggregator is None:
            print(
                'ERROR: Test time augmentation is performed but no aggregator'
                'was specified.')
            exit(1)
    else:
        if args.aggregator is not None:
            print('ERROR: No test time augmentation that needs aggregating is '
                  'performed but an aggregator was specified.')
            exit(1)

    if not args.quiet:
        print('Evaluating using the following parameters:')
        for key, value in sorted(vars(args).items()):
            print('{}: {}'.format(key, value))

    # Load the data from the CSV file.
    _, data_fids = common.load_dataset(args.dataset, args.image_root)

    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)

    # Setup a tf Dataset containing all images.
    dataset = tf.data.Dataset.from_tensor_slices(data_fids)

    # Convert filenames to actual image tensors.
    dataset = dataset.map(lambda fid: common.fid_to_image(
        fid,
        tf.constant('dummy'),
        image_root=args.image_root,
        image_size=pre_crop_size if args.crop_augment else net_input_size),
                          num_parallel_calls=args.loading_threads)

    # Augment the data if specified by the arguments.
    # `modifiers` is a list of strings that keeps track of which augmentations
    # have been applied, so that a human can understand it later on.
    modifiers = ['original']
    if args.flip_augment:
        dataset = dataset.map(flip_augment)
        dataset = dataset.apply(tf.contrib.data.unbatch())
        modifiers = [o + m for m in ['', '_flip'] for o in modifiers]

    if args.crop_augment == 'center':
        dataset = dataset.map(lambda im, fid, pid:
                              (five_crops(im, net_input_size)[0], fid, pid))
        modifiers = [o + '_center' for o in modifiers]
    elif args.crop_augment == 'five':
        dataset = dataset.map(lambda im, fid, pid:
                              (tf.stack(five_crops(im, net_input_size)),
                               tf.stack([fid] * 5), tf.stack([pid] * 5)))
        dataset = dataset.apply(tf.contrib.data.unbatch())
        modifiers = [
            o + m for o in modifiers for m in [
                '_center', '_top_left', '_top_right', '_bottom_left',
                '_bottom_right'
            ]
        ]
    elif args.crop_augment == 'avgpool':
        modifiers = [o + '_avgpool' for o in modifiers]
    else:
        modifiers = [o + '_resize' for o in modifiers]

    emb_model = EmbeddingModel(args)

    # Group it back into PK batches.
    dataset = dataset.batch(args.batch_size)
    dataset = dataset.map(lambda im, fid, pid:
                          (emb_model.preprocess_input(im), fid, pid))
    # Overlap producing and consuming.
    dataset = dataset.prefetch(1)
    tf.keras.backend.set_learning_phase(0)

    with h5py.File(args.filename, 'w') as f_out:

        ckpt = tf.train.Checkpoint(step=tf.Variable(1), net=emb_model)
        manager = tf.train.CheckpointManager(ckpt,
                                             osp.join(args.experiment_root,
                                                      'tf_ckpts'),
                                             max_to_keep=1)
        ckpt.restore(manager.latest_checkpoint)
        if manager.latest_checkpoint:
            print("Restored from {}".format(manager.latest_checkpoint))
        else:
            print("Initializing from scratch.")

        emb_storage = np.zeros(
            (len(data_fids) * len(modifiers), args.embedding_dim), np.float32)

        # for batch_idx,batch in enumerate(dataset):
        dataset_iter = iter(dataset)
        for start_idx in count(step=args.batch_size):

            try:
                images, _, _ = next(dataset_iter)
                emb = emb_model(images)
                emb_storage[start_idx:start_idx + len(emb)] += emb
                print('\rEmbedded batch {}-{}/{}'.format(
                    start_idx, start_idx + len(emb), len(emb_storage)),
                      flush=True,
                      end='')
            except StopIteration:
                break  # This just indicates the end of the dataset.

        if not args.quiet:
            print("Done with embedding, aggregating augmentations...",
                  flush=True)

        if len(modifiers) > 1:
            # Pull out the augmentations into a separate first dimension.
            emb_storage = emb_storage.reshape(len(data_fids), len(modifiers),
                                              -1)
            emb_storage = emb_storage.transpose((1, 0, 2))  # (Aug,FID,128D)

            # Store the embedding of all individual variants too.
            emb_dataset = f_out.create_dataset('emb_aug', data=emb_storage)

            # Aggregate according to the specified parameter.
            emb_storage = AGGREGATORS[args.aggregator](emb_storage)

        # Store the final embeddings.
        emb_dataset = f_out.create_dataset('emb', data=emb_storage)

        # Store information about the produced augmentation and in case no crop
        # augmentation was used, if the images are resized or avg pooled.
        f_out.create_dataset('augmentation_types',
                             data=np.asarray(modifiers, dtype='|S'))
Ejemplo n.º 13
0
def prepare_data(args):
    '''
    Data preparation for training

    :param args: all stored arguments
    :return: images: prepared images for training
            fid: figure id which means relative paths of images
            pid: person id (or car id) of each image
    '''
    # Load the data from the CSV file.
    # pids - person id (array corresponding to the images)
    # fids - array of the paths to the images ({str_})
    pids, fids = common.load_dataset(args.train_dataset, args.image_root,
                                     False)
    max_fid_len = max(map(len, fids))  # We'll need this later for logfiles.

    # Setup a tf.Dataset where one "epoch" loops over all PIDS.
    # PIDS are shuffled after every epoch and continue indefinitely.
    unique_pids = np.unique(pids)
    # dataset.output_types = float32, dataset.output_shape = len(unique_pids)
    dataset = tf.data.Dataset.from_tensor_slices(unique_pids)
    dataset = dataset.shuffle(len(unique_pids))

    # Constrain the dataset size to a multiple of the batch-size, so that
    # we don't get overlap at the end of each epoch.
    dataset = dataset.take((len(unique_pids) // args.batch_p) * args.batch_p)
    dataset = dataset.repeat(None)  # Repeat forever. Funny way of stating it.

    # For every PID, get K images.
    # default is args.batch_k = 4 (so it takes 4 images per each person)
    # dataset has len(unique_pids) tensors
    # create each tensor = (tensor_of_k_fids, tensor_of_pid)
    dataset = dataset.map(lambda pid: sample_k_fids_for_pid(
        pid, all_fids=fids, all_pids=pids, batch_k=args.batch_k))

    # Ungroup/flatten the batches for easy loading of the files.
    dataset = dataset.apply(tf.contrib.data.unbatch())

    # Convert filenames to actual image tensors.
    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)

    dataset = dataset.map(lambda fid, pid: common.fid_to_image(
        fid,
        pid,
        image_root=args.image_root,
        image_size=pre_crop_size if args.crop_augment else net_input_size),
                          num_parallel_calls=args.loading_threads)

    # Augment the data if specified by the arguments.
    if args.flip_augment:
        dataset = dataset.map(lambda im, fid, pid:
                              (tf.image.random_flip_left_right(im), fid, pid))
    if args.crop_augment:
        dataset = dataset.map(lambda im, fid, pid: (tf.random_crop(
            im, net_input_size + (3, )), fid, pid))

    # Group it back into PK batches.
    batch_size = args.batch_p * args.batch_k
    dataset = dataset.batch(batch_size)

    # Overlap producing and consuming for parallelism.
    dataset = dataset.prefetch(1)

    # Since we repeat the data infinitely, we only need a one-shot iterator.
    images, fids, pids = dataset.make_one_shot_iterator().get_next()
    return [images, fids, pids, max_fid_len]
def main():
    # Verify that parameters are set correctly.
    args = parser.parse_args()

    # Possibly auto-generate the output filename.
    if args.filename is None:
        basename = os.path.basename(args.dataset)
        args.filename = os.path.splitext(basename)[0] + '_embeddings.h5'
    args.filename = os.path.join(args.experiment_root, args.filename)

    # Load the args from the original experiment.
    args_file = os.path.join(args.experiment_root, 'args.json')

    if os.path.isfile(args_file):
        if not args.quiet:
            print('Loading args from {}.'.format(args_file))
        with open(args_file, 'r') as f:
            args_resumed = json.load(f)

        # Add arguments from training.
        for key, value in args_resumed.items():
            args.__dict__.setdefault(key, value)

        # A couple special-cases and sanity checks
        if (args_resumed['crop_augment']) == (args.crop_augment is None):
            print('WARNING: crop augmentation differs between training and '
                  'evaluation.')
        args.image_root = args.image_root or args_resumed['image_root']
    else:
        raise IOError('`args.json` could not be found in: {}'.format(args_file))

    # Check a proper aggregator is provided if augmentation is used.
    if args.flip_augment or args.crop_augment == 'five':
        if args.aggregator is None:
            print('ERROR: Test time augmentation is performed but no aggregator'
                  'was specified.')
            exit(1)
    else:
        if args.aggregator is not None:
            print('ERROR: No test time augmentation that needs aggregating is '
                  'performed but an aggregator was specified.')
            exit(1)

    if not args.quiet:
        print('Evaluating using the following parameters:')
        for key, value in sorted(vars(args).items()):
            print('{}: {}'.format(key, value))

    # Load the data from the CSV file.
    _, data_fids = common.load_dataset(args.dataset, args.image_root)

    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)
    data_fid = data_fids[10]

    # Setup a tf Dataset containing all images.
    dataset = tf.data.Dataset.from_tensor_slices(data_fids)
    image = common.fid_to_image(data_fid,'dummy',image_root=args.image_root,
                               image_size=pre_crop_size if args.crop_augment else net_input_size)

    # Convert filenames to actual image tensors.
    # dataset = dataset.map(
    #     lambda fid: common.fid_to_image(
    #         fid, 'dummy', image_root=args.image_root,
    #         image_size=pre_crop_size if args.crop_augment else net_input_size),
    #     num_parallel_calls=args.loading_threads)

    # Augment the data if specified by the arguments.
    # `modifiers` is a list of strings that keeps track of which augmentations
    # have been applied, so that a human can understand it later on.
    # modifiers = ['original']
    # if args.flip_augment:
    #     dataset = dataset.map(flip_augment)
    #     dataset = dataset.apply(tf.contrib.data.unbatch())
    #     modifiers = [o + m for m in ['', '_flip'] for o in modifiers]
    #
    # if args.crop_augment == 'center':
    #     dataset = dataset.map(lambda im, fid, pid:
    #         (five_crops(im, net_input_size)[0], fid, pid))
    #     modifiers = [o + '_center' for o in modifiers]
    # elif args.crop_augment == 'five':
    #     dataset = dataset.map(lambda im, fid, pid:
    #         (tf.stack(five_crops(im, net_input_size)), [fid]*5, [pid]*5))
    #     dataset = dataset.apply(tf.contrib.data.unbatch())
    #     modifiers = [o + m for o in modifiers for m in [
    #         '_center', '_top_left', '_top_right', '_bottom_left', '_bottom_right']]
    # elif args.crop_augment == 'avgpool':
    #     modifiers = [o + '_avgpool' for o in modifiers]
    # else:
    #     modifiers = [o + '_resize' for o in modifiers]

    # Group it back into PK batches.
    # dataset = dataset.batch(args.batch_size)
    #
    # # Overlap producing and consuming.
    # dataset = dataset.prefetch(1)
    #
    # images, _, _ = dataset.make_one_shot_iterator().get_next()

    # Create the model and an embedding head.
    model = import_module('nets.' + args.model_name)
    head = import_module('heads.' + args.head_name)
    image = tf.reshape(image[0],[1,224,224,3])
    endpoints, body_prefix = model.endpoints(image, is_training=False)
    with tf.name_scope('head'):
        endpoints = head.head(endpoints, args.embedding_dim, is_training=False)

    with tf.Session() as sess:
        checkpoint = os.path.join(args.experiment_root, args.checkpoint)
        tf.train.Saver().restore(sess,checkpoint)
        layer_name = ['Conv2d_1_pointwise', 'Conv2d_3_pointwise', 'Conv2d_5_pointwise', 'Conv2d_11_pointwise','Conv2d_13_pointwise']
        feature1,feature2,feature3,feature4,feature5 = sess.run([endpoints[layer_name[0]],
                                                                        endpoints[layer_name[1]],
                                                                        endpoints[layer_name[2]],
                                                                        endpoints[layer_name[3]],
                                                                        endpoints[layer_name[4]],
                                                                        ])

        features = [feature1,feature2,feature3,feature4,feature5]
        cols = 5
        rows = 1
        for layer,feature in zip(layer_name,features):
            # for feature in features:
            h = feature.shape[1]
            w = feature.shape[2]
            filter_show = cols
            img_grid = np.zeros((h*rows,w*cols))

            for c in range(filter_show):
                f_r = math.ceil((c + 1) / cols)
                f_c = (c + 1) if f_r == 1 else (c + 1 - (f_r - 1) * cols)
                img_grid[(f_r - 1) * h:f_r * h, (f_c - 1) * w:f_c * w] = feature[0, :, :, c]
            plt.figure()
            plt.imshow(img_grid,aspect='equal',cmap='viridis')
            plt.grid(False)
            plt.title(layer, fontsize=16)
        plt.show()
Ejemplo n.º 15
0
    def embed_csv_file(self,
                       image_root,
                       csv_file,
                       emb_file,
                       loading_threads=8,
                       flip=False,
                       crop=False,
                       aggregator='mean'):
        data_ids, data_fids, data_fols = common.load_dataset(
            csv_file, image_root)
        data_ids = data_ids.astype(np.int32)
        data_fols = data_fols.astype(np.int32)
        dataset = tf.data.Dataset.from_tensor_slices(data_fids)
        dataset = dataset.map(lambda fid: common.fid_to_image(
            fid,
            tf.constant('dummy'),
            image_root=image_root,
            image_size=self.pre_crop_size
            if self.crop_augment else self.net_input_size),
                              num_parallel_calls=loading_threads)

        # Augment the data if specified by the arguments.
        # `modifiers` is a list of strings that keeps track of which augmentations
        # have been applied, so that a human can understand it later on.
        modifiers = ['original']
        if flip:
            dataset = dataset.map(Embedder.flip_augment)
            dataset = dataset.apply(tf.contrib.data.unbatch())
            modifiers = [o + m for m in ['', '_flip'] for o in modifiers]
        print(flip, crop)
        if crop == 'center':
            dataset = dataset.map(lambda im, fid, pid: (five_crops(
                im, net_input_size)[0], fid, pid))
            modifiers = [o + '_center' for o in modifiers]
        elif crop == 'five':
            dataset = dataset.map(lambda im, fid, pid: (
                tf.stack(Embedder.five_crops(im, self.net_input_size)),
                tf.stack([fid] * 5), tf.stack([pid] * 5)))
            dataset = dataset.apply(tf.contrib.data.unbatch())
            modifiers = [
                o + m for o in modifiers for m in [
                    '_center', '_top_left', '_top_right', '_bottom_left',
                    '_bottom_right'
                ]
            ]
        elif crop == 'avgpool':
            modifiers = [o + '_avgpool' for o in modifiers]
        else:
            modifiers = [o + '_resize' for o in modifiers]

        # Group it back into PK batches.
        dataset = dataset.batch(self.batch_size)

        # Overlap producing and consuming.
        dataset = dataset.prefetch(1)

        images, _, _ = dataset.make_one_shot_iterator().get_next()

        endpoints, body_prefix = self.model.endpoints(images,
                                                      is_training=False)
        with tf.name_scope('head'):
            endpoints = self.head.head(endpoints,
                                       self.embedding_dim,
                                       is_training=False)

#         emb_file = os.path.join(self.exp_root, emb_file)
        print("Save h5 file to: ", emb_file)
        with h5py.File(emb_file, 'w') as f_out, tf.Session() as sess:
            # Initialize the network/load the checkpoint.
            checkpoint = tf.train.latest_checkpoint(self.exp_root)
            print('Restoring from checkpoint: {}'.format(checkpoint))
            tf.train.Saver().restore(sess, checkpoint)

            # Go ahead and embed the whole dataset, with all augmented versions too.
            emb_storage = np.zeros(
                (len(data_fids) * len(modifiers), self.embedding_dim),
                np.float32)
            for start_idx in count(step=self.batch_size):
                try:
                    emb = sess.run(endpoints['emb'])
                    print('\rEmbedded batch {}-{}/{}'.format(
                        start_idx, start_idx + len(emb), len(emb_storage)),
                          flush=True,
                          end='')
                    emb_storage[start_idx:start_idx + len(emb)] = emb
                except tf.errors.OutOfRangeError:
                    break  # This just indicates the end of the dataset.

            print()
            print("Done with embedding, aggregating augmentations...",
                  flush=True)
            print(emb_storage.shape)
            if len(modifiers) > 1:
                # Pull out the augmentations into a separate first dimension.
                emb_storage = emb_storage.reshape(len(data_fids),
                                                  len(modifiers), -1)
                emb_storage = emb_storage.transpose(
                    (1, 0, 2))  # (Aug,FID,128D)

                # Store the embedding of all individual variants too.
                emb_dataset = f_out.create_dataset('emb_aug', data=emb_storage)

                # Aggregate according to the specified parameter.
                emb_storage = AGGREGATORS[aggregator](emb_storage)
            print(emb_storage.shape)
            # Store the final embeddings.
            f_out.create_dataset('emb', data=emb_storage)
            f_out.create_dataset('id', data=data_ids)
            f_out.create_dataset('fol_id', data=data_fols)

            # Store information about the produced augmentation and in case no crop
            # augmentation was used, if the images are resized or avg pooled.
            f_out.create_dataset('augmentation_types',
                                 data=np.asarray(modifiers, dtype='|S'))

        tf.reset_default_graph()
Ejemplo n.º 16
0
def main():
    # Verify that parameters are set correctly.
    args = parser.parse_args()

    # Possibly auto-generate the output filename.
    if args.filename is None:
        basename = os.path.basename(args.dataset)
        args.filename = os.path.splitext(basename)[0] + '_embeddings.h5'
    args.filename = os.path.join(args.experiment_root, args.filename)

    # Load the args from the original experiment.
    args_file = os.path.join(args.experiment_root, 'args.json')

    if os.path.isfile(args_file):
        if not args.quiet:
            print('Loading args from {}.'.format(args_file))
        with open(args_file, 'r') as f:
            args_resumed = json.load(f)

        # Add arguments from training.
        for key, value in args_resumed.items():
            args.__dict__.setdefault(key, value)

        # A couple special-cases and sanity checks
        if (args_resumed['crop_augment']) == (args.crop_augment is None):
            print('WARNING: crop augmentation differs between training and '
                  'evaluation.')
        args.image_root = args.image_root or args_resumed['image_root']
    else:
        raise IOError(
            '`args.json` could not be found in: {}'.format(args_file))

    # Check a proper aggregator is provided if augmentation is used.
    if args.flip_augment or args.crop_augment == 'five':
        if args.aggregator is None:
            print(
                'ERROR: Test time augmentation is performed but no aggregator'
                'was specified.')
            exit(1)
    else:
        if args.aggregator is not None:
            print('ERROR: No test time augmentation that needs aggregating is '
                  'performed but an aggregator was specified.')
            exit(1)

    if not args.quiet:
        print('Evaluating using the following parameters:')
        for key, value in sorted(vars(args).items()):
            print('{}: {}'.format(key, value))

    # Load the data from the txt file.
    data_fids = common.load_from_txt(args.dataset, args.image_root)

    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)

    # Setup a tf Dataset containing all images.
    dataset = tf.data.Dataset.from_tensor_slices(data_fids)

    # Convert filenames to actual image tensors.
    dataset = dataset.map(lambda fid: common.fid_to_image(
        fid,
        'dummy',
        image_root=args.image_root,
        image_size=pre_crop_size if args.crop_augment else net_input_size),
                          num_parallel_calls=args.loading_threads)

    # Augment the data if specified by the arguments.
    # `modifiers` is a list of strings that keeps track of which augmentations
    # have been applied, so that a human can understand it later on.
    modifiers = ['original']
    if args.flip_augment:
        dataset = dataset.map(flip_augment)
        dataset = dataset.apply(tf.contrib.data.unbatch())
        modifiers = [o + m for m in ['', '_flip'] for o in modifiers]

    if args.crop_augment == 'center':
        dataset = dataset.map(lambda im, fid, pid:
                              (five_crops(im, net_input_size)[0], fid, pid))
        modifiers = [o + '_center' for o in modifiers]
    elif args.crop_augment == 'five':
        dataset = dataset.map(lambda im, fid, pid: (tf.stack(
            five_crops(im, net_input_size)), [fid] * 5, [pid] * 5))
        dataset = dataset.apply(tf.contrib.data.unbatch())
        modifiers = [
            o + m for o in modifiers for m in [
                '_center', '_top_left', '_top_right', '_bottom_left',
                '_bottom_right'
            ]
        ]
    elif args.crop_augment == 'avgpool':
        modifiers = [o + '_avgpool' for o in modifiers]
    else:
        modifiers = [o + '_resize' for o in modifiers]

    # Group it back into PK batches.
    dataset = dataset.batch(args.batch_size)

    # Overlap producing and consuming.
    dataset = dataset.prefetch(1)

    images, _, _ = dataset.make_one_shot_iterator().get_next()

    # Create the model and an embedding head.
    model = import_module('nets.' + args.model_name)
    head = import_module('heads.' + args.head_name)

    endpoints, body_prefix = model.endpoints(images, is_training=False)
    with tf.name_scope('head'):
        endpoints = head.head(endpoints, args.embedding_dim, is_training=False)

    with h5py.File(args.filename,
                   'w') as f_out, tf.Session(config=config) as sess:
        # Initialize the network/load the checkpoint.
        if args.checkpoint is None:
            checkpoint = tf.train.latest_checkpoint(args.experiment_root)
        else:
            checkpoint = os.path.join(args.experiment_root, args.checkpoint)
        if not args.quiet:
            print('Restoring from checkpoint: {}'.format(checkpoint))
        tf.train.Saver().restore(sess, checkpoint)

        # Go ahead and embed the whole dataset, with all augmented versions too.
        emb_storage = np.zeros(
            (len(data_fids) * len(modifiers), args.embedding_dim), np.float32)
        for start_idx in count(step=args.batch_size):
            try:
                emb = sess.run(endpoints['emb'])
                print('\rEmbedded batch {}-{}/{}'.format(
                    start_idx, start_idx + len(emb), len(emb_storage)),
                      flush=True,
                      end='')
                emb_storage[start_idx:start_idx + len(emb)] = emb
            except tf.errors.OutOfRangeError:
                break  # This just indicates the end of the dataset.

        print()
        if not args.quiet:
            print("Done with embedding, aggregating augmentations...",
                  flush=True)

        if len(modifiers) > 1:
            # Pull out the augmentations into a separate first dimension.
            emb_storage = emb_storage.reshape(len(data_fids), len(modifiers),
                                              -1)
            emb_storage = emb_storage.transpose((1, 0, 2))  # (Aug,FID,128D)

            # Store the embedding of all individual variants too.
            emb_dataset = f_out.create_dataset('emb_aug', data=emb_storage)

            # Aggregate according to the specified parameter.
            emb_storage = AGGREGATORS[args.aggregator](emb_storage)

        # Store the final embeddings.
        emb_dataset = f_out.create_dataset('emb', data=emb_storage)

        # Store information about the produced augmentation and in case no crop
        # augmentation was used, if the images are resized or avg pooled.
        f_out.create_dataset('augmentation_types',
                             data=np.asarray(modifiers, dtype='|S'))
Ejemplo n.º 17
0
def main(argv):

    args = parser.parse_args(argv)

    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    # tf.compat.v1.disable_eager_execution()

    # physical_devices = tf.config.experimental.list_physical_devices('GPU')
    # tf.config.experimental.set_memory_growth(physical_devices[0], True)

    # We store all arguments in a json file. This has two advantages:
    # 1. We can always get back and see what exactly that experiment was
    # 2. We can resume an experiment as-is without needing to remember all flags.
    args_file = os.path.join(args.experiment_root, 'args.json')
    if args.resume:
        if not os.path.isfile(args_file):
            raise IOError('`args.json` not found in {}'.format(args_file))

        print('Loading args from {}.'.format(args_file))
        with open(args_file, 'r') as f:
            args_resumed = json.load(f)
        args_resumed['resume'] = True  # This would be overwritten.

        # When resuming, we not only want to populate the args object with the
        # values from the file, but we also want to check for some possible
        # conflicts between loaded and given arguments.
        for key, value in args.__dict__.items():
            if key in args_resumed:
                resumed_value = args_resumed[key]
                if resumed_value != value:
                    print('Warning: For the argument `{}` we are using the'
                          ' loaded value `{}`. The provided value was `{}`'
                          '.'.format(key, resumed_value, value))
                    args.__dict__[key] = resumed_value
            else:
                print('Warning: A new argument was added since the last run:'
                      ' `{}`. Using the new value: `{}`.'.format(key, value))

    else:
        # If the experiment directory exists already, we bail in fear.
        if os.path.exists(args.experiment_root):
            if os.listdir(args.experiment_root):
                print('The directory {} already exists and is not empty.'
                      ' If you want to resume training, append --resume to'
                      ' your call.'.format(args.experiment_root))
                exit(1)
        else:
            os.makedirs(args.experiment_root)

        # Store the passed arguments for later resuming and grepping in a nice
        # and readable format.
        with open(args_file, 'w') as f:
            json.dump(vars(args),
                      f,
                      ensure_ascii=False,
                      indent=2,
                      sort_keys=True)

    log_file = os.path.join(args.experiment_root, "train")
    logging.config.dictConfig(common.get_logging_dict(log_file))
    log = logging.getLogger('train')

    # Also show all parameter values at the start, for ease of reading logs.
    log.info('Training using the following parameters:')
    for key, value in sorted(vars(args).items()):
        log.info('{}: {}'.format(key, value))

    # Check them here, so they are not required when --resume-ing.
    if not args.train_set:
        parser.print_help()
        log.error("You did not specify the `train_set` argument!")
        sys.exit(1)
    if not args.image_root:
        parser.print_help()
        log.error("You did not specify the required `image_root` argument!")
        sys.exit(1)

    # Load the data from the CSV file.
    pids, fids = common.load_dataset(args.train_set, args.image_root)
    max_fid_len = max(map(len, fids))  # We'll need this later for logfiles.

    # Setup a tf.Dataset where one "epoch" loops over all PIDS.
    # PIDS are shuffled after every epoch and continue indefinitely.
    unique_pids = np.unique(pids)
    if len(unique_pids) < args.batch_p:
        unique_pids = np.tile(unique_pids,
                              int(np.ceil(args.batch_p / len(unique_pids))))
    dataset = tf.data.Dataset.from_tensor_slices(unique_pids)
    dataset = dataset.shuffle(len(unique_pids))

    # Constrain the dataset size to a multiple of the batch-size, so that
    # we don't get overlap at the end of each epoch.
    dataset = dataset.take((len(unique_pids) // args.batch_p) * args.batch_p)
    dataset = dataset.repeat(None)  # Repeat forever. Funny way of stating it.

    # For every PID, get K images.
    dataset = dataset.map(lambda pid: sample_k_fids_for_pid(
        pid, all_fids=fids, all_pids=pids, batch_k=args.batch_k))

    # Ungroup/flatten the batches for easy loading of the files.
    dataset = dataset.unbatch()

    # Convert filenames to actual image tensors.
    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)

    dataset = dataset.map(lambda fid, pid: common.fid_to_image(
        fid,
        pid,
        image_root=args.image_root,
        image_size=pre_crop_size if args.crop_augment else net_input_size),
                          num_parallel_calls=args.loading_threads)

    # Augment the data if specified by the arguments.

    dataset = dataset.map(
        lambda im, fid, pid: common.fid_to_image(fid,
                                                 pid,
                                                 image_root=args.image_root,
                                                 image_size=pre_crop_size
                                                 if args.crop_augment else
                                                 net_input_size),  # Ergys
        num_parallel_calls=args.loading_threads)

    if args.flip_augment:
        dataset = dataset.map(lambda im, fid, pid:
                              (tf.image.random_flip_left_right(im), fid, pid))
    if args.crop_augment:
        dataset = dataset.map(lambda im, fid, pid: (tf.image.random_crop(
            im, net_input_size + (3, )), fid, pid))

    # Create the model and an embedding head.
    tf.keras.backend.set_learning_phase(1)
    emb_model = EmbeddingModel(args)

    # Group it back into PK batches.
    batch_size = args.batch_p * args.batch_k
    dataset = dataset.map(lambda im, fid, pid:
                          (emb_model.preprocess_input(im), fid, pid))
    dataset = dataset.batch(batch_size)

    # Overlap producing and consuming for parallelism.
    dataset = dataset.prefetch(1)

    # Since we repeat the data infinitely, we only need a one-shot iterator.

    # Feed the image through the model. The returned `body_prefix` will be used
    # further down to load the pre-trained weights for all variables with this
    # prefix.

    # all_trainable_variables = embedding_head.trainable_variables+base_model.trainable_variables

    # Define the optimizer and the learning-rate schedule.
    # Unfortunately, we get NaNs if we don't handle no-decay separately.
    if 0 <= args.decay_start_iteration < args.train_iterations:
        learning_rate = tf.optimizers.schedules.PolynomialDecay(
            args.learning_rate, args.train_iterations, end_learning_rate=1e-7)
    else:
        learning_rate = args.learning_rate

    if args.optimizer == 'adam':
        optimizer = tf.keras.optimizers.Adam(learning_rate)
    elif args.optimizer == 'momentum':
        optimizer = tf.keras.optimizers.SGD(learning_rate, momentum=0.9)
    else:
        raise NotImplementedError('Invalid optimizer {}'.format(
            args.optimizer))

    @tf.function
    def train_step(images, pids):

        with tf.GradientTape() as tape:
            batch_embedding = emb_model(images)
            if args.loss == 'semi_hard_triplet':
                embedding_loss = triplet_semihard_loss(batch_embedding, pids,
                                                       args.margin)
            elif args.loss == 'hard_triplet':
                embedding_loss = batch_hard(batch_embedding, pids, args.margin,
                                            args.metric)
            elif args.loss == 'lifted_loss':
                embedding_loss = lifted_loss(pids,
                                             batch_embedding,
                                             margin=args.margin)
            elif args.loss == 'contrastive_loss':
                assert batch_size % 2 == 0
                assert args.batch_k == 4  ## Can work with other number but will need tuning

                contrastive_idx = np.tile([0, 1, 4, 3, 2, 5, 6, 7],
                                          args.batch_p // 2)
                for i in range(args.batch_p // 2):
                    contrastive_idx[i * 8:i * 8 + 8] += i * 8

                contrastive_idx = np.expand_dims(contrastive_idx, 1)
                batch_embedding_ordered = tf.gather_nd(batch_embedding,
                                                       contrastive_idx)
                pids_ordered = tf.gather_nd(pids, contrastive_idx)
                # batch_embedding_ordered = tf.Print(batch_embedding_ordered,[pids_ordered],'pids_ordered :: ',summarize=1000)
                embeddings_anchor, embeddings_positive = tf.unstack(
                    tf.reshape(batch_embedding_ordered,
                               [-1, 2, args.embedding_dim]), 2, 1)
                # embeddings_anchor = tf.Print(embeddings_anchor,[pids_ordered,embeddings_anchor,embeddings_positive,batch_embedding,batch_embedding_ordered],"Tensors ", summarize=1000)

                fixed_labels = np.tile([1, 0, 0, 1], args.batch_p // 2)
                # fixed_labels = np.reshape(fixed_labels,(len(fixed_labels),1))
                # print(fixed_labels)
                labels = tf.constant(fixed_labels)
                # labels = tf.Print(labels,[labels],'labels ',summarize=1000)
                embedding_loss = contrastive_loss(labels,
                                                  embeddings_anchor,
                                                  embeddings_positive,
                                                  margin=args.margin)
            elif args.loss == 'angular_loss':
                embeddings_anchor, embeddings_positive = tf.unstack(
                    tf.reshape(batch_embedding, [-1, 2, args.embedding_dim]),
                    2, 1)
                # pids = tf.Print(pids, [pids], 'pids:: ', summarize=100)
                pids, _ = tf.unstack(tf.reshape(pids, [-1, 2, 1]), 2, 1)
                # pids = tf.Print(pids,[pids],'pids:: ',summarize=100)
                embedding_loss = angular_loss(pids,
                                              embeddings_anchor,
                                              embeddings_positive,
                                              batch_size=args.batch_p,
                                              with_l2reg=True)

            elif args.loss == 'npairs_loss':
                assert args.batch_k == 2  ## Single positive pair per class
                embeddings_anchor, embeddings_positive = tf.unstack(
                    tf.reshape(batch_embedding, [-1, 2, args.embedding_dim]),
                    2, 1)
                pids, _ = tf.unstack(tf.reshape(pids, [-1, 2, 1]), 2, 1)
                pids = tf.reshape(pids, [-1])
                embedding_loss = npairs_loss(pids, embeddings_anchor,
                                             embeddings_positive)

            else:
                raise NotImplementedError('Invalid Loss {}'.format(args.loss))
            loss_mean = tf.reduce_mean(embedding_loss)

        gradients = tape.gradient(loss_mean, emb_model.trainable_variables)
        optimizer.apply_gradients(zip(gradients,
                                      emb_model.trainable_variables))

        return embedding_loss

    # sess = tf.compat.v1.Session()
    # start_step = sess.run(global_step)
    # checkpoint_saver = tf.train.Saver(max_to_keep=2)
    start_step = 0
    log.info('Starting training from iteration {}.'.format(start_step))
    dataset_iter = iter(dataset)

    ckpt = tf.train.Checkpoint(step=tf.Variable(1),
                               optimizer=optimizer,
                               net=emb_model)
    manager = tf.train.CheckpointManager(ckpt,
                                         osp.join(args.experiment_root,
                                                  'tf_ckpts'),
                                         max_to_keep=3)

    ckpt.restore(manager.latest_checkpoint)
    if manager.latest_checkpoint:
        print("Restored from {}".format(manager.latest_checkpoint))
    else:
        print("Initializing from scratch.")

    with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u:
        for i in range(ckpt.step.numpy(), args.train_iterations):
            # for batch_idx, batch in enumerate():
            start_time = time.time()
            images, fids, pids = next(dataset_iter)
            batch_loss = train_step(images, pids)
            elapsed_time = time.time() - start_time
            seconds_todo = (args.train_iterations - i) * elapsed_time
            # print(tf.reduce_min(batch_loss).numpy(),tf.reduce_mean(batch_loss).numpy(),tf.reduce_max(batch_loss).numpy())
            log.info(
                'iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, ETA: {} ({:.2f}s/it)'
                .format(
                    i,
                    tf.reduce_min(batch_loss).numpy(),
                    tf.reduce_mean(batch_loss).numpy(),
                    tf.reduce_max(batch_loss).numpy(),
                    # args.batch_k - 1, float(b_prec_at_k),
                    timedelta(seconds=int(seconds_todo)),
                    elapsed_time))

            ckpt.step.assign_add(1)
            if (args.checkpoint_frequency > 0
                    and i % args.checkpoint_frequency == 0):

                # uncomment if you want to save the model weight separately
                # emb_model.save_weights(os.path.join(args.experiment_root, 'model_weights_{0:04d}.w'.format(i)))

                manager.save()

            # Stop the main-loop at the end of the step, if requested.
            if u.interrupted:
                log.info("Interrupted on request!")
                break