def main():
    args = parser.parse_args()

    # We store all arguments in a json file. This has two advantages:
    # 1. We can always get back and see what exactly that experiment was
    # 2. We can resume an experiment as-is without needing to remember all flags.
    args_file = os.path.join(args.experiment_root, 'args.json')
    if args.resume:
        if not os.path.isfile(args_file):
            raise IOError('`args.json` not found in {}'.format(args_file))

        print('Loading args from {}.'.format(args_file))
        with open(args_file, 'r') as f:
            args_resumed = json.load(f)
        args_resumed['resume'] = True  # This would be overwritten.

        # When resuming, we not only want to populate the args object with the
        # values from the file, but we also want to check for some possible
        # conflicts between loaded and given arguments.
        for key, value in args.__dict__.items():
            if key in args_resumed:
                resumed_value = args_resumed[key]
                if resumed_value != value:
                    print('Warning: For the argument `{}` we are using the'
                          ' loaded value `{}`. The provided value was `{}`'
                          '.'.format(key, resumed_value, value))
                    args.__dict__[key] = resumed_value
                print('Warning: A new argument was added since the last run:'
                      ' `{}`. Using the new value: `{}`.'.format(key, value))

        # If the experiment directory exists already, we bail in fear.
        if os.path.exists(args.experiment_root):
            if os.listdir(args.experiment_root):
                print('The directory {} already exists and is not empty.'
                      ' If you want to resume training, append --resume to'
                      ' your call.'.format(args.experiment_root))

        # Store the passed arguments for later resuming and grepping in a nice
        # and readable format.
        with open(args_file, 'w') as f:

    log_file = os.path.join(args.experiment_root, "train")
    log = logging.getLogger('train')

    # Also show all parameter values at the start, for ease of reading logs.
    log.info('Training using the following parameters:')
    for key, value in sorted(vars(args).items()):
        log.info('{}: {}'.format(key, value))

    # Check them here, so they are not required when --resume-ing.
    if not args.train_set:
        log.error("You did not specify the `train_set` argument!")
    if not args.image_root:
        log.error("You did not specify the required `image_root` argument!")

    # Load the data from the CSV file.
    pids, fids = common.load_dataset(args.train_set, args.image_root)
    max_fid_len = max(map(len, fids))  # We'll need this later for logfiles.

    # Setup a tf.Dataset where one "epoch" loops over all PIDS.
    # PIDS are shuffled after every epoch and continue indefinitely.
    unique_pids = np.unique(pids)
    dataset = tf.data.Dataset.from_tensor_slices(unique_pids)
    dataset = dataset.shuffle(len(unique_pids))

    # Constrain the dataset size to a multiple of the batch-size, so that
    # we don't get overlap at the end of each epoch.
    dataset = dataset.take((len(unique_pids) // args.batch_p) * args.batch_p)
    dataset = dataset.repeat(None)  # Repeat forever. Funny way of stating it.

    # For every PID, get K images.
    dataset = dataset.map(lambda pid: sample_k_fids_for_pid(
        pid, all_fids=fids, all_pids=pids, batch_k=args.batch_k))

    # Ungroup/flatten the batches for easy loading of the files.
    dataset = dataset.apply(tf.contrib.data.unbatch())

    # Convert filenames to actual image tensors.
    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)
    dataset = dataset.map(lambda fid, pid: common.fid_to_image(
        image_size=pre_crop_size if args.crop_augment else net_input_size),

    # Augment the data if specified by the arguments.
    if args.flip_augment:
        dataset = dataset.map(lambda im, fid, pid:
                              (tf.image.random_flip_left_right(im), fid, pid))
    if args.crop_augment:
        dataset = dataset.map(lambda im, fid, pid: (tf.random_crop(
            im, net_input_size + (3, )), fid, pid))

    # Group it back into PK batches.
    batch_size = args.batch_p * args.batch_k
    dataset = dataset.batch(batch_size)

    # Overlap producing and consuming for parallelism.
    dataset = dataset.prefetch(1)

    # Since we repeat the data infinitely, we only need a one-shot iterator.
    images, fids, pids = dataset.make_one_shot_iterator().get_next()

    # Create the model and an embedding head.
    model = import_module('nets.' + args.model_name)
    head = import_module('heads.' + args.head_name)

    # Feed the image through the model. The returned `body_prefix` will be used
    # further down to load the pre-trained weights for all variables with this
    # prefix.
    endpoints, body_prefix = model.endpoints(images, is_training=True)
    with tf.name_scope('head'):
        endpoints = head.head(endpoints, args.embedding_dim, is_training=True)

    # Create the loss in two steps:
    # 1. Compute all pairwise distances according to the specified metric.
    # 2. For each anchor along the first dimension, compute its loss.
    dists = loss.cdist(endpoints['emb'], endpoints['emb'], metric=args.metric)
    ln = loss.LOSS_CHOICES
    losses, train_top1, prec_at_k, _, neg_dists, pos_dists = ln['batch_soft'](
        dists, pids, args.margin, batch_precision_at_k=args.batch_k - 1)
    losses1, train_top1, prec_at_k, _, neg_dists, pos_dists = ln['batch_hard'](
        dists, pids, args.margin, batch_precision_at_k=args.batch_k - 1)

    print('kaju', ln['batch_soft'])

    decDense = tf.layers.dense(
        inputs=endpoints['emb'], units=5120,
        name='decDense')  #  ,activation = tf.nn.relu  ################
    unflat = tf.reshape(decDense, shape=[tf.shape(decDense)[0], 32, 16, 10])
    unp3shape = tf.TensorShape(
        [2 * di for di in unflat.get_shape().as_list()[1:-1]])
    unPool3 = tf.image.resize_nearest_neighbor(unflat,

    deConv3 = tf.layers.conv2d(inputs=unPool3,
                               kernel_size=[5, 5],
                               strides=(1, 1),
    unp2shape = tf.TensorShape(
        [2 * di for di in deConv3.get_shape().as_list()[1:-1]])
    unPool2 = tf.image.resize_nearest_neighbor(deConv3,
    deConv2 = tf.layers.conv2d(inputs=unPool2,
                               kernel_size=[5, 5],
                               strides=(1, 1),
    unp1shape = tf.TensorShape(
        [2 * di for di in deConv2.get_shape().as_list()[1:-1]])
    unPool1 = tf.image.resize_nearest_neighbor(deConv2,
    deConv1 = tf.layers.conv2d(inputs=unPool1,
                               kernel_size=[5, 5],
                               strides=(1, 1),
    imClip = deConv1  #tf.clip_by_value(t = deConv1,clip_value_min = -1.0,clip_value_max = 1.0,name='clipRelu')
    print('RconstructeddImage :  ', imClip.name)
    recLoss = tf.losses.mean_squared_error(
    print('recLoss :  ', recLoss.name)
    toshow = tf.clip_by_value(t=imClip,
                              name='clipRelu')  #tf.cast(,tf.uint8)
    toshowOrg = tf.cast(
        tf.clip_by_value(t=images, clip_value_min=0, clip_value_max=255.0),

    decDense2 = tf.layers.dense(
        inputs=endpoints['emb'], units=1280,
        name='decDense2')  #  ,activation = tf.nn.relu  ################
    unflat2 = tf.reshape(decDense2, shape=[tf.shape(decDense2)[0], 16, 8, 10])
    unp3shape2 = tf.TensorShape(
        [2 * di for di in unflat2.get_shape().as_list()[1:-1]])
    unPool32 = tf.image.resize_nearest_neighbor(unflat2,

    deConv32 = tf.layers.conv2d(inputs=unPool32,
                                kernel_size=[5, 5],
                                strides=(1, 1),
    unp2shape2 = tf.TensorShape(
        [2 * di for di in deConv32.get_shape().as_list()[1:-1]])
    unPool22 = tf.image.resize_nearest_neighbor(deConv32,
    print('a', deConv32.shape)
    deConv22 = tf.layers.conv2d(inputs=unPool22,
                                kernel_size=[5, 5],
                                strides=(1, 1),
    unp1shape2 = tf.TensorShape(
        [2 * di for di in deConv22.get_shape().as_list()[1:-1]])
    unPool12 = tf.image.resize_nearest_neighbor(deConv22,
    print('k', deConv22.shape)
    deConv12 = tf.layers.conv2d(inputs=unPool12,
                                kernel_size=[5, 5],
                                strides=(1, 1),
    imClip1 = deConv12
    print('RconstructeddImage :  ', imClip1.name)

    images2 = tf.image.resize_images(images, [128, 64])

    recLoss2 = tf.losses.mean_squared_error(

    decDense3 = tf.layers.dense(
        inputs=endpoints['emb'], units=320,
        name='decDense3')  #  ,activation = tf.nn.relu  ################
    unflat23 = tf.reshape(decDense3, shape=[tf.shape(decDense3)[0], 8, 4, 10])
    unp3shape23 = tf.TensorShape(
        [2 * di for di in unflat23.get_shape().as_list()[1:-1]])
    unPool323 = tf.image.resize_nearest_neighbor(unflat23,

    deConv323 = tf.layers.conv2d(inputs=unPool323,
                                 kernel_size=[5, 5],
                                 strides=(1, 1),
    unp2shape23 = tf.TensorShape(
        [2 * di for di in deConv323.get_shape().as_list()[1:-1]])
    unPool223 = tf.image.resize_nearest_neighbor(deConv323,
    deConv223 = tf.layers.conv2d(inputs=unPool223,
                                 kernel_size=[5, 5],
                                 strides=(1, 1),
    unp1shape23 = tf.TensorShape(
        [2 * di for di in deConv223.get_shape().as_list()[1:-1]])
    unPool123 = tf.image.resize_nearest_neighbor(deConv223,
    deConv123 = tf.layers.conv2d(inputs=unPool123,
                                 kernel_size=[5, 5],
                                 strides=(1, 1),
    imClip2 = deConv123
    images3 = tf.image.resize_images(images, [64, 32])

    recLoss3 = tf.losses.mean_squared_error(

    # Count the number of active entries, and compute the total batch loss.
    num_active = tf.reduce_sum(tf.cast(tf.greater(losses, 1e-5), tf.float32))
    loss_mean = tf.reduce_mean(losses)
    loss_mean1 = tf.reduce_mean(losses1)
    recMean = tf.reduce_mean(imClip)

    # Some logging for tensorboard.
    tf.summary.histogram('loss_distribution', losses)
    tf.summary.scalar('loss', loss_mean)
    tf.summary.scalar('recLoss', recLoss)
    tf.summary.scalar('recImagemean', recMean)
    tf.summary.scalar('batch_top1', train_top1)
    tf.summary.scalar('batch_prec_at_{}'.format(args.batch_k - 1), prec_at_k)
    tf.summary.scalar('active_count', num_active)
    tf.summary.histogram('embedding_dists', dists)
    tf.summary.histogram('embedding_pos_dists', pos_dists)
    tf.summary.histogram('embedding_neg_dists', neg_dists)
                         tf.norm(endpoints['emb_raw'], axis=1))
    tf.summary.image('Batchimage', toshow, max_outputs=4)
    tf.summary.image('Batchimage', toshowOrg, max_outputs=4)

    # Create the mem-mapped arrays in which we'll log all training detail in
    # addition to tensorboard, because tensorboard is annoying for detailed
    # inspection and actually discards data in histogram summaries.
    if args.detailed_logs:
        log_embs = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'embeddings'),
            shape=(args.train_iterations, batch_size, args.embedding_dim))
        log_loss = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'losses'),
            shape=(args.train_iterations, batch_size))
        log_fids = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'fids'),
            dtype='S' + str(max_fid_len),
            shape=(args.train_iterations, batch_size))

    # These are collected here before we add the optimizer, because depending
    # on the optimizer, it might add extra slots, which are also global
    # variables, with the exact same prefix.
    model_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,

    # Define the optimizer and the learning-rate schedule.
    # Unfortunately, we get NaNs if we don't handle no-decay separately.
    global_step = tf.Variable(0, name='global_step', trainable=False)
    if 0 <= args.decay_start_iteration < args.train_iterations:
        learning_rate = tf.train.exponential_decay(
            tf.maximum(0, global_step - args.decay_start_iteration),
            args.train_iterations - args.decay_start_iteration, 0.001)
        learning_rate = args.learning_rate
    tf.summary.scalar('learning_rate', learning_rate)
    optimizer = tf.train.AdamOptimizer(learning_rate)
    # Feel free to try others!
    # optimizer = tf.train.AdadeltaOptimizer(learning_rate)

    # Update_ops are used to update batchnorm stats.
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        train_op = optimizer.minimize(tf.add(
                recLoss / (255.0 * 255.0),
                tf.add(recLoss2 / (255.0 * 255.0),
                       recLoss3 / (255.0 * 255.0)))),
        train_op1 = optimizer.minimize(tf.add(
                recLoss / (255.0 * 255.0),
                tf.add(recLoss2 / (255.0 * 255.0),
                       recLoss3 / (255.0 * 255.0)))),

    # Define a saver for the complete model.
    checkpoint_saver = tf.train.Saver(max_to_keep=0)

    with tf.Session() as sess:
        if args.resume:
            # In case we're resuming, simply load the full checkpoint to init.
            last_checkpoint = tf.train.latest_checkpoint(args.experiment_root)
            log.info('Restoring from checkpoint: {}'.format(last_checkpoint))
            checkpoint_saver.restore(sess, last_checkpoint)
            # But if we're starting from scratch, we may need to load some
            # variables from the pre-trained weights, and random init others.
            if args.initial_checkpoint is not None:
                saver = tf.train.Saver(model_variables)
                saver.restore(sess, args.initial_checkpoint)

            # In any case, we also store this initialization as a checkpoint,
            # such that we could run exactly reproduceable experiments.

        merged_summary = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(args.experiment_root,

        start_step = sess.run(global_step)
        log.info('Starting training from iteration {}.'.format(start_step))

        # Finally, here comes the main-loop. This `Uninterrupt` is a handy
        # utility such that an iteration still finishes on Ctrl+C and we can
        # stop the training cleanly.
        with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u:
            for i in range(start_step, args.train_iterations):

                # Compute gradients, update weights, store logs!
                start_time = time.time()
                if i < 3000:
                    losses = losses1
                    train_op = train_op1
                    losses = losses
                    train_op = train_op

                _, summary, step, b_prec_at_k, b_embs, b_loss, b_fids ,b_rec= \
                    sess.run([train_op, merged_summary, global_step,
                              prec_at_k, endpoints['emb'], losses, fids,recLoss])
                elapsed_time = time.time() - start_time

                # Compute the iteration speed and add it to the summary.
                # We did observe some weird spikes that we couldn't track down.
                summary2 = tf.Summary()
                summary_writer.add_summary(summary2, step)
                summary_writer.add_summary(summary, step)

                if args.detailed_logs:
                    log_embs[i], log_loss[i], log_fids[
                        i] = b_embs, b_loss, b_fids

                # Do a huge print out of the current progress.
                seconds_todo = (args.train_iterations - step) * elapsed_time
                    'iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, '
                    'recLoss: {:.3f} batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it)'.
                    format(step, float(np.min(b_loss)), float(np.mean(b_loss)),
                           float(np.max(b_loss)), b_rec, args.batch_k - 1,
                           timedelta(seconds=int(seconds_todo)), elapsed_time))

                # Save a checkpoint of training every so often.
                if (args.checkpoint_frequency > 0
                        and step % args.checkpoint_frequency == 0):

                # Stop the main-loop at the end of the step, if requested.
                if u.interrupted:
                    log.info("Interrupted on request!")

        # Store one final checkpoint. This might be redundant, but it is crucial
        # in case intermediate storing was disabled and it saves a checkpoint
        # when the process was interrupted.
                              os.path.join(args.experiment_root, 'checkpoint'),
def main(argv):
    # Verify that parameters are set correctly.
    args = parser.parse_args(argv)

    if not os.path.exists(args.dataset):

    # Possibly auto-generate the output filename.
    if args.filename is None:
        basename = os.path.basename(args.dataset)
        args.filename = os.path.splitext(basename)[0] + '_embeddings.h5'

    os_utils.touch_dir(os.path.join(args.experiment_root, args.foldername))

    log_file = os.path.join(args.experiment_root, args.foldername, "embed")
    log = logging.getLogger('embed')

    args.filename = os.path.join(args.experiment_root, args.foldername,
    var_filepath = os.path.join(args.experiment_root, args.foldername,
                                args.filename[:-3] + '_var.txt')
    # Load the args from the original experiment.
    args_file = os.path.join(args.experiment_root, 'args.json')

    if os.path.isfile(args_file):
        if not args.quiet:
            print('Loading args from {}.'.format(args_file))
        with open(args_file, 'r') as f:
            args_resumed = json.load(f)

        # Add arguments from training.
        for key, value in args_resumed.items():
            args.__dict__.setdefault(key, value)

        # A couple special-cases and sanity checks
        if (args_resumed['crop_augment']) == (args.crop_augment is None):
            print('WARNING: crop augmentation differs between training and '
        args.image_root = args.image_root or args_resumed['image_root']
        raise IOError(
            '`args.json` could not be found in: {}'.format(args_file))

    # Check a proper aggregator is provided if augmentation is used.
    if args.flip_augment or args.crop_augment == 'five':
        if args.aggregator is None:
                'ERROR: Test time augmentation is performed but no aggregator'
                'was specified.')
        if args.aggregator is not None:
            print('ERROR: No test time augmentation that needs aggregating is '
                  'performed but an aggregator was specified.')

    if not args.quiet:
        print('Evaluating using the following parameters:')
        for key, value in sorted(vars(args).items()):
            print('{}: {}'.format(key, value))

    # Load the data from the CSV file.
    _, data_fids = common.load_dataset(args.dataset, args.image_root)

    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)

    # Setup a tf Dataset containing all images.
    dataset = tf.data.Dataset.from_tensor_slices(data_fids)

    # Convert filenames to actual image tensors.
    dataset = dataset.map(lambda fid: common.fid_to_image(
        image_size=pre_crop_size if args.crop_augment else net_input_size),

    # Augment the data if specified by the arguments.
    # `modifiers` is a list of strings that keeps track of which augmentations
    # have been applied, so that a human can understand it later on.
    modifiers = ['original']
    if args.flip_augment:
        dataset = dataset.map(flip_augment)
        dataset = dataset.apply(tf.contrib.data.unbatch())
        modifiers = [o + m for m in ['', '_flip'] for o in modifiers]

    if args.crop_augment == 'center':
        dataset = dataset.map(lambda im, fid, pid:
                              (five_crops(im, net_input_size)[0], fid, pid))
        modifiers = [o + '_center' for o in modifiers]
    elif args.crop_augment == 'five':
        dataset = dataset.map(lambda im, fid, pid:
                              (tf.stack(five_crops(im, net_input_size)),
                               tf.stack([fid] * 5), tf.stack([pid] * 5)))
        dataset = dataset.apply(tf.contrib.data.unbatch())
        modifiers = [
            o + m for o in modifiers for m in [
                '_center', '_top_left', '_top_right', '_bottom_left',
    elif args.crop_augment == 'avgpool':
        modifiers = [o + '_avgpool' for o in modifiers]
        modifiers = [o + '_resize' for o in modifiers]

    # Group it back into PK batches.
    dataset = dataset.batch(args.batch_size)

    # Overlap producing and consuming.
    dataset = dataset.prefetch(1)

    #images, _, _ = dataset.make_one_shot_iterator().get_next()
    #init_iter = dataset.make_initializable_iterator()
    init_iter = tf.data.Iterator.from_structure(dataset.output_types,
    images, _, _ = init_iter.get_next()
    iter_init_op = init_iter.make_initializer(dataset)
    # Create the model and an embedding head.
    model = import_module('nets.' + args.model_name)
    head = import_module('heads.' + args.head_name)

    images_ph = tf.placeholder(dataset.output_types[0],
    endpoints, body_prefix = model.endpoints(images_ph, is_training=False)

    with tf.name_scope('head'):
        endpoints = head.head(endpoints, args.embedding_dim, is_training=False)

    gpu_options = tf.GPUOptions(allow_growth=True)
    gpu_config = tf.ConfigProto(gpu_options=gpu_options)
    with h5py.File(args.filename,
                   'w') as f_out, tf.Session(config=gpu_config) as sess:
        # Initialize the network/load the checkpoint.
        if args.checkpoint is None:
            checkpoint = tf.train.latest_checkpoint(args.experiment_root)
            checkpoint = os.path.join(args.experiment_root, args.checkpoint)
        if not args.quiet:
            print('Restoring from checkpoint: {}'.format(checkpoint))
        tf.train.Saver().restore(sess, checkpoint)

        # Go ahead and embed the whole dataset, with all augmented versions too.
        emb_storage = np.zeros(
            (len(data_fids) * len(modifiers), args.embedding_dim), np.float32)


        for start_idx in count(step=args.batch_size):
                current_imgs = sess.run(images)
                batch_embedding = endpoints['emb']
                emb = sess.run(batch_embedding,
                               feed_dict={images_ph: current_imgs})
                emb_storage[start_idx:start_idx + len(emb)] += emb
                print('\rEmbedded batch {}-{}/{}'.format(
                    start_idx, start_idx + len(emb), len(emb_storage)),
            except tf.errors.OutOfRangeError:
                break  # This just indicates the end of the dataset.

        if not args.quiet:
            print("Done with embedding, aggregating augmentations...",

        if len(modifiers) > 1:
            # Pull out the augmentations into a separate first dimension.
            emb_storage = emb_storage.reshape(len(data_fids), len(modifiers),
            emb_storage = emb_storage.transpose((1, 0, 2))  # (Aug,FID,128D)

            # Store the embedding of all individual variants too.
            emb_dataset = f_out.create_dataset('emb_aug', data=emb_storage)

            # Aggregate according to the specified parameter.
            emb_storage = AGGREGATORS[args.aggregator](emb_storage)

        # Store the final embeddings.
        emb_dataset = f_out.create_dataset('emb', data=emb_storage)

        # Store information about the produced augmentation and in case no crop
        # augmentation was used, if the images are resized or avg pooled.
                             data=np.asarray(modifiers, dtype='|S'))
def main():
    args = parser.parse_args()

    # Data augmentation
    global seq_geo
    global seq_img
    seq_geo = iaa.SomeOf(
        (0, 5),
            iaa.Fliplr(0.5),  # horizontally flip 50% of the images
            iaa.PerspectiveTransform(scale=(0, 0.075)),
                    "x": (0.8, 1.0),
                    "y": (0.8, 1.0)
                rotate=(-5, 5),
                    "x": (-0.1, 0.1),
                    "y": (-0.1, 0.1)
            ),  # rotate by -45 to +45 degrees),
                0, 0.125
            )),  # crop images from each side by 0 to 12.5% (randomly chosen)
            iaa.CoarsePepper(p=0.01, size_percent=0.1)
    # Content transformation
    seq_img = iaa.SomeOf(
        (0, 3),
                sigma=(0, 1.0)),  # blur images with a sigma of 0 to 2.0
            iaa.ContrastNormalization(alpha=(0.9, 1.1)),
            iaa.Grayscale(alpha=(0, 0.2)),
            iaa.Multiply((0.9, 1.1))

    # We store all arguments in a json file. This has two advantages:
    # 1. We can always get back and see what exactly that experiment was
    # 2. We can resume an experiment as-is without needing to remember all flags.
    args_file = os.path.join(args.experiment_root, 'args.json')
    if args.resume:
        if not os.path.isfile(args_file):
            raise IOError('`args.json` not found in {}'.format(args_file))

        print('Loading args from {}.'.format(args_file))
        with open(args_file, 'r') as f:
            args_resumed = json.load(f)
        args_resumed['resume'] = True  # This would be overwritten.

        # When resuming, we not only want to populate the args object with the
        # values from the file, but we also want to check for some possible
        # conflicts between loaded and given arguments.
        for key, value in args.__dict__.items():
            if key in args_resumed:
                resumed_value = args_resumed[key]
                if resumed_value != value:
                    print('Warning: For the argument `{}` we are using the'
                          ' loaded value `{}`. The provided value was `{}`'
                          '.'.format(key, resumed_value, value))
                    args.__dict__[key] = resumed_value
                print('Warning: A new argument was added since the last run:'
                      ' `{}`. Using the new value: `{}`.'.format(key, value))

        # If the experiment directory exists already, we bail in fear.
        if os.path.exists(args.experiment_root):
            if os.listdir(args.experiment_root):
                print('The directory {} already exists and is not empty.'
                      ' If you want to resume training, append --resume to'
                      ' your call.'.format(args.experiment_root))

        # Store the passed arguments for later resuming and grepping in a nice
        # and readable format.
        with open(args_file, 'w') as f:

    log_file = os.path.join(args.experiment_root, "train")
    log = logging.getLogger('train')

    # Also show all parameter values at the start, for ease of reading logs.
    log.info('Training using the following parameters:')
    for key, value in sorted(vars(args).items()):
        log.info('{}: {}'.format(key, value))

    # Check them here, so they are not required when --resume-ing.
    if not args.train_set:
        log.error("You did not specify the `train_set` argument!")
    if not args.image_root:
        log.error("You did not specify the required `image_root` argument!")

    # Load the data from the CSV file.
    pids, fids = common.load_dataset(args.train_set, args.image_root)
    max_fid_len = max(map(len, fids))  # We'll need this later for logfiles.

    # Load feature embeddings
    if args.hard_pool_size > 0:
        with h5py.File(args.train_embeddings, 'r') as f_train:
            train_embs = np.array(f_train['emb'])
            f_dists = scipy.spatial.distance.cdist(train_embs, train_embs)
            hard_ids = get_hard_id_pool(pids, f_dists, args.hard_pool_size)

    # Setup a tf.Dataset where one "epoch" loops over all PIDS.
    # PIDS are shuffled after every epoch and continue indefinitely.
    unique_pids = np.unique(pids)
    dataset = tf.data.Dataset.from_tensor_slices(unique_pids)
    dataset = dataset.shuffle(len(unique_pids))

    # Constrain the dataset size to a multiple of the batch-size, so that
    # we don't get overlap at the end of each epoch.
    if args.hard_pool_size == 0:
        dataset = dataset.take(
            (len(unique_pids) // args.batch_p) * args.batch_p)
        dataset = dataset.repeat(
            None)  # Repeat forever. Funny way of stating it.

        dataset = dataset.repeat(
            None)  # Repeat forever. Funny way of stating it.
        dataset = dataset.map(lambda pid: sample_batch_ids_for_pid(
            pid, all_pids=pids, batch_p=args.batch_p, all_hard_pids=hard_ids))
        # Unbatch the P PIDs
        dataset = dataset.apply(tf.contrib.data.unbatch())

    # For every PID, get K images.
    dataset = dataset.map(lambda pid: sample_k_fids_for_pid(
        pid, all_fids=fids, all_pids=pids, batch_k=args.batch_k))

    # Ungroup/flatten the batches for easy loading of the files.
    dataset = dataset.apply(tf.contrib.data.unbatch())

    # Convert filenames to actual image tensors.
    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)
    dataset = dataset.map(lambda fid, pid: common.fid_to_image(
        image_size=pre_crop_size if args.crop_augment else net_input_size),

    # Augment the data if specified by the arguments.
    if args.augment == False:
        dataset = dataset.map(
            lambda im, fid, pid: common.fid_to_image(
                if args.crop_augment else net_input_size),  #Ergys

        if args.flip_augment:
            dataset = dataset.map(lambda im, fid, pid: (
                tf.image.random_flip_left_right(im), fid, pid))
        if args.crop_augment:
            dataset = dataset.map(lambda im, fid, pid: (tf.random_crop(
                im, net_input_size + (3, )), fid, pid))
        dataset = dataset.map(lambda im, fid, pid: common.fid_to_image(
            fid, pid, image_root=args.image_root, image_size=net_input_size),

        dataset = dataset.map(lambda im, fid, pid: (tf.py_func(
            augment_images, [im], [tf.float32]), fid, pid))
        dataset = dataset.map(lambda im, fid, pid: (tf.reshape(
            (args.net_input_height, args.net_input_width, 3)), fid, pid))

    # Group it back into PK batches.
    batch_size = args.batch_p * args.batch_k
    dataset = dataset.batch(batch_size)

    # Overlap producing and consuming for parallelism.
    dataset = dataset.prefetch(batch_size * 2)

    # Since we repeat the data infinitely, we only need a one-shot iterator.
    images, fids, pids = dataset.make_one_shot_iterator().get_next()

    # Create the model and an embedding head.
    model = import_module('nets.' + args.model_name)
    head = import_module('heads.' + args.head_name)

    # Feed the image through the model. The returned `body_prefix` will be used
    # further down to load the pre-trained weights for all variables with this
    # prefix.
    endpoints, body_prefix = model.endpoints(images, is_training=True)
    with tf.name_scope('head'):
        endpoints = head.head(endpoints, args.embedding_dim, is_training=True)

    # Create the loss in two steps:
    # 1. Compute all pairwise distances according to the specified metric.
    # 2. For each anchor along the first dimension, compute its loss.
    dists = loss.cdist(endpoints['emb'], endpoints['emb'], metric=args.metric)
    losses, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[
                   batch_precision_at_k=args.batch_k - 1)

    # Count the number of active entries, and compute the total batch loss.
    num_active = tf.reduce_sum(tf.cast(tf.greater(losses, 1e-5), tf.float32))
    loss_mean = tf.reduce_mean(losses)

    # Some logging for tensorboard.
    tf.summary.histogram('loss_distribution', losses)
    tf.summary.scalar('loss', loss_mean)
    tf.summary.scalar('batch_top1', train_top1)
    tf.summary.scalar('batch_prec_at_{}'.format(args.batch_k - 1), prec_at_k)
    tf.summary.scalar('active_count', num_active)
    tf.summary.histogram('embedding_dists', dists)
    tf.summary.histogram('embedding_pos_dists', pos_dists)
    tf.summary.histogram('embedding_neg_dists', neg_dists)
                         tf.norm(endpoints['emb_raw'], axis=1))

    # Create the mem-mapped arrays in which we'll log all training detail in
    # addition to tensorboard, because tensorboard is annoying for detailed
    # inspection and actually discards data in histogram summaries.
    if args.detailed_logs:
        log_embs = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'embeddings'),
            shape=(args.train_iterations, batch_size, args.embedding_dim))
        log_loss = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'losses'),
            shape=(args.train_iterations, batch_size))
        log_fids = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'fids'),
            dtype='S' + str(max_fid_len),
            shape=(args.train_iterations, batch_size))

    # These are collected here before we add the optimizer, because depending
    # on the optimizer, it might add extra slots, which are also global
    # variables, with the exact same prefix.
    model_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,

    # Define the optimizer and the learning-rate schedule.
    # Unfortunately, we get NaNs if we don't handle no-decay separately.
    global_step = tf.Variable(0, name='global_step', trainable=False)
    if 0 <= args.decay_start_iteration < args.train_iterations:
        learning_rate = tf.train.exponential_decay(
            tf.maximum(0, global_step - args.decay_start_iteration),
            args.train_iterations - args.decay_start_iteration, 0.001)
        learning_rate = args.learning_rate
    tf.summary.scalar('learning_rate', learning_rate)
    optimizer = tf.train.AdamOptimizer(learning_rate)
    # Feel free to try others!
    # optimizer = tf.train.AdadeltaOptimizer(learning_rate)

    # Update_ops are used to update batchnorm stats.
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        train_op = optimizer.minimize(loss_mean, global_step=global_step)

    # Define a saver for the complete model.
    checkpoint_saver = tf.train.Saver(max_to_keep=0)

    with tf.Session() as sess:
        if args.resume:
            # In case we're resuming, simply load the full checkpoint to init.
            last_checkpoint = tf.train.latest_checkpoint(args.experiment_root)
            log.info('Restoring from checkpoint: {}'.format(last_checkpoint))
            checkpoint_saver.restore(sess, last_checkpoint)
            # But if we're starting from scratch, we may need to load some
            # variables from the pre-trained weights, and random init others.
            if args.initial_checkpoint is not None:
                saver = tf.train.Saver(model_variables)
                saver.restore(sess, args.initial_checkpoint)

            # In any case, we also store this initialization as a checkpoint,
            # such that we could run exactly reproduceable experiments.

        merged_summary = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(args.experiment_root,

        start_step = sess.run(global_step)
        log.info('Starting training from iteration {}.'.format(start_step))

        # Finally, here comes the main-loop. This `Uninterrupt` is a handy
        # utility such that an iteration still finishes on Ctrl+C and we can
        # stop the training cleanly.
        with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u:
            for i in range(start_step, args.train_iterations):

                # Compute gradients, update weights, store logs!
                start_time = time.time()
                _, summary, step, b_prec_at_k, b_embs, b_loss, b_fids = \
                    sess.run([train_op, merged_summary, global_step,
                              prec_at_k, endpoints['emb'], losses, fids])
                elapsed_time = time.time() - start_time

                # Compute the iteration speed and add it to the summary.
                # We did observe some weird spikes that we couldn't track down.
                summary2 = tf.Summary()
                summary_writer.add_summary(summary2, step)
                summary_writer.add_summary(summary, step)

                if args.detailed_logs:
                    log_embs[i], log_loss[i], log_fids[
                        i] = b_embs, b_loss, b_fids

                # Do a huge print out of the current progress.
                seconds_todo = (args.train_iterations - step) * elapsed_time
                    'iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, '
                    'batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it)'.format(
                        step, float(np.min(b_loss)), float(np.mean(b_loss)),
                        float(np.max(b_loss)), args.batch_k - 1,
                        timedelta(seconds=int(seconds_todo)), elapsed_time))

                # Save a checkpoint of training every so often.
                if (args.checkpoint_frequency > 0
                        and step % args.checkpoint_frequency == 0):

                # Stop the main-loop at the end of the step, if requested.
                if u.interrupted:
                    log.info("Interrupted on request!")

        # Store one final checkpoint. This might be redundant, but it is crucial
        # in case intermediate storing was disabled and it saves a checkpoint
        # when the process was interrupted.
                              os.path.join(args.experiment_root, 'checkpoint'),
def main(argv):
    # Verify that parameters are set correctly.
    args = parser.parse_args(argv)

    if not os.path.exists(args.dataset):

    # Possibly auto-generate the output filename.
    if args.filename is None:
        basename = os.path.basename(args.dataset)
        args.filename = os.path.splitext(basename)[0] + '_embeddings.h5'

    os_utils.touch_dir(os.path.join(args.experiment_root, args.foldername))

    log_file = os.path.join(args.experiment_root, args.foldername, "embed")
    log = logging.getLogger('embed')

    args.filename = os.path.join(args.experiment_root, args.foldername,
    var_filepath = os.path.join(args.experiment_root, args.foldername,
                                args.filename[:-3] + '_var.txt')
    # Load the args from the original experiment.
    args_file = os.path.join(args.experiment_root, 'args.json')

    if os.path.isfile(args_file):
        if not args.quiet:
            print('Loading args from {}.'.format(args_file))
        with open(args_file, 'r') as f:
            args_resumed = json.load(f)

        # Add arguments from training.
        for key, value in args_resumed.items():
            args.__dict__.setdefault(key, value)

        # A couple special-cases and sanity checks
        if (args_resumed['crop_augment']) == (args.crop_augment is None):
            print('WARNING: crop augmentation differs between training and '
        args.image_root = args.image_root or args_resumed['image_root']
        raise IOError(
            '`args.json` could not be found in: {}'.format(args_file))

    # Check a proper aggregator is provided if augmentation is used.
    if args.flip_augment or args.crop_augment == 'five':
        if args.aggregator is None:
                'ERROR: Test time augmentation is performed but no aggregator'
                'was specified.')
        if args.aggregator is not None:
            print('ERROR: No test time augmentation that needs aggregating is '
                  'performed but an aggregator was specified.')

    if not args.quiet:
        print('Evaluating using the following parameters:')
        for key, value in sorted(vars(args).items()):
            print('{}: {}'.format(key, value))

    # Load the data from the CSV file.
    _, data_fids = common.load_dataset(args.dataset, args.image_root)

    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)

    # Setup a tf Dataset containing all images.
    dataset = tf.data.Dataset.from_tensor_slices(data_fids)

    # Convert filenames to actual image tensors.
    dataset = dataset.map(lambda fid: common.fid_to_image(
        image_size=pre_crop_size if args.crop_augment else net_input_size),

    # Augment the data if specified by the arguments.
    # `modifiers` is a list of strings that keeps track of which augmentations
    # have been applied, so that a human can understand it later on.
    modifiers = ['original']
    if args.flip_augment:
        dataset = dataset.map(flip_augment)
        dataset = dataset.apply(tf.contrib.data.unbatch())
        modifiers = [o + m for m in ['', '_flip'] for o in modifiers]

    if args.crop_augment == 'center':
        dataset = dataset.map(lambda im, fid, pid:
                              (five_crops(im, net_input_size)[0], fid, pid))
        modifiers = [o + '_center' for o in modifiers]
    elif args.crop_augment == 'five':
        dataset = dataset.map(lambda im, fid, pid:
                              (tf.stack(five_crops(im, net_input_size)),
                               tf.stack([fid] * 5), tf.stack([pid] * 5)))
        dataset = dataset.apply(tf.contrib.data.unbatch())
        modifiers = [
            o + m for o in modifiers for m in [
                '_center', '_top_left', '_top_right', '_bottom_left',
    elif args.crop_augment == 'avgpool':
        modifiers = [o + '_avgpool' for o in modifiers]
        modifiers = [o + '_resize' for o in modifiers]

    emb_model = EmbeddingModel(args)

    # Group it back into PK batches.
    dataset = dataset.batch(args.batch_size)
    dataset = dataset.map(lambda im, fid, pid:
                          (emb_model.preprocess_input(im), fid, pid))
    # Overlap producing and consuming.
    dataset = dataset.prefetch(1)

    with h5py.File(args.filename, 'w') as f_out:

        ckpt = tf.train.Checkpoint(step=tf.Variable(1), net=emb_model)
        manager = tf.train.CheckpointManager(ckpt,
        if manager.latest_checkpoint:
            print("Restored from {}".format(manager.latest_checkpoint))
            print("Initializing from scratch.")

        emb_storage = np.zeros(
            (len(data_fids) * len(modifiers), args.embedding_dim), np.float32)

        # for batch_idx,batch in enumerate(dataset):
        dataset_iter = iter(dataset)
        for start_idx in count(step=args.batch_size):

                images, _, _ = next(dataset_iter)
                emb = emb_model(images)
                emb_storage[start_idx:start_idx + len(emb)] += emb
                print('\rEmbedded batch {}-{}/{}'.format(
                    start_idx, start_idx + len(emb), len(emb_storage)),
            except StopIteration:
                break  # This just indicates the end of the dataset.

        if not args.quiet:
            print("Done with embedding, aggregating augmentations...",

        if len(modifiers) > 1:
            # Pull out the augmentations into a separate first dimension.
            emb_storage = emb_storage.reshape(len(data_fids), len(modifiers),
            emb_storage = emb_storage.transpose((1, 0, 2))  # (Aug,FID,128D)

            # Store the embedding of all individual variants too.
            emb_dataset = f_out.create_dataset('emb_aug', data=emb_storage)

            # Aggregate according to the specified parameter.
            emb_storage = AGGREGATORS[args.aggregator](emb_storage)

        # Store the final embeddings.
        emb_dataset = f_out.create_dataset('emb', data=emb_storage)

        # Store information about the produced augmentation and in case no crop
        # augmentation was used, if the images are resized or avg pooled.
                             data=np.asarray(modifiers, dtype='|S'))
def main():
    args = parser.parse_args()

    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    # We store all arguments in a json file. This has two advantages:
    # 1. We can always get back and see what exactly
    # that experiment was
    # 2. We can resume an experiment as-is without needing to remember all flags.
    args_file = os.path.join(args.experiment_root, 'args.json')
    if args.resume:
        if not os.path.isfile(args_file):
            raise IOError('`args.json` not found in {}'.format(args_file))

        print('Loading args from {}.'.format(args_file))
        with open(args_file, 'r') as f:
            args_resumed = json.load(f)
        args_resumed['resume'] = True  # This would be overwritten.

        # When resuming, we not only want to populate the args object with the
        # values from the file, but we also want to check for some possible
        # conflicts between loaded and given arguments.
        for key, value in args.__dict__.items():
            if key in args_resumed:
                resumed_value = args_resumed[key]
                if resumed_value != value:
                    print('Warning: For the argument `{}` we are using the'
                          ' loaded value `{}`. The provided value was `{}`'
                          '.'.format(key, resumed_value, value))
                    args.__dict__[key] = resumed_value
                print('Warning: A new argument was added since the last run:'
                      ' `{}`. Using the new value: `{}`.'.format(key, value))

        # If the experiment directory exists already, we bail in fear.
        if os.path.exists(args.experiment_root):
            if os.listdir(args.experiment_root):
                print('The directory {} already exists and is not empty.'
                      ' If you want to resume training, append --resume to'
                      ' your call.'.format(args.experiment_root))

        # Store the passed arguments for later resuming and grepping in a nice
        # and readable format.
        with open(args_file, 'w') as f:

    log_file = os.path.join(args.experiment_root, "train")
    log = logging.getLogger('train')

    # Also show all parameter values at the start, for ease of reading logs.
    log.info('Training using the following parameters:')
    for key, value in sorted(vars(args).items()):
        log.info('{}: {}'.format(key, value))

    # Check them here, so they are not required when --resume-ing.
    if not args.train_dataset:
        log.error("You did not specify the `train_set` argument!")
    if not args.image_root:
        log.error("You did not specify the required `image_root` argument!")

    images, fids, pids, max_fid_len = prepare_data(args)
    train(args, images, fids, pids, max_fid_len, log)
def main(argv):

    args = parser.parse_args(argv)

    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    # tf.compat.v1.disable_eager_execution()

    # physical_devices = tf.config.experimental.list_physical_devices('GPU')
    # tf.config.experimental.set_memory_growth(physical_devices[0], True)

    # We store all arguments in a json file. This has two advantages:
    # 1. We can always get back and see what exactly that experiment was
    # 2. We can resume an experiment as-is without needing to remember all flags.
    args_file = os.path.join(args.experiment_root, 'args.json')
    if args.resume:
        if not os.path.isfile(args_file):
            raise IOError('`args.json` not found in {}'.format(args_file))

        print('Loading args from {}.'.format(args_file))
        with open(args_file, 'r') as f:
            args_resumed = json.load(f)
        args_resumed['resume'] = True  # This would be overwritten.

        # When resuming, we not only want to populate the args object with the
        # values from the file, but we also want to check for some possible
        # conflicts between loaded and given arguments.
        for key, value in args.__dict__.items():
            if key in args_resumed:
                resumed_value = args_resumed[key]
                if resumed_value != value:
                    print('Warning: For the argument `{}` we are using the'
                          ' loaded value `{}`. The provided value was `{}`'
                          '.'.format(key, resumed_value, value))
                    args.__dict__[key] = resumed_value
                print('Warning: A new argument was added since the last run:'
                      ' `{}`. Using the new value: `{}`.'.format(key, value))

        # If the experiment directory exists already, we bail in fear.
        if os.path.exists(args.experiment_root):
            if os.listdir(args.experiment_root):
                print('The directory {} already exists and is not empty.'
                      ' If you want to resume training, append --resume to'
                      ' your call.'.format(args.experiment_root))

        # Store the passed arguments for later resuming and grepping in a nice
        # and readable format.
        with open(args_file, 'w') as f:

    log_file = os.path.join(args.experiment_root, "train")
    log = logging.getLogger('train')

    # Also show all parameter values at the start, for ease of reading logs.
    log.info('Training using the following parameters:')
    for key, value in sorted(vars(args).items()):
        log.info('{}: {}'.format(key, value))

    # Check them here, so they are not required when --resume-ing.
    if not args.train_set:
        log.error("You did not specify the `train_set` argument!")
    if not args.image_root:
        log.error("You did not specify the required `image_root` argument!")

    # Load the data from the CSV file.
    pids, fids = common.load_dataset(args.train_set, args.image_root)
    max_fid_len = max(map(len, fids))  # We'll need this later for logfiles.

    # Setup a tf.Dataset where one "epoch" loops over all PIDS.
    # PIDS are shuffled after every epoch and continue indefinitely.
    unique_pids = np.unique(pids)
    if len(unique_pids) < args.batch_p:
        unique_pids = np.tile(unique_pids,
                              int(np.ceil(args.batch_p / len(unique_pids))))
    dataset = tf.data.Dataset.from_tensor_slices(unique_pids)
    dataset = dataset.shuffle(len(unique_pids))

    # Constrain the dataset size to a multiple of the batch-size, so that
    # we don't get overlap at the end of each epoch.
    dataset = dataset.take((len(unique_pids) // args.batch_p) * args.batch_p)
    dataset = dataset.repeat(None)  # Repeat forever. Funny way of stating it.

    # For every PID, get K images.
    dataset = dataset.map(lambda pid: sample_k_fids_for_pid(
        pid, all_fids=fids, all_pids=pids, batch_k=args.batch_k))

    # Ungroup/flatten the batches for easy loading of the files.
    dataset = dataset.unbatch()

    # Convert filenames to actual image tensors.
    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)

    dataset = dataset.map(lambda fid, pid: common.fid_to_image(
        image_size=pre_crop_size if args.crop_augment else net_input_size),

    # Augment the data if specified by the arguments.

    dataset = dataset.map(
        lambda im, fid, pid: common.fid_to_image(fid,
                                                 if args.crop_augment else
                                                 net_input_size),  # Ergys

    if args.flip_augment:
        dataset = dataset.map(lambda im, fid, pid:
                              (tf.image.random_flip_left_right(im), fid, pid))
    if args.crop_augment:
        dataset = dataset.map(lambda im, fid, pid: (tf.image.random_crop(
            im, net_input_size + (3, )), fid, pid))

    # Create the model and an embedding head.
    emb_model = EmbeddingModel(args)

    # Group it back into PK batches.
    batch_size = args.batch_p * args.batch_k
    dataset = dataset.map(lambda im, fid, pid:
                          (emb_model.preprocess_input(im), fid, pid))
    dataset = dataset.batch(batch_size)

    # Overlap producing and consuming for parallelism.
    dataset = dataset.prefetch(1)

    # Since we repeat the data infinitely, we only need a one-shot iterator.

    # Feed the image through the model. The returned `body_prefix` will be used
    # further down to load the pre-trained weights for all variables with this
    # prefix.

    # all_trainable_variables = embedding_head.trainable_variables+base_model.trainable_variables

    # Define the optimizer and the learning-rate schedule.
    # Unfortunately, we get NaNs if we don't handle no-decay separately.
    if 0 <= args.decay_start_iteration < args.train_iterations:
        learning_rate = tf.optimizers.schedules.PolynomialDecay(
            args.learning_rate, args.train_iterations, end_learning_rate=1e-7)
        learning_rate = args.learning_rate

    if args.optimizer == 'adam':
        optimizer = tf.keras.optimizers.Adam(learning_rate)
    elif args.optimizer == 'momentum':
        optimizer = tf.keras.optimizers.SGD(learning_rate, momentum=0.9)
        raise NotImplementedError('Invalid optimizer {}'.format(

    def train_step(images, pids):

        with tf.GradientTape() as tape:
            batch_embedding = emb_model(images)
            if args.loss == 'semi_hard_triplet':
                embedding_loss = triplet_semihard_loss(batch_embedding, pids,
            elif args.loss == 'hard_triplet':
                embedding_loss = batch_hard(batch_embedding, pids, args.margin,
            elif args.loss == 'lifted_loss':
                embedding_loss = lifted_loss(pids,
            elif args.loss == 'contrastive_loss':
                assert batch_size % 2 == 0
                assert args.batch_k == 4  ## Can work with other number but will need tuning

                contrastive_idx = np.tile([0, 1, 4, 3, 2, 5, 6, 7],
                                          args.batch_p // 2)
                for i in range(args.batch_p // 2):
                    contrastive_idx[i * 8:i * 8 + 8] += i * 8

                contrastive_idx = np.expand_dims(contrastive_idx, 1)
                batch_embedding_ordered = tf.gather_nd(batch_embedding,
                pids_ordered = tf.gather_nd(pids, contrastive_idx)
                # batch_embedding_ordered = tf.Print(batch_embedding_ordered,[pids_ordered],'pids_ordered :: ',summarize=1000)
                embeddings_anchor, embeddings_positive = tf.unstack(
                               [-1, 2, args.embedding_dim]), 2, 1)
                # embeddings_anchor = tf.Print(embeddings_anchor,[pids_ordered,embeddings_anchor,embeddings_positive,batch_embedding,batch_embedding_ordered],"Tensors ", summarize=1000)

                fixed_labels = np.tile([1, 0, 0, 1], args.batch_p // 2)
                # fixed_labels = np.reshape(fixed_labels,(len(fixed_labels),1))
                # print(fixed_labels)
                labels = tf.constant(fixed_labels)
                # labels = tf.Print(labels,[labels],'labels ',summarize=1000)
                embedding_loss = contrastive_loss(labels,
            elif args.loss == 'angular_loss':
                embeddings_anchor, embeddings_positive = tf.unstack(
                    tf.reshape(batch_embedding, [-1, 2, args.embedding_dim]),
                    2, 1)
                # pids = tf.Print(pids, [pids], 'pids:: ', summarize=100)
                pids, _ = tf.unstack(tf.reshape(pids, [-1, 2, 1]), 2, 1)
                # pids = tf.Print(pids,[pids],'pids:: ',summarize=100)
                embedding_loss = angular_loss(pids,

            elif args.loss == 'npairs_loss':
                assert args.batch_k == 2  ## Single positive pair per class
                embeddings_anchor, embeddings_positive = tf.unstack(
                    tf.reshape(batch_embedding, [-1, 2, args.embedding_dim]),
                    2, 1)
                pids, _ = tf.unstack(tf.reshape(pids, [-1, 2, 1]), 2, 1)
                pids = tf.reshape(pids, [-1])
                embedding_loss = npairs_loss(pids, embeddings_anchor,

                raise NotImplementedError('Invalid Loss {}'.format(args.loss))
            loss_mean = tf.reduce_mean(embedding_loss)

        gradients = tape.gradient(loss_mean, emb_model.trainable_variables)

        return embedding_loss

    # sess = tf.compat.v1.Session()
    # start_step = sess.run(global_step)
    # checkpoint_saver = tf.train.Saver(max_to_keep=2)
    start_step = 0
    log.info('Starting training from iteration {}.'.format(start_step))
    dataset_iter = iter(dataset)

    ckpt = tf.train.Checkpoint(step=tf.Variable(1),
    manager = tf.train.CheckpointManager(ckpt,

    if manager.latest_checkpoint:
        print("Restored from {}".format(manager.latest_checkpoint))
        print("Initializing from scratch.")

    with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u:
        for i in range(ckpt.step.numpy(), args.train_iterations):
            # for batch_idx, batch in enumerate():
            start_time = time.time()
            images, fids, pids = next(dataset_iter)
            batch_loss = train_step(images, pids)
            elapsed_time = time.time() - start_time
            seconds_todo = (args.train_iterations - i) * elapsed_time
            # print(tf.reduce_min(batch_loss).numpy(),tf.reduce_mean(batch_loss).numpy(),tf.reduce_max(batch_loss).numpy())
                'iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, ETA: {} ({:.2f}s/it)'
                    # args.batch_k - 1, float(b_prec_at_k),

            if (args.checkpoint_frequency > 0
                    and i % args.checkpoint_frequency == 0):

                # uncomment if you want to save the model weight separately
                # emb_model.save_weights(os.path.join(args.experiment_root, 'model_weights_{0:04d}.w'.format(i)))


            # Stop the main-loop at the end of the step, if requested.
            if u.interrupted:
                log.info("Interrupted on request!")
def main():
    args = parser.parse_args()

    # We store all arguments in a json file. This has two advantages:
    # 1. We can always get back and see what exactly that experiment was
    # 2. We can resume an experiment as-is without needing to remember all flags.
    args_file = os.path.join(args.experiment_root, 'args.json')
    if args.resume:
        if not os.path.isfile(args_file):
            raise IOError('`args.json` not found in {}'.format(args_file))

        print('Loading args from {}.'.format(args_file))
        with open(args_file, 'r') as f:
            args_resumed = json.load(f)
        args_resumed['resume'] = True  # This would be overwritten.

        # When resuming, we not only want to populate the args object with the
        # values from the file, but we also want to check for some possible
        # conflicts between loaded and given arguments.
        for key, value in args.__dict__.items():
            if key in args_resumed:
                resumed_value = args_resumed[key]
                if resumed_value != value:
                    print('Warning: For the argument `{}` we are using the'
                          ' loaded value `{}`. The provided value was `{}`'
                          '.'.format(key, resumed_value, value))
                    comand = input('Would you like to restore it?(yes/no)')
                    if comand == 'yes':
                        args.__dict__[key] = resumed_value
                        print('For the argument `{}` we are using the loaded value `{}`.'.format(key,
                        print('For the argument `{}` we are using the provided value `{}`.'.format(key,
                print('Warning: A new argument was added since the last run:'
                      ' `{}`. Using the new value: `{}`.'.format(key, value))
        with open(args_file, 'w') as f:
            json.dump(vars(args), f, ensure_ascii=False, indent=2, sort_keys=True)

        # If the experiment directory exists already, we bail in fear.
        if os.path.exists(args.experiment_root):
            if os.listdir(args.experiment_root):
                print('The directory {} already exists and is not empty.'
                      ' If you want to resume training, append --resume to'
                      ' your call.'.format(args.experiment_root))

        # Store the passed arguments for later resuming and grepping in a nice
        # and readable format.
        with open(args_file, 'w') as f:
            json.dump(vars(args), f, ensure_ascii=False, indent=2, sort_keys=True)

    log_file = os.path.join(args.experiment_root, "train")
    log = logging.getLogger('train')

    # Also show all parameter values at the start, for ease of reading logs.
    log.info('Training using the following parameters:')
    for key, value in sorted(vars(args).items()):
        log.info('{}: {}'.format(key, value))

    # Check them here, so they are not required when --resume-ing.
    if not args.train_set:
        log.error("You did not specify the `train_set` argument!")
    if not args.image_root:
        log.error("You did not specify the required `image_root` argument!")

#prepare the training dataset
    # Load the data from the TxT file. see Common.load_dataset function for details
    pids_train, fids_train = common.load_dataset(args.train_set, args.image_root)
    max_fid_len = max(map(len, fids_train))  # We'll need this later for logfiles.

    # Setup a tf.Dataset where one "epoch" loops over all PIDS.
    # PIDS are shuffled after every epoch and continue indefinitely.
    unique_pids = np.unique(pids_train)
    dataset = tf.data.Dataset.from_tensor_slices(unique_pids)
    dataset = dataset.shuffle(len(unique_pids))

    # Constrain the dataset size to a multiple of the batch-size, so that
    # we don't get overlap at the end of each epoch.
    dataset = dataset.take((len(unique_pids) // args.batch_p) * args.batch_p)
    dataset = dataset.repeat(None)  # Repeat forever. Funny way of stating it.

    # For every PID, get K images.
    dataset = dataset.map(lambda pid: sample_k_fids_for_pid(
        pid, all_fids=fids_train, all_pids=pids_train, batch_k=args.batch_k))  # now the dataset has been modified as [selected_fids
    # , pid] due to the return of the function 'sample_k_fids_for_pid'

    # Ungroup/flatten the batches for easy loading of the files.
    dataset = dataset.apply(tf.contrib.data.unbatch())

    # Convert filenames to actual image tensors.
    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)
    dataset = dataset.map(
        lambda fid, pid: common.fid_to_image(
            fid, pid, image_root=args.image_root,
            image_size=pre_crop_size if args.crop_augment else net_input_size),
        num_parallel_calls=args.loading_threads)  # now the dataset has been modified as [selected_images
    # , fid, pid] due to the return of the function 'fid_to_image'

    # Augment the data if specified by the arguments.
    if args.flip_augment:
        dataset = dataset.map(
            lambda im, fid, pid: (tf.image.random_flip_left_right(im), fid, pid))
    if args.crop_augment:
        dataset = dataset.map(
            lambda im, fid, pid: (tf.random_crop(im, net_input_size + (3,)), fid, pid))

    # Group it back into PK batches.
    batch_size = args.batch_p * args.batch_k
    dataset = dataset.batch(batch_size)

    # Overlap producing and consuming for parallelism.
    dataset = dataset.prefetch(1)

    # Since we repeat the data infinitely, we only need a one-shot iterator.
    images_train, fids_train, pids_train = dataset.make_one_shot_iterator().get_next()
    #prepare the validation set
    pids_val, fids_val = common.load_dataset(args.validation_set, args.validation_image_root)
    # Setup a tf.Dataset where one "epoch" loops over all PIDS.
    # PIDS are shuffled after every epoch and continue indefinitely.
    unique_pids_val = np.unique(pids_val)
    dataset_val = tf.data.Dataset.from_tensor_slices(unique_pids_val)
    dataset_val = dataset_val.shuffle(len(unique_pids_val))

    # Constrain the dataset size to a multiple of the batch-size, so that
    # we don't get overlap at the end of each epoch.
    dataset_val = dataset_val.take((len(unique_pids_val) // args.batch_p) * args.batch_p)
    dataset_val = dataset_val.repeat(None)  # Repeat forever. Funny way of stating it.

    # For every PID, get K images.
    dataset_val = dataset_val.map(lambda pid: sample_k_fids_for_pid(
        pid, all_fids=fids_val, all_pids=pids_val, batch_k=args.batch_k))  # now the dataset has been modified as [selected_fids
    # , pid] due to the return of the function 'sample_k_fids_for_pid'

    # Ungroup/flatten the batches for easy loading of the files.
    dataset_val = dataset_val.apply(tf.contrib.data.unbatch())

    # Convert filenames to actual image tensors.
    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)
    dataset_val = dataset_val.map(
        lambda fid, pid: common.fid_to_image(
            fid, pid, image_root=args.validation_image_root,
            image_size=pre_crop_size if args.crop_augment else net_input_size),
        num_parallel_calls=args.loading_threads)  # now the dataset has been modified as [selected_images
    # , fid, pid] due to the return of the function 'fid_to_image'

    # Augment the data if specified by the arguments.
    if args.flip_augment:
        dataset_val = dataset_val.map(
            lambda im, fid, pid: (tf.image.random_flip_left_right(im), fid, pid))
    if args.crop_augment:
        dataset_val = dataset_val.map(
            lambda im, fid, pid: (tf.random_crop(im, net_input_size + (3,)), fid, pid))

    # Group it back into PK batches.
    dataset_val = dataset_val.batch(batch_size)

    # Overlap producing and consuming for parallelism.
    dataset_val = dataset_val.prefetch(1)

    # Since we repeat the data infinitely, we only need a one-shot iterator.
    images_val, fids_val, pids_val = dataset_val.make_one_shot_iterator().get_next()
    # Create the model and an embedding head.
    model = import_module('nets.' + args.model_name)
    head = import_module('heads.' + args.head_name)

    # Feed the image through the model. The returned `body_prefix` will be used
    # further down to load the pre-trained weights for all variables with this
    # prefix.
    input_images = tf.placeholder(dtype=tf.float32, shape=[None,args.net_input_height,args.net_input_width,3],name='input')
    pids = tf.placeholder(dtype=tf.string, shape=[None,],name='pids')
    fids = tf.placeholder(dtype=tf.string, shape=[None, ], name='fids')

    endpoints, body_prefix = model.endpoints(input_images, is_training=True)
    with tf.name_scope('head'):
        endpoints = head.head(endpoints, args.embedding_dim, is_training=True)

    # Create the loss in two steps:
    # 1. Compute all pairwise distances according to the specified metric.
    # 2. For each anchor along the first dimension, compute its loss.
    # dists = loss.cdist(endpoints['emb'], endpoints['emb'], metric=args.metric)
    # losses, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[args.loss](
    #     dists, pids, args.margin, batch_precision_at_k=args.batch_k-1)
    # # '_' stands for the boolean matrix shows topK where the correct match of the identities occurs
    # shape=(batch_size,K)

# 更改
    # loss1
    dists1 = loss.cdist(endpoints['feature1'], endpoints['feature1'], metric=args.metric)
    losses1,_,_,_,_,_ =loss.LOSS_CHOICES[args.loss](
        dists1, pids, args.margin, batch_precision_at_k=args.batch_k - 1)
    dists2 = loss.cdist(endpoints['feature2'], endpoints['feature2'], metric=args.metric)
    losses2, _, _, _, _, _ = loss.LOSS_CHOICES[args.loss](
        dists2, pids, args.margin, batch_precision_at_k=args.batch_k - 1)
    dists3 = loss.cdist(endpoints['feature3'], endpoints['feature3'], metric=args.metric)
    losses3, _, _, _, _, _ = loss.LOSS_CHOICES[args.loss](
        dists3, pids, args.margin, batch_precision_at_k=args.batch_k - 1)
    dists4 = loss.cdist(endpoints['feature4'], endpoints['feature4'], metric=args.metric)
    losses4, _, _, _, _, _ = loss.LOSS_CHOICES[args.loss](
        dists4, pids, args.margin, batch_precision_at_k=args.batch_k - 1)
    dists_fu = loss.cdist(endpoints['fusion_layer'], endpoints['fusion_layer'], metric=args.metric)
    losses_fu, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[args.loss](
        dists_fu, pids, args.margin, batch_precision_at_k=args.batch_k - 1)

    losses = losses1+losses2+losses3+losses4+losses_fu

# 更改
    # losses_fu, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[args.loss](
    #     endpoints,pids, model_type=args.model_name, metric=args.metric, batch_precision_at_k=args.batch_k - 1
    # )

    # Count the number of active entries, and compute the total batch loss.
    num_active = tf.reduce_sum(tf.cast(tf.greater(losses, 1e-5), tf.float32))

    # 此处losses即为 pospair 比 negpair+margin 还大的部分
    loss_mean = tf.reduce_mean(losses)

    # Some logging for tensorboard.
    tf.summary.histogram('loss_distribution', losses)
    tf.summary.scalar('loss', loss_mean)
    tf.summary.scalar('batch_top1', train_top1)
    tf.summary.scalar('batch_prec_at_{}'.format(args.batch_k-1), prec_at_k)
    tf.summary.scalar('active_count', num_active)
    #tf.summary.histogram('embedding_dists', dists)
    tf.summary.histogram('embedding_pos_dists', pos_dists)
    tf.summary.histogram('embedding_neg_dists', neg_dists)
                         tf.norm(endpoints['emb_raw'], axis=1))

    # Create the mem-mapped arrays in which we'll log all training detail in
    # addition to tensorboard, because tensorboard is annoying for detailed
    # inspection and actually discards data in histogram summaries.
    if args.detailed_logs:
        log_embs = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'embeddings'),
            dtype=np.float32, shape=(args.train_iterations, batch_size, args.embedding_dim))
        log_loss = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'losses'),
            dtype=np.float32, shape=(args.train_iterations, batch_size))
        log_fids = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'fids'),
            dtype='S' + str(max_fid_len), shape=(args.train_iterations, batch_size))

    # These are collected here before we add the optimizer, because depending
    # on the optimizer, it might add extra slots, which are also global
    # variables, with the exact same prefix.
    model_variables = tf.get_collection(
        tf.GraphKeys.GLOBAL_VARIABLES, body_prefix)

    # Define the optimizer and the learning-rate schedule.
    # Unfortunately, we get NaNs if we don't handle no-decay separately.
    global_step = tf.Variable(0, name='global_step', trainable=False)  # 'global_step' means the number of batches seen
                                                                       #  by graph
    if 0 <= args.decay_start_iteration < args.train_iterations:
        learning_rate = tf.train.exponential_decay(
            tf.maximum(0, global_step - args.decay_start_iteration),  # decay every 'lr_decay_steps' after the
                                                                      # 'decay_start_iteration'
            # args.train_iterations - args.decay_start_iteration, args.weight_decay_factor)
            args.lr_decay_steps, args.lr_decay_factor, staircase=True)
        learning_rate = args.learning_rate  # the case when we set 'decay_start_iteration' as -1
    tf.summary.scalar('learning_rate', learning_rate)
    optimizer = tf.train.AdamOptimizer(learning_rate,epsilon=1e-3)
    # Feel free to try others!
    # optimizer = tf.train.AdadeltaOptimizer(learning_rate)

    # Update_ops are used to update batchnorm stats.
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        train_op = optimizer.minimize(loss_mean, global_step=global_step)

    # Define a saver for the complete model.
    checkpoint_saver = tf.train.Saver(max_to_keep=0)

    with tf.Session(config=config) as sess:
        if args.resume:
            # In case we're resuming, simply load the full checkpoint to init.
            if args.checkpoint is None:
                last_checkpoint = tf.train.latest_checkpoint(args.experiment_root)
                log.info('Restoring from checkpoint: {}'.format(last_checkpoint))
                checkpoint_saver.restore(sess, last_checkpoint)
                ckpt_path = os.path.join(args.experiment_root, args.checkpoint)
                log.info('Restoring from checkpoint: {}'.format(args.checkpoint))
                checkpoint_saver.restore(sess, ckpt_path)
            # But if we're starting from scratch, we may need to load some
            # variables from the pre-trained weights, and random init others.
            if args.initial_checkpoint is not None:
                saver = tf.train.Saver(model_variables)
                saver.restore(sess, args.initial_checkpoint)  # restore the pre-trained parameter from online model

            # In any case, we also store this initialization as a checkpoint,
            # such that we could run exactly reproduceable experiments.
            checkpoint_saver.save(sess, os.path.join(
                args.experiment_root, 'checkpoint'), global_step=0)

        merged_summary = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(args.experiment_root, sess.graph)

        start_step = sess.run(global_step)
        log.info('Starting training from iteration {}.'.format(start_step))

        # Finally, here comes the main-loop. This `Uninterrupt` is a handy
        # utility such that an iteration still finishes on Ctrl+C and we can
        # stop the training cleanly.
        with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u:
            for i in range(start_step, args.train_iterations):

                # Compute gradients, update weights, store logs!
                start_time = time.time()
                _, summary, step, b_prec_at_k, b_embs, b_loss, b_fids = \
                    sess.run([train_op, merged_summary, global_step,
                              prec_at_k, endpoints['emb'], losses, fids], feed_dict={input_images:images_train.eval(),
                elapsed_time = time.time() - start_time

                # Compute the iteration speed and add it to the summary.
                # We did observe some weird spikes that we couldn't track down.
                summary2 = tf.Summary()
                summary2.value.add(tag='secs_per_iter', simple_value=elapsed_time)
                summary_writer.add_summary(summary2, step)
                summary_writer.add_summary(summary, step)

                if args.detailed_logs:
                    log_embs[i], log_loss[i], log_fids[i] = b_embs, b_loss, b_fids

                # Do a huge print out of the current progress.
                seconds_todo = (args.train_iterations - step) * elapsed_time
                log.info('iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, '
                         'batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it)'.format(
                             args.batch_k-1, float(b_prec_at_k),

                # Save a checkpoint of training every so often.
                if (args.checkpoint_frequency > 0 and
                        step % args.checkpoint_frequency == 0):
                    checkpoint_saver.save(sess, os.path.join(
                        args.experiment_root, 'checkpoint'), global_step=step)

                #get validation results
                if (args.validation_frequency > 0 and
                        step % args.validation_frequency == 0):
                     b_prec_at_k_val, b_loss, b_fids = \
                        sess.run([prec_at_k, losses, fids], feed_dict={input_images : images_val.eval(),
                     log.info('Validation @:{:6d} iteration, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, '
                         'batch-p@{}: {:.2%}'.format(
                          args.batch_k - 1, float(b_prec_at_k_val)
                     summary3 = tf.Summary()
                     summary3.value.add(tag='secs_per_iter', simple_value=float(np.mean(b_loss)))
                     summary_writer.add_summary(summary3, step)
                     summary_writer.add_summary(summary3, step)

                # Stop the main-loop at the end of the step, if requested.
                if u.interrupted:
                    log.info("Interrupted on request!")

        # Store one final checkpoint. This might be redundant, but it is crucial
        # in case intermediate storing was disabled and it saves a checkpoint
        # when the process was interrupted.
        checkpoint_saver.save(sess, os.path.join(
            args.experiment_root, 'checkpoint'), global_step=step)
def main():
    # args = parser.parse_args()

    # We store all arguments in a json file. This has two advantages:
    # 1. We can always get back and see what exactly that experiment was
    # 2. We can resume an experiment as-is without needing to remember all flags.

    train_config = cfg.TrainConfig()

    args_file = os.path.join(train_config.experiment_root, 'args.json')
    if train_config.resume:
        if not os.path.isfile(args_file):
            raise IOError('`args.json` not found in {}'.format(args_file))

        print('Loading args from {}.'.format(args_file))
        with open(args_file, 'r') as f:
            args_resumed = json.load(f)
        args_resumed['resume'] = True  # This would be overwritten.

        # When resuming, we not only want to populate the args object with the
        # values from the file, but we also want to check for some possible
        # conflicts between loaded and given arguments.
        for key, value in train_config.__dict__.items():
            if key in args_resumed:
                resumed_value = args_resumed[key]
                if resumed_value != value:
                    print('Warning: For the argument `{}` we are using the'
                          ' loaded value `{}`. The provided value was `{}`'
                          '.'.format(key, resumed_value, value))
                    train_config.__dict__[key] = resumed_value
                print('Warning: A new argument was added since the last run:'
                      ' `{}`. Using the new value: `{}`.'.format(key, value))

        # If the experiment directory exists already, we bail in fear.
        if os.path.exists(train_config.experiment_root):
            if os.listdir(train_config.experiment_root):
                print('The directory {} already exists and is not empty.'
                      ' If you want to resume training, append --resume to'
                      ' your call.'.format(train_config.experiment_root))

        # Store the passed arguments for later resuming and grepping in a nice
        # and readable format.
        with open(args_file, 'w') as f:
            json.dump(vars(args), f, ensure_ascii=False, indent=2, sort_keys=True)

    log_file = os.path.join(train_config.experiment_root, "train")
    log = logging.getLogger('train')

    # Also show all parameter values at the start, for ease of reading logs.
    log.info('Training using the following parameters:')
    for key, value in sorted(vars(args).items()):
        log.info('{}: {}'.format(key, value))

    # Check them here, so they are not required when --resume-ing.
    if not train_config.train_set:
        log.error("You did not specify the `train_set` argument!")
    if not train_config.image_root:
        log.error("You did not specify the required `image_root` argument!")

    # Load the data from the CSV file.
    pids, fids = common.load_dataset(train_config.train_set, train_config.image_root, is_train=True)

    max_fid_len = max(map(len, fids))  # We'll need this later for logfiles

    # Setup a tf.Dataset where one "epoch" loops over all PIDS.
    # PIDS are shuffled after every epoch and continue indefinitely.

    unique_pids = np.unique(pids)

    dataset = tf.data.Dataset.from_tensor_slices(unique_pids)

    dataset = dataset.shuffle(len(unique_pids))

    # Constrain the dataset size to a multiple of the batch-size, so that
    # we don't get overlap at the end of each epoch.
    dataset = dataset.take((len(unique_pids) // train_config.batch_p) * train_config.batch_p)
    # take(count)  Creates a Dataset with at most count elements from this dataset.

    dataset = dataset.repeat(None)  # Repeat forever. Funny way of stating it.
    # Repeats this dataset count times.

    # For every PID, get K images.
    dataset = dataset.map(lambda pid: sample_k_fids_for_pid(
        pid, all_fids=fids, all_pids=pids, batch_k=train_config.batch_k))

    # Ungroup/flatten the batches for easy loading of the files.
    dataset = dataset.apply(tf.contrib.data.unbatch())
    # apply(transformation_func) Apply a transformation function to this dataset.
    # apply enables chaining of custom Dataset transformations, which are represented as functions that take one Dataset argument and return a transformed Dataset.

    # Convert filenames to actual image tensors.
    net_input_size = (train_config.net_input_height, train_config.net_input_width)
    # 256,128
    pre_crop_size = (train_config.pre_crop_height, train_config.pre_crop_width)
    # 288,144
    dataset = dataset.map(
        lambda fid, pid: common.fid_to_image_label(
            fid, pid, image_root=train_config.image_root,
            image_size=pre_crop_size if train_config.crop_augment else net_input_size),

    dataset = dataset.map(
        lambda im, keypt, mask, fid, pid: (tf.concat([im, keypt, mask], 2), fid, pid))

    # Augment the data if specified by the arguments.
    if train_config.flip_augment:
        dataset = dataset.map(
            lambda im, fid, pid: (tf.image.random_flip_left_right(im), fid, pid))

    # net_input_size_aug = net_input_size + (4,)
    if train_config.crop_augment:
        dataset = dataset.map(
            lambda im, fid, pid: (tf.random_crop(im, net_input_size + (21,)), fid, pid))
    # net_input_size + (21,) = (256, 128, 21)
    # split

    dataset = dataset.map(
        lambda im, fid, pid: (common.split(im, fid, pid)))


    # Group it back into PK batches.
    batch_size = train_config.batch_p * train_config.batch_k
    dataset = dataset.batch(batch_size)

    # Overlap producing and consuming for parallelism.
    dataset = dataset.prefetch(1)
    # prefetch(buffer_size)   Creates a Dataset that prefetches elements from this dataset.

    # Since we repeat the data infinitely, we only need a one-shot iterator.
    images, keypts, masks, fids, pids = dataset.make_one_shot_iterator().get_next()
    # tf.summary.image('image',images,10)

    # Create the model and an embedding head.
    model = import_module('nets.' + train_config.model_name)
    head = import_module('heads.' + train_config.head_name)

    # Feed the image through the model. The returned `body_prefix` will be used
    # further down to load the pre-trained weights for all variables with this
    # prefix.

    endpoints, body_prefix = model.endpoints(images, is_training=True)
    heatmap_in = endpoints[train_config.model_name + '/block4']
    # resnet_block_4_out = heatmap.resnet_block_4(heatmap_in)
    # resnet_block_3_4_out = heatmap.resnet_block_3_4(heatmap_in)
    # resnet_block_2_3_4_out = heatmap.resnet_block_2_3_4(heatmap_in)
    # head for heatmap
    with tf.name_scope('heatmap'):
        # heatmap_in = endpoints['model_output']
        # heatmap_out_layer_0 = heatmap.hmnet_layer_0(resnet_block_4_out, 1)
        # heatmap_out_layer_0 = heatmap.hmnet_layer_0(resnet_block_3_4_out, 1)
        # heatmap_out_layer_0 = heatmap.hmnet_layer_0(resnet_block_2_3_4_out, 1)
        heatmap_out_layer_0 = VAC.hmnet_layer_0(heatmap_in[:, :, :, 1020:2048], 1)
        heatmap_out_layer_1 = VAC.hmnet_layer_1(heatmap_out_layer_0, 1)
        heatmap_out_layer_2 = VAC.hmnet_layer_2(heatmap_out_layer_1, 1)
        heatmap_out_layer_3 = VAC.hmnet_layer_3(heatmap_out_layer_2, 1)
        heatmap_out_layer_4 = VAC.hmnet_layer_4(heatmap_out_layer_3, 1)
        heatmap_out = heatmap_out_layer_4
        heatmap_loss = VAC.loss_mutilayer(heatmap_out_layer_0, heatmap_out_layer_1, heatmap_out_layer_2,
                                              heatmap_out_layer_3, heatmap_out_layer_4, masks, net_input_size)
        # heatmap_loss = heatmap.loss(heatmap_out, labels, net_input_size)
        # heatmap_loss_mean = heatmap_loss

    with tf.name_scope('head'):
        # heatmap_sum = tf.reduce_sum(heatmap_out, axis=3)
        # heatmap_resize = tf.image.resize_images(tf.expand_dims(heatmap_sum, axis=3), [8, 4])
        # featuremap_tmp = tf.multiply(heatmap_resize, endpoints[args.model_name + '/block4'])
        # endpoints[args.model_name + '/block4'] = featuremap_tmp
        endpoints = head.head(endpoints, train_config.embedding_dim, is_training=True)

        tf.summary.image('feature_map', tf.expand_dims(endpoints[train_config.model_name + '/block4'][:, :, :, 0], axis=3), 4)

    with tf.name_scope('keypoints_pre'):
        keypoints_pre_in = endpoints[train_config.model_name + '/block4']
        # keypoints_pre_in_0 = keypoints_pre_in[:, :, :, 0:256]
        # keypoints_pre_in_1 = keypoints_pre_in[:, :, :, 256:512]
        # keypoints_pre_in_2 = keypoints_pre_in[:, :, :, 512:768]
        # keypoints_pre_in_3 = keypoints_pre_in[:, :, :, 768:1024]
        keypoints_pre_in_0 = keypoints_pre_in[:, :, :, 0:170]
        keypoints_pre_in_1 = keypoints_pre_in[:, :, :, 170:340]
        keypoints_pre_in_2 = keypoints_pre_in[:, :, :, 340:510]
        keypoints_pre_in_3 = keypoints_pre_in[:, :, :, 510:680]
        keypoints_pre_in_4 = keypoints_pre_in[:, :, :, 680:850]
        keypoints_pre_in_5 = keypoints_pre_in[:, :, :, 850:1020]

        labels = tf.image.resize_images(keypts, [128, 64])
        # keypoints_gt_0 = tf.concat([labels[:, :, :, 0:5], labels[:, :, :, 14:15], labels[:, :, :, 15:16], labels[:, :, :, 16:17], labels[:, :, :, 17:18]], 3)
        # keypoints_gt_1 = tf.concat([labels[:, :, :, 1:2], labels[:, :, :, 2:3], labels[:, :, :, 3:4], labels[:, :, :, 5:6]], 3)
        # keypoints_gt_2 = tf.concat([labels[:, :, :, 4:5], labels[:, :, :, 7:8], labels[:, :, :, 8:9], labels[:, :, :, 11:12]], 3)
        # keypoints_gt_3 = tf.concat([labels[:, :, :, 9:10], labels[:, :, :, 10:11], labels[:, :, :, 12:13], labels[:, :, :, 13:14]], 3)

        keypoints_gt_0 = labels[:, :, :, 0:5]
        keypoints_gt_1 = labels[:, :, :, 5:7]
        keypoints_gt_2 = labels[:, :, :, 7:9]
        keypoints_gt_3 = labels[:, :, :, 9:13]
        keypoints_gt_4 = labels[:, :, :, 13:15]
        keypoints_gt_5 = labels[:, :, :, 15:17]

        keypoints_pre_0 = PAC.tran_conv_0(keypoints_pre_in, kp_num=5)
        keypoints_pre_1 = PAC.tran_conv_1(keypoints_pre_in, kp_num=2)
        keypoints_pre_2 = PAC.tran_conv_2(keypoints_pre_in, kp_num=2)
        keypoints_pre_3 = PAC.tran_conv_3(keypoints_pre_in, kp_num=4)
        keypoints_pre_4 = PAC.tran_conv_4(keypoints_pre_in, kp_num=2)
        keypoints_pre_5 = PAC.tran_conv_5(keypoints_pre_in, kp_num=2)

        keypoints_loss_0 = PAC.keypoints_loss(keypoints_pre_0, keypoints_gt_0)
        keypoints_loss_1 = PAC.keypoints_loss(keypoints_pre_1, keypoints_gt_1)
        keypoints_loss_2 = PAC.keypoints_loss(keypoints_pre_2, keypoints_gt_2)
        keypoints_loss_3 = PAC.keypoints_loss(keypoints_pre_3, keypoints_gt_3)
        keypoints_loss_4 = PAC.keypoints_loss(keypoints_pre_4, keypoints_gt_4)
        keypoints_loss_5 = PAC.keypoints_loss(keypoints_pre_5, keypoints_gt_5)

        keypoints_loss = 5/17*keypoints_loss_0 + 2/17*keypoints_loss_1 + 2/17*keypoints_loss_2 + 4/17*keypoints_loss_3 + 2/17*keypoints_loss_4 + 2/17*keypoints_loss_5

    # Create the loss in two steps:
    # 1. Compute all pairwise distances according to the specified metric.
    # 2. For each anchor along the first dimension, compute its loss.
    dists = loss.cdist(endpoints['emb'], endpoints['emb'], metric=train_config.metric)
    losses, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[train_config.loss](
        dists, pids, train_config.margin, batch_precision_at_k=train_config.batch_k-1)

    # Count the number of active entries, and compute the total batch loss.
    num_active = tf.reduce_sum(tf.cast(tf.greater(losses, 1e-5), tf.float32))
    loss_mean = tf.reduce_mean(losses)
    scale_rate_0 = 1E-7
    scale_rate_1 = 6E-8
    total_loss = loss_mean + keypoints_loss*scale_rate_0 + heatmap_loss*scale_rate_1
    # total_loss = loss_mean + keypoints_loss * scale_rate_0
    # total_loss = loss_mean

    # Some logging for tensorboard.
    tf.summary.histogram('loss_distribution', losses)
    tf.summary.scalar('loss', loss_mean)
    # tf.summary.histogram('hm_loss_distribution', heatmap_loss)
    tf.summary.scalar('keypt_loss_0', keypoints_loss_0)
    tf.summary.scalar('keypt_loss_1', keypoints_loss_1)
    tf.summary.scalar('keypt_loss_2', keypoints_loss_2)
    tf.summary.scalar('keypt_loss_3', keypoints_loss_3)
    tf.summary.scalar('keypt_loss_all', keypoints_loss)
    tf.summary.scalar('total_loss', total_loss)
    tf.summary.scalar('batch_top1', train_top1)
    tf.summary.scalar('batch_prec_at_{}'.format(args.batch_k-1), prec_at_k)
    tf.summary.scalar('active_count', num_active)
    tf.summary.histogram('embedding_dists', dists)
    tf.summary.histogram('embedding_pos_dists', pos_dists)
    tf.summary.histogram('embedding_neg_dists', neg_dists)
                         tf.norm(endpoints['emb_raw'], axis=1))

    # Create the mem-mapped arrays in which we'll log all training detail in
    # addition to tensorboard, because tensorboard is annoying for detailed
    # inspection and actually discards data in histogram summaries.
    if args.detailed_logs:
        log_embs = lb.create_or_resize_dat(
            os.path.join(train_config.experiment_root, 'embeddings'),
            dtype=np.float32, shape=(train_config.train_iterations, batch_size, args.embedding_dim))
        log_loss = lb.create_or_resize_dat(
            os.path.join(train_config.experiment_root, 'losses'),
            dtype=np.float32, shape=(train_config.train_iterations, batch_size))
        log_fids = lb.create_or_resize_dat(
            os.path.join(train_config.experiment_root, 'fids'),
            dtype='S' + str(max_fid_len), shape=(train_config.train_iterations, batch_size))

    # These are collected here before we add the optimizer, because depending
    # on the optimizer, it might add extra slots, which are also global
    # variables, with the exact same prefix.
    model_variables = tf.get_collection(
        tf.GraphKeys.GLOBAL_VARIABLES, body_prefix)

    # Define the optimizer and the learning-rate schedule.
    # Unfortunately, we get NaNs if we don't handle no-decay separately.
    global_step = tf.Variable(0, name='global_step', trainable=False)
    if 0 <= train_config.decay_start_iteration < train_config.train_iterations:
        learning_rate = tf.train.exponential_decay(
            tf.maximum(0, global_step - train_config.decay_start_iteration),
            train_config.train_iterations - train_config.decay_start_iteration, 0.001)
        learning_rate = train_config.learning_rate
    tf.summary.scalar('learning_rate', learning_rate)
    optimizer = tf.train.AdamOptimizer(learning_rate)
    # Feel free to try others!
    # optimizer = tf.train.AdadeltaOptimizer(learning_rate)

    # Update_ops are used to update batchnorm stats.
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
    #    train_op = optimizer.minimize(loss_mean, global_step=global_step)
        train_op = optimizer.minimize(total_loss, global_step=global_step)

    # Define a saver for the complete model.
    checkpoint_saver = tf.train.Saver(max_to_keep=0)

    gpu_options = tf.GPUOptions(allow_growth=True)
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        if train_config.resume:
            # In case we're resuming, simply load the full checkpoint to init.
            last_checkpoint = tf.train.latest_checkpoint(args.experiment_root)
            log.info('Restoring from checkpoint: {}'.format(last_checkpoint))
            checkpoint_saver.restore(sess, last_checkpoint)
            # But if we're starting from scratch, we may need to load some
            # variables from the pre-trained weights, and random init others.
            if train_config.initial_checkpoint is not None:
                saver = tf.train.Saver(model_variables, write_version=tf.train.SaverDef.V1)
                saver.restore(sess, train_config.initial_checkpoint)

                # name_11 = 'resnet_v1_50/block4'
                # name_12 = 'resnet_v1_50/block3'
                # name_13 = 'resnet_v1_50/block2'
                # name_21 = 'Resnet_block_2_3_4/block4'
                # name_22 = 'Resnet_block_2_3_4/block3'
                # name_23 = 'Resnet_block_2_3_4/block2'
                # for var in tf.trainable_variables():
                #     var_name = var.name
                #     if re.match(name_11, var_name):
                #         dst_name = var_name.replace(name_11, name_21)
                #         tensor = tf.get_default_graph().get_tensor_by_name(var_name)
                #         dst_tensor = tf.get_default_graph().get_tensor_by_name(dst_name)
                #         tf.assign(dst_tensor, tensor)
                #     if re.match(name_12, var_name):
                #         dst_name = var_name.replace(name_12, name_22)
                #         tensor = tf.get_default_graph().get_tensor_by_name(var_name)
                #         dst_tensor = tf.get_default_graph().get_tensor_by_name(dst_name)
                #         tf.assign(dst_tensor, tensor)
                #     if re.match(name_13, var_name):
                #         dst_name = var_name.replace(name_13, name_23)
                #         tensor = tf.get_default_graph().get_tensor_by_name(var_name)
                #         dst_tensor = tf.get_default_graph().get_tensor_by_name(dst_name)
                #         tf.assign(dst_tensor, tensor)
            # In any case, we also store this initialization as a checkpoint,
            # such that we could run exactly reproduceable experiments.
            checkpoint_saver.save(sess, os.path.join(
                train_config.experiment_root, 'checkpoint'), global_step=0)

        merged_summary = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(train_config.experiment_root, sess.graph)

        start_step = sess.run(global_step)
        log.info('Starting training from iteration {}.'.format(start_step))

        # Finally, here comes the main-loop. This `Uninterrupt` is a handy
        # utility such that an iteration still finishes on Ctrl+C and we can
        # stop the training cleanly.
        with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u:
            for i in range(start_step, train_config.train_iterations):

                # Compute gradients, update weights, store logs!
                start_time = time.time()
                _, summary, step, b_prec_at_k, b_embs, b_loss, b_fids = \
                    sess.run([train_op, merged_summary, global_step,
                              prec_at_k, endpoints['emb'], losses, fids])
                elapsed_time = time.time() - start_time

                # Compute the iteration speed and add it to the summary.
                # We did observe some weird spikes that we couldn't track down.
                summary2 = tf.Summary()
                summary2.value.add(tag='secs_per_iter', simple_value=elapsed_time)
                summary_writer.add_summary(summary2, step)
                summary_writer.add_summary(summary, step)

                if train_config.detailed_logs:
                    log_embs[i], log_loss[i], log_fids[i] = b_embs, b_loss, b_fids

                # Do a huge print out of the current progress.
                seconds_todo = (train_config.train_iterations - step) * elapsed_time
                log.info('iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, '
                         'batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it)'.format(
                             train_config.batch_k-1, float(b_prec_at_k),

                # Save a checkpoint of training every so often.
                if (train_config.checkpoint_frequency > 0 and
                        step % train_config.checkpoint_frequency == 0):
                    checkpoint_saver.save(sess, os.path.join(
                        train_config.experiment_root, 'checkpoint'), global_step=step)

                # Stop the main-loop at the end of the step, if requested.
                if u.interrupted:
                    log.info("Interrupted on request!")

        # Store one final checkpoint. This might be redundant, but it is crucial
        # in case intermediate storing was disabled and it saves a checkpoint
        # when the process was interrupted.
        checkpoint_saver.save(sess, os.path.join(
            train_config.experiment_root, 'checkpoint'), global_step=step)