示例#1
0
def main():
    args = parser.parse_args()

    # We store all arguments in a json file. This has two advantages:
    # 1. We can always get back and see what exactly that experiment was
    # 2. We can resume an experiment as-is without needing to remember all flags.
    args_file = os.path.join(args.experiment_root, 'args.json')
    if args.resume:
        if not os.path.isfile(args_file):
            raise IOError('`args.json` not found in {}'.format(args_file))

        print('Loading args from {}.'.format(args_file))
        with open(args_file, 'r') as f:
            args_resumed = json.load(f)
        args_resumed['resume'] = True  # This would be overwritten.

        # When resuming, we not only want to populate the args object with the
        # values from the file, but we also want to check for some possible
        # conflicts between loaded and given arguments.
        for key, value in args.__dict__.items():
            if key in args_resumed:
                resumed_value = args_resumed[key]
                if resumed_value != value:
                    print('Warning: For the argument `{}` we are using the'
                          ' loaded value `{}`. The provided value was `{}`'
                          '.'.format(key, resumed_value, value))
                    comand = input('Would you like to restore it?(yes/no)')
                    if comand == 'yes':
                        args.__dict__[key] = resumed_value
                        print(
                            'For the argument `{}` we are using the loaded value `{}`.'
                            .format(key, args.__dict__[key]))
                    else:
                        print(
                            'For the argument `{}` we are using the provided value `{}`.'
                            .format(key, args.__dict__[key]))
            else:
                print('Warning: A new argument was added since the last run:'
                      ' `{}`. Using the new value: `{}`.'.format(key, value))
        os.remove(args_file)
        with open(args_file, 'w') as f:
            json.dump(vars(args),
                      f,
                      ensure_ascii=False,
                      indent=2,
                      sort_keys=True)

    else:
        # If the experiment directory exists already, we bail in fear.
        if os.path.exists(args.experiment_root):
            if os.listdir(args.experiment_root):
                print('The directory {} already exists and is not empty.'
                      ' If you want to resume training, append --resume to'
                      ' your call.'.format(args.experiment_root))
                exit(1)
        else:
            os.makedirs(args.experiment_root)

        # Store the passed arguments for later resuming and grepping in a nice
        # and readable format.
        with open(args_file, 'w') as f:
            json.dump(vars(args),
                      f,
                      ensure_ascii=False,
                      indent=2,
                      sort_keys=True)

    log_file = os.path.join(args.experiment_root, "train")
    logging.config.dictConfig(common.get_logging_dict(log_file))
    log = logging.getLogger('train')

    # Also show all parameter values at the start, for ease of reading logs.
    log.info('Training using the following parameters:')
    for key, value in sorted(vars(args).items()):
        log.info('{}: {}'.format(key, value))

    # Check them here, so they are not required when --resume-ing.
    if not args.train_set:
        parser.print_help()
        log.error("You did not specify the `train_set` argument!")
        sys.exit(1)
    if not args.image_root:
        parser.print_help()
        log.error("You did not specify the required `image_root` argument!")
        sys.exit(1)

######################################################################################
#prepare the training dataset
# Load the data from the TxT file. see Common.load_dataset function for details
    pids_train, fids_train = common.load_dataset(args.train_set,
                                                 args.image_root)
    max_fid_len = max(map(len,
                          fids_train))  # We'll need this later for logfiles.

    # Setup a tf.Dataset where one "epoch" loops over all PIDS.
    # PIDS are shuffled after every epoch and continue indefinitely.
    unique_pids = np.unique(pids_train)
    dataset = tf.data.Dataset.from_tensor_slices(unique_pids)
    dataset = dataset.shuffle(len(unique_pids))

    # Constrain the dataset size to a multiple of the batch-size, so that
    # we don't get overlap at the end of each epoch.
    dataset = dataset.take((len(unique_pids) // args.batch_p) * args.batch_p)
    dataset = dataset.repeat(None)  # Repeat forever. Funny way of stating it.

    # For every PID, get K images.
    dataset = dataset.map(lambda pid: sample_k_fids_for_pid(
        pid, all_fids=fids_train, all_pids=pids_train, batch_k=args.batch_k
    ))  # now the dataset has been modified as [selected_fids
    # , pid] due to the return of the function 'sample_k_fids_for_pid'

    # Ungroup/flatten the batches for easy loading of the files.
    dataset = dataset.apply(tf.contrib.data.unbatch())

    # Convert filenames to actual image tensors.
    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)
    dataset = dataset.map(
        lambda fid, pid: common.fid_to_image(fid,
                                             pid,
                                             image_root=args.image_root,
                                             image_size=pre_crop_size if args.
                                             crop_augment else net_input_size),
        num_parallel_calls=args.loading_threads
    )  # now the dataset has been modified as [selected_images
    # , fid, pid] due to the return of the function 'fid_to_image'

    # Augment the data if specified by the arguments.
    if args.flip_augment:
        dataset = dataset.map(lambda im, fid, pid:
                              (tf.image.random_flip_left_right(im), fid, pid))
    if args.crop_augment:
        dataset = dataset.map(lambda im, fid, pid: (tf.random_crop(
            im, net_input_size + (3, )), fid, pid))

    # Group it back into PK batches.
    batch_size = args.batch_p * args.batch_k
    dataset = dataset.batch(batch_size)

    # Overlap producing and consuming for parallelism.
    dataset = dataset.prefetch(1)

    # Since we repeat the data infinitely, we only need a one-shot iterator.
    images_train, fids_train, pids_train = dataset.make_one_shot_iterator(
    ).get_next()
    ########################################################################################################################
    #prepare the validation set
    pids_val, fids_val = common.load_dataset(args.validation_set,
                                             args.validation_image_root)
    # Setup a tf.Dataset where one "epoch" loops over all PIDS.
    # PIDS are shuffled after every epoch and continue indefinitely.
    unique_pids_val = np.unique(pids_val)
    dataset_val = tf.data.Dataset.from_tensor_slices(unique_pids_val)
    dataset_val = dataset_val.shuffle(len(unique_pids_val))

    # Constrain the dataset size to a multiple of the batch-size, so that
    # we don't get overlap at the end of each epoch.
    dataset_val = dataset_val.take(
        (len(unique_pids_val) // args.batch_p) * args.batch_p)
    dataset_val = dataset_val.repeat(
        None)  # Repeat forever. Funny way of stating it.

    # For every PID, get K images.
    dataset_val = dataset_val.map(lambda pid: sample_k_fids_for_pid(
        pid, all_fids=fids_val, all_pids=pids_val, batch_k=args.batch_k
    ))  # now the dataset has been modified as [selected_fids
    # , pid] due to the return of the function 'sample_k_fids_for_pid'

    # Ungroup/flatten the batches for easy loading of the files.
    dataset_val = dataset_val.apply(tf.contrib.data.unbatch())

    # Convert filenames to actual image tensors.
    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)
    dataset_val = dataset_val.map(
        lambda fid, pid: common.fid_to_image(
            fid,
            pid,
            image_root=args.validation_image_root,
            image_size=pre_crop_size if args.crop_augment else net_input_size),
        num_parallel_calls=args.loading_threads
    )  # now the dataset has been modified as [selected_images
    # , fid, pid] due to the return of the function 'fid_to_image'

    # Augment the data if specified by the arguments.
    if args.flip_augment:
        dataset_val = dataset_val.map(lambda im, fid, pid: (
            tf.image.random_flip_left_right(im), fid, pid))
    if args.crop_augment:
        dataset_val = dataset_val.map(lambda im, fid, pid: (tf.random_crop(
            im, net_input_size + (3, )), fid, pid))

    # Group it back into PK batches.
    dataset_val = dataset_val.batch(batch_size)

    # Overlap producing and consuming for parallelism.
    dataset_val = dataset_val.prefetch(1)

    # Since we repeat the data infinitely, we only need a one-shot iterator.
    images_val, fids_val, pids_val = dataset_val.make_one_shot_iterator(
    ).get_next()
    ####################################################################################################################
    # Create the model and an embedding head.
    model = import_module('nets.' + args.model_name)
    head = import_module('heads.' + args.head_name)

    # Feed the image through the model. The returned `body_prefix` will be used
    # further down to load the pre-trained weights for all variables with this
    # prefix.
    input_images = tf.placeholder(
        dtype=tf.float32,
        shape=[None, args.net_input_height, args.net_input_width, 3],
        name='input')
    pids = tf.placeholder(dtype=tf.string, shape=[
        None,
    ], name='pids')
    fids = tf.placeholder(dtype=tf.string, shape=[
        None,
    ], name='fids')

    endpoints, body_prefix = model.endpoints(input_images, is_training=True)
    with tf.name_scope('head'):
        endpoints = head.head(endpoints, args.embedding_dim, is_training=True)

    # Create the loss in two steps:
    # 1. Compute all pairwise distances according to the specified metric.
    # 2. For each anchor along the first dimension, compute its loss.
    # dists = loss.cdist(endpoints['emb'], endpoints['emb'], metric=args.metric)
    # losses, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[args.loss](
    #     dists, pids, args.margin, batch_precision_at_k=args.batch_k-1)
    # # '_' stands for the boolean matrix shows topK where the correct match of the identities occurs
    # shape=(batch_size,K)


# 更改
# loss1
    dists1 = loss.cdist(endpoints['feature1'],
                        endpoints['feature1'],
                        metric=args.metric)
    losses1, _, _, _, _, _ = loss.LOSS_CHOICES[args.loss](
        dists1, pids, args.margin, batch_precision_at_k=args.batch_k - 1)
    dists2 = loss.cdist(endpoints['feature2'],
                        endpoints['feature2'],
                        metric=args.metric)
    losses2, _, _, _, _, _ = loss.LOSS_CHOICES[args.loss](
        dists2, pids, args.margin, batch_precision_at_k=args.batch_k - 1)
    dists3 = loss.cdist(endpoints['feature3'],
                        endpoints['feature3'],
                        metric=args.metric)
    losses3, _, _, _, _, _ = loss.LOSS_CHOICES[args.loss](
        dists3, pids, args.margin, batch_precision_at_k=args.batch_k - 1)
    dists4 = loss.cdist(endpoints['feature4'],
                        endpoints['feature4'],
                        metric=args.metric)
    losses4, _, _, _, _, _ = loss.LOSS_CHOICES[args.loss](
        dists4, pids, args.margin, batch_precision_at_k=args.batch_k - 1)
    dists_fu = loss.cdist(endpoints['fusion_layer'],
                          endpoints['fusion_layer'],
                          metric=args.metric)
    losses_fu, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[
        args.loss](dists_fu,
                   pids,
                   args.margin,
                   batch_precision_at_k=args.batch_k - 1)

    losses = losses1 + losses2 + losses3 + losses4 + losses_fu

    # 更改
    #loss
    # losses_fu, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[args.loss](
    #     endpoints,pids, model_type=args.model_name, metric=args.metric, batch_precision_at_k=args.batch_k - 1
    # )

    # Count the number of active entries, and compute the total batch loss.
    num_active = tf.reduce_sum(tf.cast(tf.greater(losses, 1e-5), tf.float32))

    # 此处losses即为 pospair 比 negpair+margin 还大的部分
    loss_mean = tf.reduce_mean(losses)

    # Some logging for tensorboard.
    tf.summary.histogram('loss_distribution', losses)
    tf.summary.scalar('loss', loss_mean)
    tf.summary.scalar('batch_top1', train_top1)
    tf.summary.scalar('batch_prec_at_{}'.format(args.batch_k - 1), prec_at_k)
    tf.summary.scalar('active_count', num_active)
    #tf.summary.histogram('embedding_dists', dists)
    tf.summary.histogram('embedding_pos_dists', pos_dists)
    tf.summary.histogram('embedding_neg_dists', neg_dists)
    tf.summary.histogram('embedding_lengths',
                         tf.norm(endpoints['emb_raw'], axis=1))

    # Create the mem-mapped arrays in which we'll log all training detail in
    # addition to tensorboard, because tensorboard is annoying for detailed
    # inspection and actually discards data in histogram summaries.
    if args.detailed_logs:
        log_embs = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'embeddings'),
            dtype=np.float32,
            shape=(args.train_iterations, batch_size, args.embedding_dim))
        log_loss = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'losses'),
            dtype=np.float32,
            shape=(args.train_iterations, batch_size))
        log_fids = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'fids'),
            dtype='S' + str(max_fid_len),
            shape=(args.train_iterations, batch_size))

    # These are collected here before we add the optimizer, because depending
    # on the optimizer, it might add extra slots, which are also global
    # variables, with the exact same prefix.
    model_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                        body_prefix)

    # Define the optimizer and the learning-rate schedule.
    # Unfortunately, we get NaNs if we don't handle no-decay separately.
    global_step = tf.Variable(
        0, name='global_step',
        trainable=False)  # 'global_step' means the number of batches seen
    #  by graph
    if 0 <= args.decay_start_iteration < args.train_iterations:
        learning_rate = tf.train.exponential_decay(
            args.learning_rate,
            tf.maximum(0, global_step - args.decay_start_iteration
                       ),  # decay every 'lr_decay_steps' after the
            # 'decay_start_iteration'
            # args.train_iterations - args.decay_start_iteration, args.weight_decay_factor)
            args.lr_decay_steps,
            args.lr_decay_factor,
            staircase=True)
    else:
        learning_rate = args.learning_rate  # the case when we set 'decay_start_iteration' as -1
    tf.summary.scalar('learning_rate', learning_rate)
    optimizer = tf.train.AdamOptimizer(learning_rate, epsilon=1e-3)
    # Feel free to try others!
    # optimizer = tf.train.AdadeltaOptimizer(learning_rate)

    # Update_ops are used to update batchnorm stats.
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        train_op = optimizer.minimize(loss_mean, global_step=global_step)

    # Define a saver for the complete model.
    checkpoint_saver = tf.train.Saver(max_to_keep=0)

    with tf.Session(config=config) as sess:
        if args.resume:
            # In case we're resuming, simply load the full checkpoint to init.
            if args.checkpoint is None:
                last_checkpoint = tf.train.latest_checkpoint(
                    args.experiment_root)
                log.info(
                    'Restoring from checkpoint: {}'.format(last_checkpoint))
                checkpoint_saver.restore(sess, last_checkpoint)
            else:
                ckpt_path = os.path.join(args.experiment_root, args.checkpoint)
                log.info('Restoring from checkpoint: {}'.format(
                    args.checkpoint))
                checkpoint_saver.restore(sess, ckpt_path)
        else:
            # But if we're starting from scratch, we may need to load some
            # variables from the pre-trained weights, and random init others.
            sess.run(tf.global_variables_initializer())
            if args.initial_checkpoint is not None:
                saver = tf.train.Saver(model_variables)
                saver.restore(
                    sess, args.initial_checkpoint
                )  # restore the pre-trained parameter from online model

            # In any case, we also store this initialization as a checkpoint,
            # such that we could run exactly reproduceable experiments.
            checkpoint_saver.save(sess,
                                  os.path.join(args.experiment_root,
                                               'checkpoint'),
                                  global_step=0)

        merged_summary = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(args.experiment_root,
                                               sess.graph)

        start_step = sess.run(global_step)
        log.info('Starting training from iteration {}.'.format(start_step))

        # Finally, here comes the main-loop. This `Uninterrupt` is a handy
        # utility such that an iteration still finishes on Ctrl+C and we can
        # stop the training cleanly.
        with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u:
            for i in range(start_step, args.train_iterations):

                # Compute gradients, update weights, store logs!
                start_time = time.time()
                _, summary, step, b_prec_at_k, b_embs, b_loss, b_fids = \
                    sess.run([train_op, merged_summary, global_step,
                              prec_at_k, endpoints['emb'], losses, fids], feed_dict={input_images:images_train.eval(),
                                                                                     pids:pids_train.eval(),
                                                                                     fids:fids_train.eval()})
                elapsed_time = time.time() - start_time

                # Compute the iteration speed and add it to the summary.
                # We did observe some weird spikes that we couldn't track down.
                summary2 = tf.Summary()
                summary2.value.add(tag='secs_per_iter',
                                   simple_value=elapsed_time)
                summary_writer.add_summary(summary2, step)
                summary_writer.add_summary(summary, step)

                if args.detailed_logs:
                    log_embs[i], log_loss[i], log_fids[
                        i] = b_embs, b_loss, b_fids

                # Do a huge print out of the current progress.
                seconds_todo = (args.train_iterations - step) * elapsed_time
                log.info(
                    'iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, '
                    'batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it)'.format(
                        step, float(np.min(b_loss)), float(np.mean(b_loss)),
                        float(np.max(b_loss)), args.batch_k - 1,
                        float(b_prec_at_k),
                        timedelta(seconds=int(seconds_todo)), elapsed_time))
                sys.stdout.flush()
                sys.stderr.flush()

                # Save a checkpoint of training every so often.
                if (args.checkpoint_frequency > 0
                        and step % args.checkpoint_frequency == 0):
                    checkpoint_saver.save(sess,
                                          os.path.join(args.experiment_root,
                                                       'checkpoint'),
                                          global_step=step)

                #get validation results
                if (args.validation_frequency > 0
                        and step % args.validation_frequency == 0):
                    b_prec_at_k_val, b_loss, b_fids = \
                       sess.run([prec_at_k, losses, fids], feed_dict={input_images : images_val.eval(),
                                                                      pids:pids_val.eval(),
                                                                      fids:fids_val.eval()})
                    log.info(
                        'Validation @:{:6d} iteration, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, '
                        'batch-p@{}: {:.2%}'.format(step,
                                                    float(np.min(b_loss)),
                                                    float(np.mean(b_loss)),
                                                    float(np.max(b_loss)),
                                                    args.batch_k - 1,
                                                    float(b_prec_at_k_val)))
                    sys.stdout.flush()
                    sys.stderr.flush()
                    summary3 = tf.Summary()
                    summary3.value.add(tag='secs_per_iter',
                                       simple_value=float(np.mean(b_loss)))
                    summary_writer.add_summary(summary3, step)
                    summary_writer.add_summary(summary3, step)

                # Stop the main-loop at the end of the step, if requested.
                if u.interrupted:
                    log.info("Interrupted on request!")
                    break

        # Store one final checkpoint. This might be redundant, but it is crucial
        # in case intermediate storing was disabled and it saves a checkpoint
        # when the process was interrupted.
        checkpoint_saver.save(sess,
                              os.path.join(args.experiment_root, 'checkpoint'),
                              global_step=step)
示例#2
0
def main():
    args = parser.parse_args()

    # Data augmentation
    global seq_geo
    global seq_img
    seq_geo = iaa.SomeOf(
        (0, 5),
        [
            iaa.Fliplr(0.5),  # horizontally flip 50% of the images
            iaa.PerspectiveTransform(scale=(0, 0.075)),
            iaa.Affine(
                scale={
                    "x": (0.8, 1.0),
                    "y": (0.8, 1.0)
                },
                rotate=(-5, 5),
                translate_percent={
                    "x": (-0.1, 0.1),
                    "y": (-0.1, 0.1)
                },
            ),  # rotate by -45 to +45 degrees),
            iaa.Crop(pc=(
                0, 0.125
            )),  # crop images from each side by 0 to 12.5% (randomly chosen)
            iaa.CoarsePepper(p=0.01, size_percent=0.1)
        ],
        random_order=False)
    # Content transformation
    seq_img = iaa.SomeOf(
        (0, 3),
        [
            iaa.GaussianBlur(
                sigma=(0, 1.0)),  # blur images with a sigma of 0 to 2.0
            iaa.ContrastNormalization(alpha=(0.9, 1.1)),
            iaa.Grayscale(alpha=(0, 0.2)),
            iaa.Multiply((0.9, 1.1))
        ])

    # We store all arguments in a json file. This has two advantages:
    # 1. We can always get back and see what exactly that experiment was
    # 2. We can resume an experiment as-is without needing to remember all flags.
    args_file = os.path.join(args.experiment_root, 'args.json')
    if args.resume:
        if not os.path.isfile(args_file):
            raise IOError('`args.json` not found in {}'.format(args_file))

        print('Loading args from {}.'.format(args_file))
        with open(args_file, 'r') as f:
            args_resumed = json.load(f)
        args_resumed['resume'] = True  # This would be overwritten.

        # When resuming, we not only want to populate the args object with the
        # values from the file, but we also want to check for some possible
        # conflicts between loaded and given arguments.
        for key, value in args.__dict__.items():
            if key in args_resumed:
                resumed_value = args_resumed[key]
                if resumed_value != value:
                    print('Warning: For the argument `{}` we are using the'
                          ' loaded value `{}`. The provided value was `{}`'
                          '.'.format(key, resumed_value, value))
                    args.__dict__[key] = resumed_value
            else:
                print('Warning: A new argument was added since the last run:'
                      ' `{}`. Using the new value: `{}`.'.format(key, value))

    else:
        # If the experiment directory exists already, we bail in fear.
        if os.path.exists(args.experiment_root):
            if os.listdir(args.experiment_root):
                print('The directory {} already exists and is not empty.'
                      ' If you want to resume training, append --resume to'
                      ' your call.'.format(args.experiment_root))
                exit(1)
        else:
            os.makedirs(args.experiment_root)

        # Store the passed arguments for later resuming and grepping in a nice
        # and readable format.
        with open(args_file, 'w') as f:
            json.dump(vars(args),
                      f,
                      ensure_ascii=False,
                      indent=2,
                      sort_keys=True)

    log_file = os.path.join(args.experiment_root, "train")
    logging.config.dictConfig(common.get_logging_dict(log_file))
    log = logging.getLogger('train')

    # Also show all parameter values at the start, for ease of reading logs.
    log.info('Training using the following parameters:')
    for key, value in sorted(vars(args).items()):
        log.info('{}: {}'.format(key, value))

    # Check them here, so they are not required when --resume-ing.
    if not args.train_set:
        parser.print_help()
        log.error("You did not specify the `train_set` argument!")
        sys.exit(1)
    if not args.image_root:
        parser.print_help()
        log.error("You did not specify the required `image_root` argument!")
        sys.exit(1)

    # Load the data from the CSV file.
    pids, fids = common.load_dataset(args.train_set, args.image_root)
    max_fid_len = max(map(len, fids))  # We'll need this later for logfiles.

    # Load feature embeddings
    if args.hard_pool_size > 0:
        with h5py.File(args.train_embeddings, 'r') as f_train:
            train_embs = np.array(f_train['emb'])
            f_dists = scipy.spatial.distance.cdist(train_embs, train_embs)
            hard_ids = get_hard_id_pool(pids, f_dists, args.hard_pool_size)

    # Setup a tf.Dataset where one "epoch" loops over all PIDS.
    # PIDS are shuffled after every epoch and continue indefinitely.
    unique_pids = np.unique(pids)
    dataset = tf.data.Dataset.from_tensor_slices(unique_pids)
    dataset = dataset.shuffle(len(unique_pids))

    # Constrain the dataset size to a multiple of the batch-size, so that
    # we don't get overlap at the end of each epoch.
    if args.hard_pool_size == 0:
        dataset = dataset.take(
            (len(unique_pids) // args.batch_p) * args.batch_p)
        dataset = dataset.repeat(
            None)  # Repeat forever. Funny way of stating it.

    else:
        dataset = dataset.repeat(
            None)  # Repeat forever. Funny way of stating it.
        dataset = dataset.map(lambda pid: sample_batch_ids_for_pid(
            pid, all_pids=pids, batch_p=args.batch_p, all_hard_pids=hard_ids))
        # Unbatch the P PIDs
        dataset = dataset.apply(tf.contrib.data.unbatch())

    # For every PID, get K images.
    dataset = dataset.map(lambda pid: sample_k_fids_for_pid(
        pid, all_fids=fids, all_pids=pids, batch_k=args.batch_k))

    # Ungroup/flatten the batches for easy loading of the files.
    dataset = dataset.apply(tf.contrib.data.unbatch())

    # Convert filenames to actual image tensors.
    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)
    dataset = dataset.map(lambda im, fid, pid: common.fid_to_image(
        fid,
        pid,
        image_root=args.image_root,
        image_size=pre_crop_size if args.crop_augment else net_input_size),
                          num_parallel_calls=args.loading_threads)

    # Augment the data if specified by the arguments.
    if args.augment == False:
        dataset = dataset.map(
            lambda fid, pid: common.fid_to_image(fid,
                                                 pid,
                                                 image_root=args.image_root,
                                                 image_size=pre_crop_size
                                                 if args.crop_augment else
                                                 net_input_size),  #Ergys
            num_parallel_calls=args.loading_threads)

        if args.flip_augment:
            dataset = dataset.map(lambda im, fid, pid: (
                tf.image.random_flip_left_right(im), fid, pid))
        if args.crop_augment:
            dataset = dataset.map(lambda im, fid, pid: (tf.random_crop(
                im, net_input_size + (3, )), fid, pid))
    else:
        dataset = dataset.map(lambda im, fid, pid: common.fid_to_image(
            fid, pid, image_root=args.image_root, image_size=net_input_size),
                              num_parallel_calls=args.loading_threads)

        dataset = dataset.map(lambda im, fid, pid: (tf.py_func(
            augment_images, [im], [tf.float32]), fid, pid))
        dataset = dataset.map(lambda im, fid, pid: (tf.reshape(
            im[0],
            (args.net_input_height, args.net_input_width, 3)), fid, pid))

    # Group it back into PK batches.
    batch_size = args.batch_p * args.batch_k
    dataset = dataset.batch(batch_size)

    # Overlap producing and consuming for parallelism.
    dataset = dataset.prefetch(batch_size * 2)

    # Since we repeat the data infinitely, we only need a one-shot iterator.
    images, fids, pids = dataset.make_one_shot_iterator().get_next()

    # Create the model and an embedding head.
    model = import_module('nets.' + args.model_name)
    head = import_module('heads.' + args.head_name)

    # Feed the image through the model. The returned `body_prefix` will be used
    # further down to load the pre-trained weights for all variables with this
    # prefix.
    endpoints, body_prefix = model.endpoints(images, is_training=True)
    with tf.name_scope('head'):
        endpoints = head.head(endpoints, args.embedding_dim, is_training=True)

    # Create the loss in two steps:
    # 1. Compute all pairwise distances according to the specified metric.
    # 2. For each anchor along the first dimension, compute its loss.
    dists = loss.cdist(endpoints['emb'], endpoints['emb'], metric=args.metric)
    losses, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[
        args.loss](dists,
                   pids,
                   args.margin,
                   batch_precision_at_k=args.batch_k - 1)

    # Count the number of active entries, and compute the total batch loss.
    num_active = tf.reduce_sum(tf.cast(tf.greater(losses, 1e-5), tf.float32))
    loss_mean = tf.reduce_mean(losses)

    # Some logging for tensorboard.
    tf.summary.histogram('loss_distribution', losses)
    tf.summary.scalar('loss', loss_mean)
    tf.summary.scalar('batch_top1', train_top1)
    tf.summary.scalar('batch_prec_at_{}'.format(args.batch_k - 1), prec_at_k)
    tf.summary.scalar('active_count', num_active)
    tf.summary.histogram('embedding_dists', dists)
    tf.summary.histogram('embedding_pos_dists', pos_dists)
    tf.summary.histogram('embedding_neg_dists', neg_dists)
    tf.summary.histogram('embedding_lengths',
                         tf.norm(endpoints['emb_raw'], axis=1))

    # Create the mem-mapped arrays in which we'll log all training detail in
    # addition to tensorboard, because tensorboard is annoying for detailed
    # inspection and actually discards data in histogram summaries.
    if args.detailed_logs:
        log_embs = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'embeddings'),
            dtype=np.float32,
            shape=(args.train_iterations, batch_size, args.embedding_dim))
        log_loss = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'losses'),
            dtype=np.float32,
            shape=(args.train_iterations, batch_size))
        log_fids = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'fids'),
            dtype='S' + str(max_fid_len),
            shape=(args.train_iterations, batch_size))

    # These are collected here before we add the optimizer, because depending
    # on the optimizer, it might add extra slots, which are also global
    # variables, with the exact same prefix.
    model_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                        body_prefix)

    # Define the optimizer and the learning-rate schedule.
    # Unfortunately, we get NaNs if we don't handle no-decay separately.
    global_step = tf.Variable(0, name='global_step', trainable=False)
    if 0 <= args.decay_start_iteration < args.train_iterations:
        learning_rate = tf.train.exponential_decay(
            args.learning_rate,
            tf.maximum(0, global_step - args.decay_start_iteration),
            args.train_iterations - args.decay_start_iteration, 0.001)
    else:
        learning_rate = args.learning_rate
    tf.summary.scalar('learning_rate', learning_rate)
    optimizer = tf.train.AdamOptimizer(learning_rate)
    # Feel free to try others!
    # optimizer = tf.train.AdadeltaOptimizer(learning_rate)

    # Update_ops are used to update batchnorm stats.
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        train_op = optimizer.minimize(loss_mean, global_step=global_step)

    # Define a saver for the complete model.
    checkpoint_saver = tf.train.Saver(max_to_keep=0)

    with tf.Session() as sess:
        if args.resume:
            # In case we're resuming, simply load the full checkpoint to init.
            last_checkpoint = tf.train.latest_checkpoint(args.experiment_root)
            log.info('Restoring from checkpoint: {}'.format(last_checkpoint))
            checkpoint_saver.restore(sess, last_checkpoint)
        else:
            # But if we're starting from scratch, we may need to load some
            # variables from the pre-trained weights, and random init others.
            sess.run(tf.global_variables_initializer())
            if args.initial_checkpoint is not None:
                saver = tf.train.Saver(model_variables)
                saver.restore(sess, args.initial_checkpoint)

            # In any case, we also store this initialization as a checkpoint,
            # such that we could run exactly reproduceable experiments.
            checkpoint_saver.save(sess,
                                  os.path.join(args.experiment_root,
                                               'checkpoint'),
                                  global_step=0)

        merged_summary = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(args.experiment_root,
                                               sess.graph)

        start_step = sess.run(global_step)
        log.info('Starting training from iteration {}.'.format(start_step))

        # Finally, here comes the main-loop. This `Uninterrupt` is a handy
        # utility such that an iteration still finishes on Ctrl+C and we can
        # stop the training cleanly.
        with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u:
            for i in range(start_step, args.train_iterations):

                # Compute gradients, update weights, store logs!
                start_time = time.time()
                _, summary, step, b_prec_at_k, b_embs, b_loss, b_fids = \
                    sess.run([train_op, merged_summary, global_step,
                              prec_at_k, endpoints['emb'], losses, fids])
                elapsed_time = time.time() - start_time

                # Compute the iteration speed and add it to the summary.
                # We did observe some weird spikes that we couldn't track down.
                summary2 = tf.Summary()
                summary2.value.add(tag='secs_per_iter',
                                   simple_value=elapsed_time)
                summary_writer.add_summary(summary2, step)
                summary_writer.add_summary(summary, step)

                if args.detailed_logs:
                    log_embs[i], log_loss[i], log_fids[
                        i] = b_embs, b_loss, b_fids

                # Do a huge print out of the current progress.
                seconds_todo = (args.train_iterations - step) * elapsed_time
                log.info(
                    'iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, '
                    'batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it)'.format(
                        step, float(np.min(b_loss)), float(np.mean(b_loss)),
                        float(np.max(b_loss)), args.batch_k - 1,
                        float(b_prec_at_k),
                        timedelta(seconds=int(seconds_todo)), elapsed_time))
                sys.stdout.flush()
                sys.stderr.flush()

                # Save a checkpoint of training every so often.
                if (args.checkpoint_frequency > 0
                        and step % args.checkpoint_frequency == 0):
                    checkpoint_saver.save(sess,
                                          os.path.join(args.experiment_root,
                                                       'checkpoint'),
                                          global_step=step)

                # Stop the main-loop at the end of the step, if requested.
                if u.interrupted:
                    log.info("Interrupted on request!")
                    break

        # Store one final checkpoint. This might be redundant, but it is crucial
        # in case intermediate storing was disabled and it saves a checkpoint
        # when the process was interrupted.
        checkpoint_saver.save(sess,
                              os.path.join(args.experiment_root, 'checkpoint'),
                              global_step=step)
def main():
    args = parser.parse_args()

    # We store all arguments in a json file. This has two advantages:
    # 1. We can always get back and see what exactly that experiment was
    # 2. We can resume an experiment as-is without needing to remember all flags.
    args_file = os.path.join(args.experiment_root, 'args.json')
    if args.resume:
        if not os.path.isfile(args_file):
            raise IOError('`args.json` not found in {}'.format(args_file))

        print('Loading args from {}.'.format(args_file))
        with open(args_file, 'r') as f:
            args_resumed = json.load(f)
        args_resumed['resume'] = True  # This would be overwritten.

        # When resuming, we not only want to populate the args object with the
        # values from the file, but we also want to check for some possible
        # conflicts between loaded and given arguments.
        for key, value in args.__dict__.items():
            if key in args_resumed:
                resumed_value = args_resumed[key]
                if resumed_value != value:
                    print('Warning: For the argument `{}` we are using the'
                          ' loaded value `{}`. The provided value was `{}`'
                          '.'.format(key, resumed_value, value))
                    args.__dict__[key] = resumed_value
            else:
                print('Warning: A new argument was added since the last run:'
                      ' `{}`. Using the new value: `{}`.'.format(key, value))

    else:
        # If the experiment directory exists already, we bail in fear.
        if os.path.exists(args.experiment_root):
            if os.listdir(args.experiment_root):
                print('The directory {} already exists and is not empty.'
                      ' If you want to resume training, append --resume to'
                      ' your call.'.format(args.experiment_root))
                exit(1)
        else:
            os.makedirs(args.experiment_root)

        # Store the passed arguments for later resuming and grepping in a nice
        # and readable format.
        with open(args_file, 'w') as f:
            json.dump(vars(args),
                      f,
                      ensure_ascii=False,
                      indent=2,
                      sort_keys=True)

    log_file = os.path.join(args.experiment_root, "train")
    logging.config.dictConfig(common.get_logging_dict(log_file))
    log = logging.getLogger('train')

    # Also show all parameter values at the start, for ease of reading logs.
    log.info('Training using the following parameters:')
    for key, value in sorted(vars(args).items()):
        log.info('{}: {}'.format(key, value))

    # Check them here, so they are not required when --resume-ing.
    if not args.train_set:
        parser.print_help()
        log.error("You did not specify the `train_set` argument!")
        sys.exit(1)
    if not args.image_root:
        parser.print_help()
        log.error("You did not specify the required `image_root` argument!")
        sys.exit(1)

    # Load the data from the CSV file.
    pids, fids = common.load_dataset(args.train_set, args.image_root)
    max_fid_len = max(map(len, fids))  # We'll need this later for logfiles.

    # Setup a tf.Dataset where one "epoch" loops over all PIDS.
    # PIDS are shuffled after every epoch and continue indefinitely.
    unique_pids = np.unique(pids)
    dataset = tf.data.Dataset.from_tensor_slices(unique_pids)
    dataset = dataset.shuffle(len(unique_pids))

    # Constrain the dataset size to a multiple of the batch-size, so that
    # we don't get overlap at the end of each epoch.
    dataset = dataset.take((len(unique_pids) // args.batch_p) * args.batch_p)
    dataset = dataset.repeat(None)  # Repeat forever. Funny way of stating it.

    # For every PID, get K images.
    dataset = dataset.map(lambda pid: sample_k_fids_for_pid(
        pid, all_fids=fids, all_pids=pids, batch_k=args.batch_k))

    # Ungroup/flatten the batches for easy loading of the files.
    dataset = dataset.apply(tf.contrib.data.unbatch())

    # Convert filenames to actual image tensors.
    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)
    dataset = dataset.map(lambda fid, pid: common.fid_to_image(
        fid,
        pid,
        image_root=args.image_root,
        image_size=pre_crop_size if args.crop_augment else net_input_size),
                          num_parallel_calls=args.loading_threads)

    # Augment the data if specified by the arguments.
    if args.flip_augment:
        dataset = dataset.map(lambda im, fid, pid:
                              (tf.image.random_flip_left_right(im), fid, pid))
    if args.crop_augment:
        dataset = dataset.map(lambda im, fid, pid: (tf.random_crop(
            im, net_input_size + (3, )), fid, pid))

    # Group it back into PK batches.
    batch_size = args.batch_p * args.batch_k
    dataset = dataset.batch(batch_size)

    # Overlap producing and consuming for parallelism.
    dataset = dataset.prefetch(1)

    # Since we repeat the data infinitely, we only need a one-shot iterator.
    images, fids, pids = dataset.make_one_shot_iterator().get_next()

    # Create the model and an embedding head.
    model = import_module('nets.' + args.model_name)
    head = import_module('heads.' + args.head_name)

    # Feed the image through the model. The returned `body_prefix` will be used
    # further down to load the pre-trained weights for all variables with this
    # prefix.
    endpoints, body_prefix = model.endpoints(images, is_training=True)
    with tf.name_scope('head'):
        endpoints = head.head(endpoints, args.embedding_dim, is_training=True)

    # Create the loss in two steps:
    # 1. Compute all pairwise distances according to the specified metric.
    # 2. For each anchor along the first dimension, compute its loss.
    dists = loss.cdist(endpoints['emb'], endpoints['emb'], metric=args.metric)
    losses, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[
        args.loss](dists,
                   pids,
                   args.margin,
                   batch_precision_at_k=args.batch_k - 1)

    decDense = tf.layers.dense(
        inputs=endpoints['emb'], units=5120,
        name='decDense')  #  ,activation = tf.nn.relu  ################
    unflat = tf.reshape(decDense, shape=[tf.shape(decDense)[0], 32, 16, 10])
    unp3shape = tf.TensorShape(
        [2 * di for di in unflat.get_shape().as_list()[1:-1]])
    unPool3 = tf.image.resize_nearest_neighbor(unflat,
                                               unp3shape,
                                               name='unpool3')

    deConv3 = tf.layers.conv2d(inputs=unPool3,
                               filters=64,
                               kernel_size=[5, 5],
                               strides=(1, 1),
                               padding='same',
                               activation=tf.nn.relu,
                               name='deConv3')
    unp2shape = tf.TensorShape(
        [2 * di for di in deConv3.get_shape().as_list()[1:-1]])
    unPool2 = tf.image.resize_nearest_neighbor(deConv3,
                                               unp2shape,
                                               name='unpool2')

    deConv2 = tf.layers.conv2d(inputs=unPool2,
                               filters=32,
                               kernel_size=[5, 5],
                               strides=(1, 1),
                               padding='same',
                               activation=tf.nn.relu,
                               name='deConv2')
    unp1shape = tf.TensorShape(
        [2 * di for di in deConv2.get_shape().as_list()[1:-1]])
    unPool1 = tf.image.resize_nearest_neighbor(deConv2,
                                               unp1shape,
                                               name='unpool1')

    deConv1 = tf.layers.conv2d(inputs=unPool1,
                               filters=3,
                               kernel_size=[5, 5],
                               strides=(1, 1),
                               padding='same',
                               activation=None,
                               name='deConv1')
    imClip = deConv1  #tf.clip_by_value(t = deConv1,clip_value_min = -1.0,clip_value_max = 1.0,name='clipRelu')
    print('RconstructeddImage :  ', imClip.name)

    recLoss = tf.multiply(
        0.01, tf.losses.mean_squared_error(
            labels=images,
            predictions=imClip,
        ))
    print('recLoss :  ', recLoss.name)

    decDense1 = tf.layers.dense(
        inputs=endpoints['emb'], units=5120,
        name='decDense1')  #  ,activation = tf.nn.relu  ################
    unflat1 = tf.reshape(decDense1, shape=[tf.shape(decDense1)[0], 32, 16, 10])
    unp3shape1 = tf.TensorShape(
        [2 * di for di in unflat1.get_shape().as_list()[1:-1]])
    unPool3_new = tf.image.resize_nearest_neighbor(unflat1,
                                                   unp3shape1,
                                                   name='unpool3_new')

    deConv3_new = tf.layers.conv2d(inputs=unPool3_new,
                                   filters=64,
                                   kernel_size=[5, 5],
                                   strides=(1, 1),
                                   padding='same',
                                   activation=tf.nn.relu,
                                   name='deConv3_new')
    unp2shape_new = tf.TensorShape(
        [2 * di for di in deConv3_new.get_shape().as_list()[1:-1]])
    unPool2_new = tf.image.resize_nearest_neighbor(deConv3_new,
                                                   unp2shape_new,
                                                   name='unpool2_new')

    deConv2_new = tf.layers.conv2d(inputs=unPool2_new,
                                   filters=3,
                                   kernel_size=[5, 5],
                                   strides=(1, 1),
                                   padding='same',
                                   activation=tf.nn.relu,
                                   name='deConv2_new')
    unp1shape_new = tf.TensorShape(
        [2 * di for di in deConv2_new.get_shape().as_list()[1:-1]])
    unPool1_new = tf.image.resize_nearest_neighbor(deConv2_new,
                                                   unp1shape_new,
                                                   name='unpool1_new')

    deConv1_new = tf.layers.conv2d(inputs=unPool1_new,
                                   filters=3,
                                   kernel_size=[5, 5],
                                   strides=(1, 1),
                                   padding='same',
                                   activation=None,
                                   name='deConv1_new')
    imClip1 = deConv2_new
    print('RconstructeddImage :  ', imClip1.name)
    print(imClip1.shape)

    images2 = tf.image.resize_images(images, [128, 64])
    print(images2.shape)

    recLoss1 = tf.multiply(
        0.01,
        tf.losses.mean_squared_error(
            labels=images2,
            predictions=imClip1,
        ))
    print('recLoss_new :  ', recLoss1.name)

    decDense2 = tf.layers.dense(
        inputs=endpoints['emb'], units=5120,
        name='decDense2')  #  ,activation = tf.nn.relu  ################
    unflat12 = tf.reshape(decDense2,
                          shape=[tf.shape(decDense2)[0], 32, 16, 10])
    unp3shape12 = tf.TensorShape(
        [2 * di for di in unflat12.get_shape().as_list()[1:-1]])
    unPool3_new2 = tf.image.resize_nearest_neighbor(unflat12,
                                                    unp3shape12,
                                                    name='unpool3_new2')

    deConv3_new2 = tf.layers.conv2d(inputs=unPool3_new2,
                                    filters=3,
                                    kernel_size=[5, 5],
                                    strides=(1, 1),
                                    padding='same',
                                    activation=tf.nn.relu,
                                    name='deConv3_new2')
    unp2shape_new2 = tf.TensorShape(
        [2 * di for di in deConv3_new2.get_shape().as_list()[1:-1]])
    unPool2_new2 = tf.image.resize_nearest_neighbor(deConv3_new2,
                                                    unp2shape_new2,
                                                    name='unpool2_new2')

    imClip11 = deConv3_new2
    images21 = tf.image.resize_images(images, [64, 32])

    recLoss2 = tf.multiply(
        0.01,
        tf.losses.mean_squared_error(
            labels=images21,
            predictions=imClip11,
        ))
    print('recLoss_new :  ', recLoss2.name)

    decDensel = tf.layers.dense(
        inputs=endpoints['emb'], units=5120,
        name='decDensel')  #  ,activation = tf.nn.relu  ################
    unflatl = tf.reshape(decDensel, shape=[tf.shape(decDensel)[0], 32, 16, 10])
    unp3shapel = tf.TensorShape(
        [2 * di for di in unflatl.get_shape().as_list()[1:-1]])
    unPool3l = tf.image.resize_nearest_neighbor(unflatl,
                                                unp3shapel,
                                                name='unpool3l')

    deConv3l = tf.layers.conv2d(inputs=unPool3l,
                                filters=64,
                                kernel_size=[5, 5],
                                strides=(1, 1),
                                padding='same',
                                activation=tf.nn.relu,
                                name='deConv3l')
    unp2shapel = tf.TensorShape(
        [2 * di for di in deConv3l.get_shape().as_list()[1:-1]])
    unPool2l = tf.image.resize_nearest_neighbor(deConv3l,
                                                unp2shapel,
                                                name='unpool2l')

    deConv2l = tf.layers.conv2d(inputs=unPool2l,
                                filters=32,
                                kernel_size=[5, 5],
                                strides=(1, 1),
                                padding='same',
                                activation=tf.nn.relu,
                                name='deConv2l')
    unp1shapel = tf.TensorShape(
        [2 * di for di in deConv2l.get_shape().as_list()[1:-1]])
    unPool1l = tf.image.resize_nearest_neighbor(deConv2l,
                                                unp1shapel,
                                                name='unpool1l')

    deConv1l = tf.layers.conv2d(inputs=unPool1l,
                                filters=3,
                                kernel_size=[5, 5],
                                strides=(1, 1),
                                padding='same',
                                activation=None,
                                name='deConv1l')
    imClipl = deConv1l  #tf.clip_by_value(t = deConv1,clip_value_min = -1.0,clip_value_max = 1.0,name='clipRelu')
    print('RconstructeddImage :  ', imClipl.name)

    recLossl = tf.multiply(
        0.01, tf.losses.mean_squared_error(
            labels=images,
            predictions=imClipl,
        ))
    print('recLoss :  ', recLossl.name)

    # Count the number of active entries, and compute the total batch loss.
    num_active = tf.reduce_sum(tf.cast(tf.greater(losses, 1e-5), tf.float32))
    loss_mean = tf.reduce_mean(losses)

    # Some logging for tensorboard.
    tf.summary.histogram('loss_distribution', losses)
    tf.summary.scalar('loss', loss_mean)
    tf.summary.scalar('batch_top1', train_top1)
    tf.summary.scalar('batch_prec_at_{}'.format(args.batch_k - 1), prec_at_k)
    tf.summary.scalar('active_count', num_active)
    tf.summary.histogram('embedding_dists', dists)
    tf.summary.histogram('embedding_pos_dists', pos_dists)
    tf.summary.histogram('embedding_neg_dists', neg_dists)
    tf.summary.histogram('embedding_lengths',
                         tf.norm(endpoints['emb_raw'], axis=1))

    # Create the mem-mapped arrays in which we'll log all training detail in
    # addition to tensorboard, because tensorboard is annoying for detailed
    # inspection and actually discards data in histogram summaries.
    if args.detailed_logs:
        log_embs = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'embeddings'),
            dtype=np.float32,
            shape=(args.train_iterations, batch_size, args.embedding_dim))
        log_loss = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'losses'),
            dtype=np.float32,
            shape=(args.train_iterations, batch_size))
        log_fids = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'fids'),
            dtype='S' + str(max_fid_len),
            shape=(args.train_iterations, batch_size))

    # These are collected here before we add the optimizer, because depending
    # on the optimizer, it might add extra slots, which are also global
    # variables, with the exact same prefix.
    model_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                        body_prefix)

    # Define the optimizer and the learning-rate schedule.
    # Unfortunately, we get NaNs if we don't handle no-decay separately.
    global_step = tf.Variable(0, name='global_step', trainable=False)
    if 0 <= args.decay_start_iteration < args.train_iterations:
        learning_rate = tf.train.exponential_decay(
            args.learning_rate,
            tf.maximum(0, global_step - args.decay_start_iteration),
            args.train_iterations - args.decay_start_iteration, 0.001)
    else:
        learning_rate = args.learning_rate
    tf.summary.scalar('learning_rate', learning_rate)
    optimizer = tf.train.AdamOptimizer(learning_rate)
    # Feel free to try others!
    # optimizer = tf.train.AdadeltaOptimizer(learning_rate)

    # Update_ops are used to update batchnorm stats.
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        train_op = optimizer.minimize(tf.add(
            loss_mean,
            tf.add(recLoss, tf.add(recLoss1, tf.add(recLoss2, recLossl)))),
                                      global_step=global_step)

    # Define a saver for the complete model.
    checkpoint_saver = tf.train.Saver(max_to_keep=0)

    with tf.Session() as sess:
        if args.resume:
            # In case we're resuming, simply load the full checkpoint to init.
            last_checkpoint = tf.train.latest_checkpoint(args.experiment_root)
            log.info('Restoring from checkpoint: {}'.format(last_checkpoint))
            checkpoint_saver.restore(sess, last_checkpoint)
        else:
            # But if we're starting from scratch, we may need to load some
            # variables from the pre-trained weights, and random init others.
            sess.run(tf.global_variables_initializer())
            if args.initial_checkpoint is not None:
                saver = tf.train.Saver(model_variables)
                saver.restore(sess, args.initial_checkpoint)

            # In any case, we also store this initialization as a checkpoint,
            # such that we could run exactly reproduceable experiments.
            checkpoint_saver.save(sess,
                                  os.path.join(args.experiment_root,
                                               'checkpoint'),
                                  global_step=0)

        merged_summary = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(args.experiment_root,
                                               sess.graph)

        start_step = sess.run(global_step)
        log.info('Starting training from iteration {}.'.format(start_step))

        # Finally, here comes the main-loop. This `Uninterrupt` is a handy
        # utility such that an iteration still finishes on Ctrl+C and we can
        # stop the training cleanly.
        with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u:
            for i in range(start_step, args.train_iterations):

                # Compute gradients, update weights, store logs!
                start_time = time.time()
                _, summary, step, b_prec_at_k, b_embs, b_loss, b_fids ,b_rec, b_rec1= \
                    sess.run([train_op, merged_summary, global_step,
                              prec_at_k, endpoints['emb'], losses, fids,recLoss, recLoss1])
                elapsed_time = time.time() - start_time

                # Compute the iteration speed and add it to the summary.
                # We did observe some weird spikes that we couldn't track down.
                summary2 = tf.Summary()
                summary2.value.add(tag='secs_per_iter',
                                   simple_value=elapsed_time)
                summary_writer.add_summary(summary2, step)
                summary_writer.add_summary(summary, step)

                if args.detailed_logs:
                    log_embs[i], log_loss[i], log_fids[
                        i] = b_embs, b_loss, b_fids

                # Do a huge print out of the current progress.
                seconds_todo = (args.train_iterations - step) * elapsed_time
                log.info(
                    'iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, '
                    'recLoss: {:.3f} batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it)'.
                    format(step, float(np.min(b_loss)), float(np.mean(b_loss)),
                           float(np.max(b_loss)), b_rec, args.batch_k - 1,
                           float(b_prec_at_k),
                           timedelta(seconds=int(seconds_todo)), elapsed_time))
                sys.stdout.flush()
                sys.stderr.flush()

                # Save a checkpoint of training every so often.
                if (args.checkpoint_frequency > 0
                        and step % args.checkpoint_frequency == 0):
                    checkpoint_saver.save(sess,
                                          os.path.join(args.experiment_root,
                                                       'checkpoint'),
                                          global_step=step)

                # Stop the main-loop at the end of the step, if requested.
                if u.interrupted:
                    log.info("Interrupted on request!")
                    break

        # Store one final checkpoint. This might be redundant, but it is crucial
        # in case intermediate storing was disabled and it saves a checkpoint
        # when the process was interrupted.
        checkpoint_saver.save(sess,
                              os.path.join(args.experiment_root, 'checkpoint'),
                              global_step=step)
示例#4
0
def main(argv):
    args = parser.parse_args(argv)

    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    # We store all arguments in a json file. This has two advantages:
    # 1. We can always get back and see what exactly that experiment was
    # 2. We can resume an experiment as-is without needing to remember all flags.
    args_file = os.path.join(args.experiment_root, 'args.json')
    if args.resume:
        if not os.path.isfile(args_file):
            raise IOError('`args.json` not found in {}'.format(args_file))

        print('Loading args from {}.'.format(args_file))
        with open(args_file, 'r') as f:
            args_resumed = json.load(f)
        args_resumed['resume'] = True  # This would be overwritten.

        # When resuming, we not only want to populate the args object with the
        # values from the file, but we also want to check for some possible
        # conflicts between loaded and given arguments.
        for key, value in args.__dict__.items():
            if key in args_resumed:
                resumed_value = args_resumed[key]
                if resumed_value != value:
                    print('Warning: For the argument `{}` we are using the'
                          ' loaded value `{}`. The provided value was `{}`'
                          '.'.format(key, resumed_value, value))
                    args.__dict__[key] = resumed_value
            else:
                print('Warning: A new argument was added since the last run:'
                      ' `{}`. Using the new value: `{}`.'.format(key, value))

    else:
        # If the experiment directory exists already, we bail in fear.
        if os.path.exists(args.experiment_root):
            if os.listdir(args.experiment_root):
                print('The directory {} already exists and is not empty.'
                      ' If you want to resume training, append --resume to'
                      ' your call.'.format(args.experiment_root))
                exit(1)
        else:
            os.makedirs(args.experiment_root)

        # Store the passed arguments for later resuming and grepping in a nice
        # and readable format.
        with open(args_file, 'w') as f:
            json.dump(vars(args),
                      f,
                      ensure_ascii=False,
                      indent=2,
                      sort_keys=True)

    log_file = os.path.join(args.experiment_root, "train")
    logging.config.dictConfig(common.get_logging_dict(log_file))
    log = logging.getLogger('train')

    # Also show all parameter values at the start, for ease of reading logs.
    log.info('Training using the following parameters:')
    for key, value in sorted(vars(args).items()):
        log.info('{}: {}'.format(key, value))

    # Check them here, so they are not required when --resume-ing.
    if not args.train_set:
        parser.print_help()
        log.error("You did not specify the `train_set` argument!")
        sys.exit(1)
    if not args.image_root:
        parser.print_help()
        log.error("You did not specify the required `image_root` argument!")
        sys.exit(1)

    # Load the data from the CSV file.
    pids, fids = common.load_dataset(args.train_set, args.image_root)
    max_fid_len = max(map(len, fids))  # We'll need this later for logfiles.

    # Setup a tf.Dataset where one "epoch" loops over all PIDS.
    # PIDS are shuffled after every epoch and continue indefinitely.
    unique_pids = np.unique(pids)
    if len(unique_pids) < args.batch_p:
        unique_pids = np.tile(unique_pids,
                              int(np.ceil(args.batch_p / len(unique_pids))))
    dataset = tf.data.Dataset.from_tensor_slices(unique_pids)
    dataset = dataset.shuffle(len(unique_pids))

    # Constrain the dataset size to a multiple of the batch-size, so that
    # we don't get overlap at the end of each epoch.
    dataset = dataset.take((len(unique_pids) // args.batch_p) * args.batch_p)
    dataset = dataset.repeat(None)  # Repeat forever. Funny way of stating it.

    # For every PID, get K images.
    dataset = dataset.map(lambda pid: sample_k_fids_for_pid(
        pid, all_fids=fids, all_pids=pids, batch_k=args.batch_k))

    # Ungroup/flatten the batches for easy loading of the files.
    dataset = dataset.apply(tf.contrib.data.unbatch())

    # Convert filenames to actual image tensors.
    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)

    dataset = dataset.map(lambda fid, pid: common.fid_to_image(
        fid,
        pid,
        image_root=args.image_root,
        image_size=pre_crop_size if args.crop_augment else net_input_size),
                          num_parallel_calls=args.loading_threads)

    # Augment the data if specified by the arguments.

    dataset = dataset.map(
        lambda im, fid, pid: common.fid_to_image(fid,
                                                 pid,
                                                 image_root=args.image_root,
                                                 image_size=pre_crop_size
                                                 if args.crop_augment else
                                                 net_input_size),  # Ergys
        num_parallel_calls=args.loading_threads)

    if args.flip_augment:
        dataset = dataset.map(lambda im, fid, pid:
                              (tf.image.random_flip_left_right(im), fid, pid))
    if args.crop_augment:
        dataset = dataset.map(lambda im, fid, pid: (tf.random_crop(
            im, net_input_size + (3, )), fid, pid))

    # Group it back into PK batches.
    batch_size = args.batch_p * args.batch_k
    dataset = dataset.batch(batch_size)

    # Overlap producing and consuming for parallelism.
    dataset = dataset.prefetch(1)

    # Since we repeat the data infinitely, we only need a one-shot iterator.
    images, fids, pids = dataset.make_one_shot_iterator().get_next()

    # Create the model and an embedding head.
    model = import_module('nets.' + args.model_name)
    head = import_module('heads.' + args.head_name)

    # Feed the image through the model. The returned `body_prefix` will be used
    # further down to load the pre-trained weights for all variables with this
    # prefix.

    weight_decay = 10e-4
    weights_regularizer = tf.contrib.layers.l2_regularizer(scale=weight_decay)
    endpoints, body_prefix = model.endpoints(images, is_training=True)
    with tf.name_scope('head'):
        endpoints = head.head(endpoints,
                              args.embedding_dim,
                              is_training=True,
                              weights_regularizer=weights_regularizer)

    # Create the loss in two steps:
    # 1. Compute all pairwise distances according to the specified metric.
    # 2. For each anchor along the first dimension, compute its loss.
    # batch_embedding = endpoints['emb']
    batch_embedding = endpoints['emb']
    if args.loss == 'semi_hard_triplet':
        triplet_loss = triplet_semihard_loss(batch_embedding, pids,
                                             args.margin)
    elif args.loss == 'hard_triplet':
        triplet_loss = batch_hard(batch_embedding, pids, args.margin,
                                  args.metric)
    elif args.loss == 'lifted_loss':
        triplet_loss = lifted_loss(pids, batch_embedding, margin=args.margin)
    elif args.loss == 'contrastive_loss':
        assert batch_size % 2 == 0
        assert args.batch_k == 4  ## Can work with other number but will need tuning

        contrastive_idx = np.tile([0, 1, 4, 3, 2, 5, 6, 7], args.batch_p // 2)
        for i in range(args.batch_p // 2):
            contrastive_idx[i * 8:i * 8 + 8] += i * 8

        contrastive_idx = np.expand_dims(contrastive_idx, 1)
        batch_embedding_ordered = tf.gather_nd(batch_embedding,
                                               contrastive_idx)
        pids_ordered = tf.gather_nd(pids, contrastive_idx)
        # batch_embedding_ordered = tf.Print(batch_embedding_ordered,[pids_ordered],'pids_ordered :: ',summarize=1000)
        embeddings_anchor, embeddings_positive = tf.unstack(
            tf.reshape(batch_embedding_ordered, [-1, 2, args.embedding_dim]),
            2, 1)
        # embeddings_anchor = tf.Print(embeddings_anchor,[pids_ordered,embeddings_anchor,embeddings_positive,batch_embedding,batch_embedding_ordered],"Tensors ", summarize=1000)

        fixed_labels = np.tile([1, 0, 0, 1], args.batch_p // 2)
        # fixed_labels = np.reshape(fixed_labels,(len(fixed_labels),1))
        # print(fixed_labels)
        labels = tf.constant(fixed_labels)
        # labels = tf.Print(labels,[labels],'labels ',summarize=1000)
        triplet_loss = contrastive_loss(labels,
                                        embeddings_anchor,
                                        embeddings_positive,
                                        margin=args.margin)
    elif args.loss == 'angular_loss':
        embeddings_anchor, embeddings_positive = tf.unstack(
            tf.reshape(batch_embedding, [-1, 2, args.embedding_dim]), 2, 1)
        # pids = tf.Print(pids, [pids], 'pids:: ', summarize=100)
        pids, _ = tf.unstack(tf.reshape(pids, [-1, 2, 1]), 2, 1)
        # pids = tf.Print(pids,[pids],'pids:: ',summarize=100)
        triplet_loss = angular_loss(pids,
                                    embeddings_anchor,
                                    embeddings_positive,
                                    batch_size=args.batch_p,
                                    with_l2reg=True)
    elif args.loss == 'npairs_loss':
        assert args.batch_k == 2  ## Single positive pair per class
        embeddings_anchor, embeddings_positive = tf.unstack(
            tf.reshape(batch_embedding, [-1, 2, args.embedding_dim]), 2, 1)
        pids, _ = tf.unstack(tf.reshape(pids, [-1, 2, 1]), 2, 1)
        pids = tf.reshape(pids, [-1])
        triplet_loss = npairs_loss(pids, embeddings_anchor,
                                   embeddings_positive)
    else:
        raise NotImplementedError('loss function {} NotImplemented'.format(
            args.loss))

    loss_mean = tf.reduce_mean(triplet_loss)

    # These are collected here before we add the optimizer, because depending
    # on the optimizer, it might add extra slots, which are also global
    # variables, with the exact same prefix.
    model_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                        body_prefix)

    # Define the optimizer and the learning-rate schedule.
    # Unfortunately, we get NaNs if we don't handle no-decay separately.
    global_step = tf.Variable(0, name='global_step', trainable=False)

    if 0 <= args.decay_start_iteration < args.train_iterations:
        learning_rate = tf.train.polynomial_decay(args.learning_rate,
                                                  global_step,
                                                  args.train_iterations,
                                                  end_learning_rate=1e-7,
                                                  power=1)
    else:
        learning_rate = args.learning_rate

    if args.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate)
    elif args.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.9)
    else:
        raise NotImplementedError('Invalid optimizer {}'.format(
            args.optimizer))
    #
    # learning_rate = tf.train.polynomial_decay(args.learning_rate, global_step,
    #                                           args.train_iterations, end_learning_rate= 1e-7,
    #                                           power=1)
    #

    # Feel free to try others!
    # optimizer = tf.train.AdadeltaOptimizer(learning_rate)

    # Update_ops are used to update batchnorm stats.
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        train_op = optimizer.minimize(loss_mean, global_step=global_step)

    # Define a saver for the complete model.
    checkpoint_saver = tf.train.Saver(max_to_keep=2)
    gpu_options = tf.GPUOptions(allow_growth=True)
    gpu_config = tf.ConfigProto(gpu_options=gpu_options)
    with tf.Session(config=gpu_config) as sess:
        if args.resume:
            # In case we're resuming, simply load the full checkpoint to init.
            last_checkpoint = tf.train.latest_checkpoint(args.experiment_root)
            if last_checkpoint == None:
                print('Resume with No previous checkpoint')

                # But if we're starting from scratch, we may need to load some
                # variables from the pre-trained weights, and random init others.
                sess.run(tf.global_variables_initializer())
                if args.initial_checkpoint is not None:
                    saver = tf.train.Saver(model_variables)
                    saver.restore(sess, args.initial_checkpoint)

                # In any case, we also store this initialization as a checkpoint,
                # such that we could run exactly reproduceable experiments.
                checkpoint_saver.save(sess,
                                      os.path.join(args.experiment_root,
                                                   'checkpoint'),
                                      global_step=0)

            else:
                log.info(
                    'Restoring from checkpoint: {}'.format(last_checkpoint))
                checkpoint_saver.restore(sess, last_checkpoint)
        else:
            # But if we're starting from scratch, we may need to load some
            # variables from the pre-trained weights, and random init others.
            sess.run(tf.global_variables_initializer())
            if args.initial_checkpoint is not None:
                saver = tf.train.Saver(model_variables)
                saver.restore(sess, args.initial_checkpoint)

            # In any case, we also store this initialization as a checkpoint,
            # such that we could run exactly reproduceable experiments.
            checkpoint_saver.save(sess,
                                  os.path.join(args.experiment_root,
                                               'checkpoint'),
                                  global_step=0)

        start_step = sess.run(global_step)
        log.info('Starting training from iteration {}.'.format(start_step))

        # Finally, here comes the main-loop. This `Uninterrupt` is a handy
        # utility such that an iteration still finishes on Ctrl+C and we can
        # stop the training cleanly.
        with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u:
            for i in range(start_step, args.train_iterations):

                # Compute gradients, update weights, store logs!
                start_time = time.time()
                _,  step, b_embs, b_loss, b_fids = \
                    sess.run([train_op, global_step, endpoints['emb'], triplet_loss, fids])
                elapsed_time = time.time() - start_time

                # Do a huge print out of the current progress.
                seconds_todo = (args.train_iterations - step) * elapsed_time
                log.info(
                    'iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, ETA: {} ({:.2f}s/it)'
                    .format(
                        step,
                        float(np.min(b_loss)),
                        float(np.mean(b_loss)),
                        float(np.max(b_loss)),
                        # args.batch_k - 1, float(b_prec_at_k),
                        timedelta(seconds=int(seconds_todo)),
                        elapsed_time))
                sys.stdout.flush()
                sys.stderr.flush()

                # Save a checkpoint of training every so often.
                if (args.checkpoint_frequency > 0
                        and step % args.checkpoint_frequency == 0):
                    checkpoint_saver.save(sess,
                                          os.path.join(args.experiment_root,
                                                       'checkpoint'),
                                          global_step=step)

                # Stop the main-loop at the end of the step, if requested.
                if u.interrupted:
                    log.info("Interrupted on request!")
                    break

        # Store one final checkpoint. This might be redundant, but it is crucial
        # in case intermediate storing was disabled and it saves a checkpoint
        # when the process was interrupted.
        checkpoint_saver.save(sess,
                              os.path.join(args.experiment_root, 'checkpoint'),
                              global_step=step)
示例#5
0
    tf.global_variables_initializer().run()

    if resume:
        last_checkpoint = tf.train.latest_checkpoint(tr.save_path)
        saver.restore(sess, last_checkpoint)
        start_step = sess.run(global_step)
        logger.debug('Resume training ... Start from step %d / %d .' %
                     (start_step, train_nums))
        resume = False
    else:
        start_step = 0

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u:
        for i in range(start_step, train_nums):

            _, loss_value, step, ttt = sess.run(
                [train_op, loss, global_step, tt])
            if i % print_steps == 0:
                top1, top5, top10 = sess.run([
                    accuracy_top1_batch, accuracy_top5_batch,
                    accuracy_top10_batch
                ])
                logger.debug(
                    "After %d training step(s),loss on training batch is %g.The batch test accuracy = %g , %g ,%g."
                    % (i, loss_value, top1, top5, top10))
                '''
            losslist.append([step,loss_value])
            accuracy.append([step,top1])
示例#6
0
def main(args):

    best_acc = -1

    logger = bit_common.setup_logger(args)
    cp, cn = smooth_BCE(eps=0.1)
    # Lets cuDNN benchmark conv implementations and choose the fastest.
    # Only good if sizes stay the same within the main loop!
    torch.backends.cudnn.benchmark = True

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Going to train on {device}")

    classes = 5

    train_set, valid_set, train_loader, valid_loader = mktrainval(args, logger)
    logger.info(f"Loading model from {args.model}.npz")
    #model = models.KNOWN_MODELS[args.model](head_size=classes, zero_head=True)
    #model.load_from(np.load(f"{args.model}.npz"))

    model = EfficientNet.from_pretrained(args.model, num_classes=classes)
    logger.info("Moving model onto all GPUs")
    model = torch.nn.DataParallel(model)

    # Optionally resume from a checkpoint.
    # Load it to CPU first as we'll move the model to GPU later.
    # This way, we save a little bit of GPU memory when loading.
    start_epoch = 0

    # Note: no weight-decay!
    optim = torch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9)

    # Resume fine-tuning if we find a saved model.
    savename = pjoin(args.logdir, args.name, "bit.pth.tar")
    try:
        logger.info(f"Model will be saved in '{savename}'")
        checkpoint = torch.load(savename, map_location="cpu")
        logger.info(f"Found saved model to resume from at '{savename}'")

        start_epoch = checkpoint["epoch"]
        model.load_state_dict(checkpoint["model"])
        optim.load_state_dict(checkpoint["optim"])
        logger.info(f"Resumed at epoch {start_epoch}")
    except FileNotFoundError:
        logger.info("Fine-tuning from BiT")

    model = model.to(device)
    optim.zero_grad()

    model.train()
    mixup = bit_hyperrule.get_mixup(len(train_set))
    #mixup = -1
    cri = torch.nn.CrossEntropyLoss().to(device)
    #cri = FocalLoss(cri)
    logger.info("Starting training!")
    chrono = lb.Chrono()
    accum_steps = 0
    mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1
    end = time.time()

    epoches = 10
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optim,
                                                    max_lr=0.01,
                                                    steps_per_epoch=1,
                                                    epochs=epoches)

    with lb.Uninterrupt() as u:
        for epoch in range(start_epoch, epoches):

            pbar = enumerate(train_loader)
            pbar = tqdm.tqdm(pbar, total=len(train_loader))

            scheduler.step()
            all_top1, all_top5 = [], []
            for param_group in optim.param_groups:
                lr = param_group["lr"]
            #for x, y in recycle(train_loader):
            for batch_id, (x, y) in pbar:
                #for batch_id, (x, y) in enumerate(train_loader):
                # measure data loading time, which is spent in the `for` statement.
                chrono._done("load", time.time() - end)

                if u.interrupted:
                    break

                # Schedule sending to GPU(s)
                x = x.to(device, non_blocking=True)
                y = y.to(device, non_blocking=True)

                # Update learning-rate, including stop training if over.
                #lr = bit_hyperrule.get_lr(step, len(train_set), args.base_lr)
                #if lr is None:
                #  break

                if mixup > 0.0:
                    x, y_a, y_b = mixup_data(x, y, mixup_l)

                # compute output
                with chrono.measure("fprop"):
                    logits = model(x)
                    top1, top5 = topk(logits, y, ks=(1, 5))
                    all_top1.extend(top1.cpu())
                    all_top5.extend(top5.cpu())
                    if mixup > 0.0:
                        c = mixup_criterion(cri, logits, y_a, y_b, mixup_l)
                    else:
                        c = cri(logits, y)
                train_loss = c.item()
                train_acc = np.mean(all_top1) * 100.0
                # Accumulate grads
                with chrono.measure("grads"):
                    (c / args.batch_split).backward()
                    accum_steps += 1
                accstep = f"({accum_steps}/{args.batch_split})" if args.batch_split > 1 else ""
                s = f"epoch={epoch} batch {batch_id}{accstep}: loss={train_loss:.5f} train_acc={train_acc:.2f} lr={lr:.1e}"
                #s = f"epoch={epoch} batch {batch_id}{accstep}: loss={c.item():.5f} lr={lr:.1e}"
                pbar.set_description(s)
                #logger.info(f"[batch {batch_id}{accstep}]: loss={c_num:.5f} (lr={lr:.1e})")  # pylint: disable=logging-format-interpolation
                logger.flush()

                # Update params
                with chrono.measure("update"):
                    optim.step()
                    optim.zero_grad()
                # Sample new mixup ratio for next batch
                mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1

            # Run evaluation and save the model.
            val_loss, val_acc = run_eval(model, valid_loader, device, chrono,
                                         logger, epoch)

            best = val_acc > best_acc
            if best:
                best_acc = val_acc
                torch.save(
                    {
                        "epoch": epoch,
                        "val_loss": val_loss,
                        "val_acc": val_acc,
                        "train_acc": train_acc,
                        "model": model.state_dict(),
                        "optim": optim.state_dict(),
                    }, savename)
            end = time.time()

    logger.info(f"Timings:\n{chrono}")
示例#7
0
def main():
    args = parser.parse_args()

    # We store all arguments in a json file. This has two advantages:       我们将所有参数存储在json文件中。 这有两个好处:
    # 1. We can always get back and see what exactly that experiment was    1.我们总是可以回头看看实验是什么
    # 2. We can resume an experiment as-is without needing to remember all flags.2.我们可以按原样恢复实验,无需记住所有标志。
    args_file = os.path.join(args.experiment_root, 'args.json')
    if args.resume:
        if not os.path.isfile(args_file):
            raise IOError('`args.json` not found in {}'.format(args_file))

        print('Loading args from {}.'.format(args_file))
        with open(args_file, 'r') as f:
            args_resumed = json.load(f)
        args_resumed['resume'] = True  # This would be overwritten.

        # When resuming, we not only want to populate the args object with the
        # values from the file, but we also want to check for some possible
        # conflicts between loaded and given arguments.
        #恢复时,我们不仅需要使用文件中的值填充args对象,但我们也想检查加载参数和给定参数之间的一些可能的冲突。
        for key, value in args.__dict__.items():
            if key in args_resumed:
                resumed_value = args_resumed[key]
                if resumed_value != value:
                    print('Warning: For the argument `{}` we are using the'
                          ' loaded value `{}`. The provided value was `{}`'
                          '.'.format(key, resumed_value, value))
                    args.__dict__[key] = resumed_value
            else:
                print('Warning: A new argument was added since the last run:'
                      ' `{}`. Using the new value: `{}`.'.format(key, value))

    else:
        # If the experiment directory exists already, we bail in fear.如果实验目录已经存在,我们就会担心。
        if os.path.exists(args.experiment_root):
            if os.listdir(args.experiment_root):
                print('The directory {} already exists and is not empty.'
                      ' If you want to resume training, append --resume to'
                      ' your call.'.format(args.experiment_root))
                exit(1)
        else:
            os.makedirs(args.experiment_root)

        # Store the passed arguments for later resuming and grepping in a nice
        # and readable format.将传递的参数存储起来,以便稍后以一种很好且可读的格式恢复和刷新。
        with open(args_file, 'w') as f:
            json.dump(vars(args),
                      f,
                      ensure_ascii=False,
                      indent=2,
                      sort_keys=True)

    log_file = os.path.join(args.experiment_root, "train")
    logging.config.dictConfig(common.get_logging_dict(log_file))
    log = logging.getLogger('train')

    # Also show all parameter values at the start, for ease of reading logs.
    #同时在开始时显示所有参数值,便于读取日志。
    log.info('Training using the following parameters:')
    for key, value in sorted(vars(args).items()):
        log.info('{}: {}'.format(key, value))

    # Check them here, so they are not required when --resume-ing.在这里检查他们,
    if not args.train_set:
        parser.print_help()
        log.error("You did not specify the `train_set` argument!")
        sys.exit(1)
    if not args.image_root:
        parser.print_help()
        log.error("You did not specify the required `image_root` argument!")
        sys.exit(1)

    # Load the data from the CSV file.                               加载CSV文件中的数据。
    pids, fids = common.load_dataset(args.train_set, args.image_root)
    max_fid_len = max(map(
        len, fids))  # We'll need this later for logfiles.我们稍后需要这个日志文件。

    # Setup a tf.Dataset where one "epoch" loops over all PIDS. 设置一个tf.Dataset,其中一个“epoch”在所有PIDS上循环。
    # PIDS are shuffled after every epoch and continue indefinitely.PIDS在每个时代之后都会被洗牌并无限期地继续下去。

    unique_pids = np.unique(pids)
    dataset = tf.data.Dataset.from_tensor_slices(unique_pids)
    dataset = dataset.shuffle(len(unique_pids))

    # Constrain the dataset size to a multiple of the batch-size, so that
    # we don't get overlap at the end of each epoch.将数据集大小限制为批量的倍数,以便在每个时期结束时我们不会重叠。
    dataset = dataset.take((len(unique_pids) // args.batch_p) * args.batch_p)
    dataset = dataset.repeat(
        None)  # Repeat forever. Funny way of stating it.  永远重复。 说明它的有趣方式。

    # For every PID, get K images.对于每个PID,获得K个图像。
    dataset = dataset.map(lambda pid: sample_k_fids_for_pid(
        pid, all_fids=fids, all_pids=pids, batch_k=args.batch_k))

    # Ungroup/flatten the batches for easy loading of the files.取消组合/拼合批次以便轻松加载文件。
    dataset = dataset.apply(tf.contrib.data.unbatch())

    # Convert filenames to actual image tensors.                 将文件名转换为实际图像张量。
    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)
    dataset = dataset.map(lambda fid, pid: common.fid_to_image(
        fid,
        pid,
        image_root=args.image_root,
        image_size=pre_crop_size if args.crop_augment else net_input_size),
                          num_parallel_calls=args.loading_threads)

    # Augment the data if specified by the arguments.            如果由参数指定,则增加数据。
    if args.flip_augment:
        dataset = dataset.map(lambda im, fid, pid:
                              (tf.image.random_flip_left_right(im), fid, pid))
    if args.crop_augment:
        dataset = dataset.map(lambda im, fid, pid: (tf.random_crop(
            im, net_input_size + (3, )), fid, pid))

    # Group it back into PK batches.                将其重新分组为PK批次。
    batch_size = args.batch_p * args.batch_k
    dataset = dataset.batch(batch_size)

    # Overlap producing and consuming for parallelism.
    dataset = dataset.prefetch(1)

    # Since we repeat the data infinitely, we only need a one-shot iterator.由于我们无限地重复数据,因此我们只需要一个one-shot 迭代器。
    images, fids, pids = dataset.make_one_shot_iterator().get_next()

    # Create the model and an embedding head.
    model = import_module('nets.' + args.model_name)
    head = import_module('heads.' + args.head_name)

    # Feed the image through the model. The returned `body_prefix` will be used
    # further down to load the pre-trained weights for all variables with this
    # prefix.
    #通过模型提供图像。返回的`body_prefix`将进一步用于加载具有此前缀的所有变量的预先训练权重。
    endpoints, body_prefix = model.endpoints(images, is_training=True)
    with tf.name_scope('head'):
        endpoints = head.head(endpoints, args.embedding_dim, is_training=True)

    # Create the loss in two steps:                                       分两步创建损失:
    # 1. Compute all pairwise distances according to the specified metric.1.根据指定的度量计算所有成对距离。
    # 2. For each anchor along the first dimension, compute its loss.     2.对于第一维中的每个锚,计算其损失。
    dists = loss.cdist(endpoints['emb'], endpoints['emb'], metric=args.metric)
    losses, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[
        args.loss](dists,
                   pids,
                   args.margin,
                   batch_precision_at_k=args.batch_k - 1)

    # Count the number of active entries, and compute the total batch loss.  计算活动条目的数量,并计算总批量损失。
    num_active = tf.reduce_sum(tf.cast(tf.greater(losses, 1e-5), tf.float32))
    loss_mean = tf.reduce_mean(losses)

    # Some logging for tensorboard.            一些日志记录在tensorboard。
    tf.summary.histogram('loss_distribution', losses)
    tf.summary.scalar('loss', loss_mean)
    tf.summary.scalar('batch_top1', train_top1)
    tf.summary.scalar('batch_prec_at_{}'.format(args.batch_k - 1), prec_at_k)
    tf.summary.scalar('active_count', num_active)
    tf.summary.histogram('embedding_dists', dists)
    tf.summary.histogram('embedding_pos_dists', pos_dists)
    tf.summary.histogram('embedding_neg_dists', neg_dists)
    tf.summary.histogram('embedding_lengths',
                         tf.norm(endpoints['emb_raw'], axis=1))

    # Create the mem-mapped arrays in which we'll log all training detail in
    # addition to tensorboard, because tensorboard is annoying for detailed
    # inspection and actually discards data in histogram summaries.
    #创建mem - mapped数组,我们将记录除tensorboard之外的所有训练细节,因为tensorboard对于详细检查很烦人,实际上在直方图总结中丢弃了数据。

    if args.detailed_logs:
        log_embs = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'embeddings'),
            dtype=np.float32,
            shape=(args.train_iterations, batch_size, args.embedding_dim))
        log_loss = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'losses'),
            dtype=np.float32,
            shape=(args.train_iterations, batch_size))
        log_fids = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'fids'),
            dtype='S' + str(max_fid_len),
            shape=(args.train_iterations, batch_size))

    # These are collected here before we add the optimizer, because depending
    # on the optimizer, it might add extra slots, which are also global
    # variables, with the exact same prefix.
    #在我们添加优化器之前收集这些信息,因为根据优化器的不同,它可能会添加额外的插槽,这些插槽也是全局变量,具有完全相同的前缀。
    model_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                        body_prefix)

    # Define the optimizer and the learning-rate schedule.              定义优化器和学习率计划。
    # Unfortunately, we get NaNs if we don't handle no-decay separately. 不幸的是,如果我们不单独处理无衰减,我们会得到NaNs。
    global_step = tf.Variable(0, name='global_step', trainable=False)
    if 0 <= args.decay_start_iteration < args.train_iterations:
        learning_rate = tf.train.exponential_decay(
            args.learning_rate,
            tf.maximum(0, global_step - args.decay_start_iteration),
            args.train_iterations - args.decay_start_iteration, 0.001)
    else:
        learning_rate = args.learning_rate
    tf.summary.scalar('learning_rate', learning_rate)
    optimizer = tf.train.AdamOptimizer(learning_rate)
    # Feel free to try others!                               随意尝试别的的优化器!
    # optimizer = tf.train.AdadeltaOptimizer(learning_rate)

    # Update_ops are used to update batchnorm stats.  Update_ops用于更新batchnorm统计信息。
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        train_op = optimizer.minimize(loss_mean, global_step=global_step)

    # Define a saver for the complete model.  为整个模型定义一个保存器。
    checkpoint_saver = tf.train.Saver(max_to_keep=0)

    with tf.Session() as sess:
        if args.resume:
            # In case we're resuming, simply load the full checkpoint to init.如果我们正在恢复,只需将完整的检查点加载到init。
            last_checkpoint = tf.train.latest_checkpoint(args.experiment_root)
            log.info('Restoring from checkpoint: {}'.format(last_checkpoint))
            checkpoint_saver.restore(sess, last_checkpoint)
        else:
            # But if we're starting from scratch, we may need to load some
            # variables from the pre-trained weights, and random init others.
            #但是如果我们从头开始,我们可能需要从预先训练的权重中加载一些变量,并随机初始化其他的变量。
            sess.run(tf.global_variables_initializer())
            if args.initial_checkpoint is not None:
                saver = tf.train.Saver(model_variables)
                saver.restore(sess, args.initial_checkpoint)

            # In any case, we also store this initialization as a checkpoint,
            # such that we could run exactly reproduceable experiments.
            #无论如何,我们也将这个初始化作为检查点存储,以便我们可以运行完全可再生的实验。
            checkpoint_saver.save(sess,
                                  os.path.join(args.experiment_root,
                                               'checkpoint'),
                                  global_step=0)

        merged_summary = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(args.experiment_root,
                                               sess.graph)

        start_step = sess.run(global_step)
        log.info('Starting training from iteration {}.'.format(start_step))

        # Finally, here comes the main-loop. This `Uninterrupt` is a handy
        # utility such that an iteration still finishes on Ctrl+C and we can
        # stop the training cleanly.
        #最后,这里是主循环。这个`Uninterrupt`是一个非常方便的工具,可以在Ctrl + C之后完成迭代才停止,我们可以干净地停止训练。
        with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u:
            for i in range(start_step, args.train_iterations):

                # Compute gradients, update weights, store logs!计算梯度,更新权重,存储日志!
                start_time = time.time()
                _, summary, step, b_prec_at_k, b_embs, b_loss, b_fids = \
                    sess.run([train_op, merged_summary, global_step,
                              prec_at_k, endpoints['emb'], losses, fids])
                elapsed_time = time.time() - start_time

                # Compute the iteration speed and add it to the summary.计算迭代速度并将其添加到摘要中。
                # We did observe some weird spikes that we couldn't track down.我们确实观察到一些我们无法追查的奇怪尖峰。
                summary2 = tf.Summary()
                summary2.value.add(tag='secs_per_iter',
                                   simple_value=elapsed_time)
                summary_writer.add_summary(summary2, step)
                summary_writer.add_summary(summary, step)

                if args.detailed_logs:
                    log_embs[i], log_loss[i], log_fids[
                        i] = b_embs, b_loss, b_fids

                # Do a huge print out of the current progress.            在当前的进展中做大量的印刷。
                seconds_todo = (args.train_iterations - step) * elapsed_time
                log.info(
                    'iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, '
                    'batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it)'.format(
                        step, float(np.min(b_loss)), float(np.mean(b_loss)),
                        float(np.max(b_loss)), args.batch_k - 1,
                        float(b_prec_at_k),
                        timedelta(seconds=int(seconds_todo)), elapsed_time))
                sys.stdout.flush()
                sys.stderr.flush()

                # Save a checkpoint of training every so often.每隔一段时间保存一次训练的检查点checkpoint。
                if (args.checkpoint_frequency > 0
                        and step % args.checkpoint_frequency == 0):
                    checkpoint_saver.save(sess,
                                          os.path.join(args.experiment_root,
                                                       'checkpoint'),
                                          global_step=step)

                # Stop the main-loop at the end of the step, if requested. 如果需要,在步骤结束时停止主循环。
                if u.interrupted:
                    log.info("Interrupted on request!")
                    break

        # Store one final checkpoint. This might be redundant, but it is crucial
        # in case intermediate storing was disabled and it saves a checkpoint
        # when the process was interrupted.
        #存储一个最终检查点。 这可能是多余的,但是在中间存储被禁用的情况下它是至关重要的,并且在进程被中断时保存检查点。
        checkpoint_saver.save(sess,
                              os.path.join(args.experiment_root, 'checkpoint'),
                              global_step=step)
示例#8
0
def main(args):
    logger = common.setup_logger(args)

    # Lets cuDNN benchmark conv implementations and choose the fastest.
    # Only good if sizes stay the same within the main loop!
    torch.backends.cudnn.benchmark = True

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    logger.info(f"Going to train on {device}")

    train_set, valid_set, train_loader, valid_loader = mktrainval(args, logger)

    logger.info(f"Loading model from {args.model}.npz")
    model = models.KNOWN_MODELS[args.model](head_size=len(valid_set.classes),
                                            zero_head=True)
    model.load_from(
        np.load(os.path.join(args.pretrained_dir, f"{args.model}.npz")))

    logger.info("Moving model onto all GPUs")
    model = torch.nn.DataParallel(model)

    # Optionally resume from a checkpoint.
    # Load it to CPU first as we'll move the model to GPU later.
    # This way, we save a little bit of GPU memory when loading.
    step = 0

    # Note: no weight-decay!
    optim = torch.optim.SGD(model.parameters(), lr=args.base_lr, momentum=0.9)

    writer = SummaryWriter(os.path.join(args.logdir, args.name))

    # Resume fine-tuning if we find a saved model.
    savename = pjoin(args.logdir, args.name, "model.tar")
    try:
        logger.info(f"Model will be saved in '{savename}'")
        checkpoint = torch.load(savename, map_location="cpu")
        logger.info(f"Found saved model to resume from at '{savename}'")

        step = checkpoint["step"]
        model.load_state_dict(checkpoint["model"])
        optim.load_state_dict(checkpoint["optim"])
        logger.info(f"Resumed at step {step}")
    except FileNotFoundError:
        logger.info("Fine-tuning from BiT")

    model = model.to(device)
    optim.zero_grad()

    model.train()
    mixup = hyperrule.get_mixup(len(train_set))
    cri = torch.nn.CrossEntropyLoss().to(device)

    logger.info("Starting training!")
    chrono = lb.Chrono()
    accum_steps = 0
    mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1
    end = time.time()

    with lb.Uninterrupt() as u:
        for x, y in recycle(train_loader):
            # measure data loading time, which is spent in the `for` statement.
            chrono._done("load", time.time() - end)

            if u.interrupted:
                break

            # Schedule sending to GPU(s)
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            # Update learning-rate, including stop training if over.
            lr = hyperrule.get_lr(step, len(train_set), args.base_lr)
            if lr is None:
                break
            for param_group in optim.param_groups:
                param_group["lr"] = lr

            if mixup > 0.0:
                x, y_a, y_b = mixup_data(x, y, mixup_l)

            # compute output
            with chrono.measure("fprop"):
                logits = model(x)
                if mixup > 0.0:
                    c = mixup_criterion(cri, logits, y_a, y_b, mixup_l)
                else:
                    c = cri(logits, y)
                c_num = float(
                    c.data.cpu().numpy())  # Also ensures a sync point.

            # Accumulate grads
            with chrono.measure("grads"):
                (c / args.batch_split).backward()
                accum_steps += 1

            accstep = f" ({accum_steps}/{args.batch_split})" if args.batch_split > 1 else ""
            logger.info(
                f"[step {step}{accstep}]: loss={c_num:.5f} (lr={lr:.1e})")  # pylint: disable=logging-format-interpolation
            logger.flush()
            writer.add_scalar('Train/loss', c_num, step)
            writer.add_scalar('Train/lr', lr, step)

            # Update params
            if accum_steps == args.batch_split:
                with chrono.measure("update"):
                    optim.step()
                    optim.zero_grad()
                step += 1
                accum_steps = 0
                # Sample new mixup ratio for next batch
                mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1

                # Run evaluation and save the model.
                if args.eval_every and step % args.eval_every == 0:
                    run_eval(model, valid_loader, device, chrono, logger,
                             writer, step)
                if args.save and step % args.save_every == 0:
                    step_savename = pjoin(args.logdir, args.name,
                                          "model_" + str(step) + ".tar")
                    torch.save(
                        {
                            "step": step,
                            "model": model.state_dict(),
                            "optim": optim.state_dict()
                        }, step_savename)

            end = time.time()

        # Final eval at end of training.
        run_eval(model, valid_loader, device, chrono, logger, writer, step)

    logger.info(f"Timings:\n{chrono}")
示例#9
0
def train(args, images, fids, pids, max_fid_len, log):
    '''
    Creation model and training neural network

    :param args: all stored arguments
    :param images: prepared images for training
    :param fids: figure id (relative paths from image_root to images)
    :param pids: person id (or car id) for all images
    :param log: log file, where logs from training are stored
    :return: saved files (checkpoints, train log file)
    '''
    ###################################################################################################################
    # CREATE MODEL
    ###################################################################################################################
    # Create the model and an embedding head.

    model = import_module('nets.resnet_v1_50')
    # Feed the image through the model. The returned `body_prefix` will be used
    # further down to load the pre-trained weights for all variables with this
    # prefix.
    drops = {}
    if args.dropout is not None:
        drops = getDropoutProbs(args.dropout)
    b4_layers = None
    try:
        b4_layers = int(args.b4_layers)
        if b4_layers not in [1, 2, 3]: raise ValueError()
    except:
        ValueError("Argument exception: b4_layers has to be in [1, 2, 3]")

    endpoints, body_prefix = model.endpoints(images,
                                             b4_layers,
                                             drops,
                                             is_training=True,
                                             resnet_stride=int(
                                                 args.resnet_stride))
    endpoints['emb'] = endpoints['emb_raw'] = slim.fully_connected(
        endpoints['model_output'],
        args.embedding_dim,
        activation_fn=None,
        weights_initializer=tf.orthogonal_initializer(),
        scope='emb')

    step_pl = tf.placeholder(dtype=tf.float32)

    features = endpoints['emb']

    # Create the loss in two steps:
    # 1. Compute all pairwise distances according to the specified metric.
    # 2. For each anchor along the first dimension, compute its loss.
    dists = loss.cdist(features, features, metric=args.metric)
    losses, train_top1, prec_at_k, _, probe_neg_dists, pos_dists, neg_dists = loss.loss_function(
        dists,
        pids, [args.alpha1, args.alpha2, args.alpha3],
        batch_precision_at_k=args.batch_k - 1)

    # Count the number of active entries, and compute the total batch loss.
    num_active = tf.reduce_sum(tf.cast(tf.greater(losses, 1e-5), tf.float32))
    loss_mean = tf.reduce_mean(losses)

    # Some logging for tensorboard.
    tf.summary.histogram('loss_distribution', losses)
    tf.summary.scalar('loss', loss_mean)
    tf.summary.scalar('batch_top1', train_top1)
    tf.summary.scalar('batch_prec_at_{}'.format(args.batch_k - 1), prec_at_k)
    tf.summary.scalar('active_count', num_active)
    tf.summary.scalar('embedding_pos_dists', tf.reduce_mean(pos_dists))
    tf.summary.scalar('embedding_probe_neg_dists',
                      tf.reduce_mean(probe_neg_dists))
    tf.summary.scalar('embedding_neg_dists', tf.reduce_mean(neg_dists))
    tf.summary.histogram('embedding_dists', dists)
    tf.summary.histogram('embedding_pos_dists', pos_dists)
    tf.summary.histogram('embedding_probe_neg_dists', probe_neg_dists)
    tf.summary.histogram('embedding_neg_dists', neg_dists)
    tf.summary.histogram('embedding_lengths',
                         tf.norm(endpoints['emb_raw'], axis=1))

    # Create the mem-mapped arrays in which we'll log all training detail in
    # addition to tensorboard, because tensorboard is annoying for detailed
    # inspection and actually discards data in histogram summaries.
    batch_size = args.batch_p * args.batch_k
    if args.detailed_logs:
        log_embs = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'embeddings'),
            dtype=np.float32,
            shape=(args.train_iterations, batch_size, args.embedding_dim))
        log_loss = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'losses'),
            dtype=np.float32,
            shape=(args.train_iterations, batch_size))
        log_fids = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'fids'),
            dtype='S' + str(max_fid_len),
            shape=(args.train_iterations, batch_size))

    # These are collected here before we add the optimizer, because depending
    # on the optimizer, it might add extra slots, which are also global
    # variables, with the exact same prefix.
    model_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                        body_prefix)

    # Define the optimizer and the learning-rate schedule.
    # Unfortunately, we get NaNs if we don't handle no-decay separately.
    global_step = tf.Variable(0, name='global_step', trainable=False)

    if args.sgdr:
        learning_rate = tf.train.cosine_decay_restarts(
            learning_rate=args.learning_rate,
            global_step=global_step,
            first_decay_steps=4000,
            t_mul=1.5)
    else:
        if 0 <= args.decay_start_iteration < args.train_iterations:
            learning_rate = tf.train.exponential_decay(
                args.learning_rate,
                tf.maximum(0, global_step - args.decay_start_iteration),
                args.train_iterations - args.decay_start_iteration,
                float(args.lr_decay))
        else:
            learning_rate = args.learning_rate
    tf.summary.scalar('learning_rate', learning_rate)
    optimizer = tf.train.AdamOptimizer(tf.convert_to_tensor(learning_rate))

    # Update_ops are used to update batchnorm stats.
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        train_op = optimizer.minimize(loss_mean, global_step=global_step)

    # Define a saver for the complete model.
    checkpoint_saver = tf.train.Saver(max_to_keep=0)
    with tf.Session() as sess:
        if args.resume:
            # In case we're resuming, simply load the full checkpoint to init.
            last_checkpoint = tf.train.latest_checkpoint(args.experiment_root)
            log.info('Restoring from checkpoint: {}'.format(last_checkpoint))
            checkpoint_saver.restore(sess, last_checkpoint)
        else:
            # But if we're starting from scratch, we may need to load some
            # variables from the pre-trained weights, and random init others.
            sess.run(tf.global_variables_initializer())
            if args.initial_checkpoint is not None:
                saver = tf.train.Saver(model_variables)
                saver.restore(sess, args.initial_checkpoint)

            # In any case, we also store this initialization as a checkpoint,
            # such that we could run exactly reproduceable experiments.
            checkpoint_saver.save(sess,
                                  os.path.join(args.experiment_root,
                                               'checkpoint'),
                                  global_step=0)

        merged_summary = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(args.experiment_root,
                                               sess.graph)

        start_step = sess.run(global_step)
        step = start_step
        log.info('Starting training from iteration {}.'.format(start_step))

        ###################################################################################################################
        # TRAINING
        ###################################################################################################################
        # Finally, here comes the main-loop. This `Uninterrupt` is a handy
        # utility such that an iteration still finishes on Ctrl+C and we can
        # stop the training cleanly.
        with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u:
            for i in range(start_step, args.train_iterations):
                # Compute gradients, update weights, store logs!
                start_time = time.time()
                _, summary, step, b_prec_at_k, b_embs, b_loss, b_fids = \
                    sess.run([train_op, merged_summary, global_step,
                              prec_at_k, features, losses, fids], feed_dict={step_pl: step})
                elapsed_time = time.time() - start_time

                # Compute the iteration speed and add it to the summary.
                # We did observe some weird spikes that we couldn't track down.
                summary2 = tf.Summary()
                summary2.value.add(tag='secs_per_iter',
                                   simple_value=elapsed_time)
                summary_writer.add_summary(summary2, step)
                summary_writer.add_summary(summary, step)

                if args.detailed_logs:
                    log_embs[i], log_loss[i], log_fids[
                        i] = b_embs, b_loss, b_fids

                # Do a huge print out of the current progress. Maybe steal from here.
                seconds_todo = (args.train_iterations - step) * elapsed_time
                log.info(
                    'iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, '
                    'batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it), lr={:.4g}'.
                    format(step, float(np.min(b_loss)), float(np.mean(b_loss)),
                           float(np.max(b_loss)), args.batch_k - 1,
                           float(b_prec_at_k),
                           timedelta(seconds=int(seconds_todo)), elapsed_time,
                           sess.run(optimizer._lr)))
                sys.stdout.flush()
                sys.stderr.flush()

                # Save a checkpoint of training every so often.
                if (args.checkpoint_frequency > 0
                        and step % args.checkpoint_frequency == 0):
                    checkpoint_saver.save(sess,
                                          os.path.join(args.experiment_root,
                                                       'checkpoint'),
                                          global_step=step)

                # Stop the main-loop at the end of the step, if requested.
                if u.interrupted:
                    log.info("Interrupted on request!")
                    break

        # Store one final checkpoint. This might be redundant, but it is crucial
        # in case intermediate storing was disabled and it saves a checkpoint
        # when the process was interrupted.
        checkpoint_saver.save(sess,
                              os.path.join(args.experiment_root, 'checkpoint'),
                              global_step=step)
示例#10
0
def main(argv):

    args = parser.parse_args(argv)

    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    # tf.compat.v1.disable_eager_execution()

    # physical_devices = tf.config.experimental.list_physical_devices('GPU')
    # tf.config.experimental.set_memory_growth(physical_devices[0], True)

    # We store all arguments in a json file. This has two advantages:
    # 1. We can always get back and see what exactly that experiment was
    # 2. We can resume an experiment as-is without needing to remember all flags.
    args_file = os.path.join(args.experiment_root, 'args.json')
    if args.resume:
        if not os.path.isfile(args_file):
            raise IOError('`args.json` not found in {}'.format(args_file))

        print('Loading args from {}.'.format(args_file))
        with open(args_file, 'r') as f:
            args_resumed = json.load(f)
        args_resumed['resume'] = True  # This would be overwritten.

        # When resuming, we not only want to populate the args object with the
        # values from the file, but we also want to check for some possible
        # conflicts between loaded and given arguments.
        for key, value in args.__dict__.items():
            if key in args_resumed:
                resumed_value = args_resumed[key]
                if resumed_value != value:
                    print('Warning: For the argument `{}` we are using the'
                          ' loaded value `{}`. The provided value was `{}`'
                          '.'.format(key, resumed_value, value))
                    args.__dict__[key] = resumed_value
            else:
                print('Warning: A new argument was added since the last run:'
                      ' `{}`. Using the new value: `{}`.'.format(key, value))

    else:
        # If the experiment directory exists already, we bail in fear.
        if os.path.exists(args.experiment_root):
            if os.listdir(args.experiment_root):
                print('The directory {} already exists and is not empty.'
                      ' If you want to resume training, append --resume to'
                      ' your call.'.format(args.experiment_root))
                exit(1)
        else:
            os.makedirs(args.experiment_root)

        # Store the passed arguments for later resuming and grepping in a nice
        # and readable format.
        with open(args_file, 'w') as f:
            json.dump(vars(args),
                      f,
                      ensure_ascii=False,
                      indent=2,
                      sort_keys=True)

    log_file = os.path.join(args.experiment_root, "train")
    logging.config.dictConfig(common.get_logging_dict(log_file))
    log = logging.getLogger('train')

    # Also show all parameter values at the start, for ease of reading logs.
    log.info('Training using the following parameters:')
    for key, value in sorted(vars(args).items()):
        log.info('{}: {}'.format(key, value))

    # Check them here, so they are not required when --resume-ing.
    if not args.train_set:
        parser.print_help()
        log.error("You did not specify the `train_set` argument!")
        sys.exit(1)
    if not args.image_root:
        parser.print_help()
        log.error("You did not specify the required `image_root` argument!")
        sys.exit(1)

    # Load the data from the CSV file.
    pids, fids = common.load_dataset(args.train_set, args.image_root)
    max_fid_len = max(map(len, fids))  # We'll need this later for logfiles.

    # Setup a tf.Dataset where one "epoch" loops over all PIDS.
    # PIDS are shuffled after every epoch and continue indefinitely.
    unique_pids = np.unique(pids)
    if len(unique_pids) < args.batch_p:
        unique_pids = np.tile(unique_pids,
                              int(np.ceil(args.batch_p / len(unique_pids))))
    dataset = tf.data.Dataset.from_tensor_slices(unique_pids)
    dataset = dataset.shuffle(len(unique_pids))

    # Constrain the dataset size to a multiple of the batch-size, so that
    # we don't get overlap at the end of each epoch.
    dataset = dataset.take((len(unique_pids) // args.batch_p) * args.batch_p)
    dataset = dataset.repeat(None)  # Repeat forever. Funny way of stating it.

    # For every PID, get K images.
    dataset = dataset.map(lambda pid: sample_k_fids_for_pid(
        pid, all_fids=fids, all_pids=pids, batch_k=args.batch_k))

    # Ungroup/flatten the batches for easy loading of the files.
    dataset = dataset.unbatch()

    # Convert filenames to actual image tensors.
    net_input_size = (args.net_input_height, args.net_input_width)
    pre_crop_size = (args.pre_crop_height, args.pre_crop_width)

    dataset = dataset.map(lambda fid, pid: common.fid_to_image(
        fid,
        pid,
        image_root=args.image_root,
        image_size=pre_crop_size if args.crop_augment else net_input_size),
                          num_parallel_calls=args.loading_threads)

    # Augment the data if specified by the arguments.

    dataset = dataset.map(
        lambda im, fid, pid: common.fid_to_image(fid,
                                                 pid,
                                                 image_root=args.image_root,
                                                 image_size=pre_crop_size
                                                 if args.crop_augment else
                                                 net_input_size),  # Ergys
        num_parallel_calls=args.loading_threads)

    if args.flip_augment:
        dataset = dataset.map(lambda im, fid, pid:
                              (tf.image.random_flip_left_right(im), fid, pid))
    if args.crop_augment:
        dataset = dataset.map(lambda im, fid, pid: (tf.image.random_crop(
            im, net_input_size + (3, )), fid, pid))

    # Create the model and an embedding head.
    tf.keras.backend.set_learning_phase(1)
    emb_model = EmbeddingModel(args)

    # Group it back into PK batches.
    batch_size = args.batch_p * args.batch_k
    dataset = dataset.map(lambda im, fid, pid:
                          (emb_model.preprocess_input(im), fid, pid))
    dataset = dataset.batch(batch_size)

    # Overlap producing and consuming for parallelism.
    dataset = dataset.prefetch(1)

    # Since we repeat the data infinitely, we only need a one-shot iterator.

    # Feed the image through the model. The returned `body_prefix` will be used
    # further down to load the pre-trained weights for all variables with this
    # prefix.

    # all_trainable_variables = embedding_head.trainable_variables+base_model.trainable_variables

    # Define the optimizer and the learning-rate schedule.
    # Unfortunately, we get NaNs if we don't handle no-decay separately.
    if 0 <= args.decay_start_iteration < args.train_iterations:
        learning_rate = tf.optimizers.schedules.PolynomialDecay(
            args.learning_rate, args.train_iterations, end_learning_rate=1e-7)
    else:
        learning_rate = args.learning_rate

    if args.optimizer == 'adam':
        optimizer = tf.keras.optimizers.Adam(learning_rate)
    elif args.optimizer == 'momentum':
        optimizer = tf.keras.optimizers.SGD(learning_rate, momentum=0.9)
    else:
        raise NotImplementedError('Invalid optimizer {}'.format(
            args.optimizer))

    @tf.function
    def train_step(images, pids):

        with tf.GradientTape() as tape:
            batch_embedding = emb_model(images)
            if args.loss == 'semi_hard_triplet':
                embedding_loss = triplet_semihard_loss(batch_embedding, pids,
                                                       args.margin)
            elif args.loss == 'hard_triplet':
                embedding_loss = batch_hard(batch_embedding, pids, args.margin,
                                            args.metric)
            elif args.loss == 'lifted_loss':
                embedding_loss = lifted_loss(pids,
                                             batch_embedding,
                                             margin=args.margin)
            elif args.loss == 'contrastive_loss':
                assert batch_size % 2 == 0
                assert args.batch_k == 4  ## Can work with other number but will need tuning

                contrastive_idx = np.tile([0, 1, 4, 3, 2, 5, 6, 7],
                                          args.batch_p // 2)
                for i in range(args.batch_p // 2):
                    contrastive_idx[i * 8:i * 8 + 8] += i * 8

                contrastive_idx = np.expand_dims(contrastive_idx, 1)
                batch_embedding_ordered = tf.gather_nd(batch_embedding,
                                                       contrastive_idx)
                pids_ordered = tf.gather_nd(pids, contrastive_idx)
                # batch_embedding_ordered = tf.Print(batch_embedding_ordered,[pids_ordered],'pids_ordered :: ',summarize=1000)
                embeddings_anchor, embeddings_positive = tf.unstack(
                    tf.reshape(batch_embedding_ordered,
                               [-1, 2, args.embedding_dim]), 2, 1)
                # embeddings_anchor = tf.Print(embeddings_anchor,[pids_ordered,embeddings_anchor,embeddings_positive,batch_embedding,batch_embedding_ordered],"Tensors ", summarize=1000)

                fixed_labels = np.tile([1, 0, 0, 1], args.batch_p // 2)
                # fixed_labels = np.reshape(fixed_labels,(len(fixed_labels),1))
                # print(fixed_labels)
                labels = tf.constant(fixed_labels)
                # labels = tf.Print(labels,[labels],'labels ',summarize=1000)
                embedding_loss = contrastive_loss(labels,
                                                  embeddings_anchor,
                                                  embeddings_positive,
                                                  margin=args.margin)
            elif args.loss == 'angular_loss':
                embeddings_anchor, embeddings_positive = tf.unstack(
                    tf.reshape(batch_embedding, [-1, 2, args.embedding_dim]),
                    2, 1)
                # pids = tf.Print(pids, [pids], 'pids:: ', summarize=100)
                pids, _ = tf.unstack(tf.reshape(pids, [-1, 2, 1]), 2, 1)
                # pids = tf.Print(pids,[pids],'pids:: ',summarize=100)
                embedding_loss = angular_loss(pids,
                                              embeddings_anchor,
                                              embeddings_positive,
                                              batch_size=args.batch_p,
                                              with_l2reg=True)

            elif args.loss == 'npairs_loss':
                assert args.batch_k == 2  ## Single positive pair per class
                embeddings_anchor, embeddings_positive = tf.unstack(
                    tf.reshape(batch_embedding, [-1, 2, args.embedding_dim]),
                    2, 1)
                pids, _ = tf.unstack(tf.reshape(pids, [-1, 2, 1]), 2, 1)
                pids = tf.reshape(pids, [-1])
                embedding_loss = npairs_loss(pids, embeddings_anchor,
                                             embeddings_positive)

            else:
                raise NotImplementedError('Invalid Loss {}'.format(args.loss))
            loss_mean = tf.reduce_mean(embedding_loss)

        gradients = tape.gradient(loss_mean, emb_model.trainable_variables)
        optimizer.apply_gradients(zip(gradients,
                                      emb_model.trainable_variables))

        return embedding_loss

    # sess = tf.compat.v1.Session()
    # start_step = sess.run(global_step)
    # checkpoint_saver = tf.train.Saver(max_to_keep=2)
    start_step = 0
    log.info('Starting training from iteration {}.'.format(start_step))
    dataset_iter = iter(dataset)

    ckpt = tf.train.Checkpoint(step=tf.Variable(1),
                               optimizer=optimizer,
                               net=emb_model)
    manager = tf.train.CheckpointManager(ckpt,
                                         osp.join(args.experiment_root,
                                                  'tf_ckpts'),
                                         max_to_keep=3)

    ckpt.restore(manager.latest_checkpoint)
    if manager.latest_checkpoint:
        print("Restored from {}".format(manager.latest_checkpoint))
    else:
        print("Initializing from scratch.")

    with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u:
        for i in range(ckpt.step.numpy(), args.train_iterations):
            # for batch_idx, batch in enumerate():
            start_time = time.time()
            images, fids, pids = next(dataset_iter)
            batch_loss = train_step(images, pids)
            elapsed_time = time.time() - start_time
            seconds_todo = (args.train_iterations - i) * elapsed_time
            # print(tf.reduce_min(batch_loss).numpy(),tf.reduce_mean(batch_loss).numpy(),tf.reduce_max(batch_loss).numpy())
            log.info(
                'iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, ETA: {} ({:.2f}s/it)'
                .format(
                    i,
                    tf.reduce_min(batch_loss).numpy(),
                    tf.reduce_mean(batch_loss).numpy(),
                    tf.reduce_max(batch_loss).numpy(),
                    # args.batch_k - 1, float(b_prec_at_k),
                    timedelta(seconds=int(seconds_todo)),
                    elapsed_time))

            ckpt.step.assign_add(1)
            if (args.checkpoint_frequency > 0
                    and i % args.checkpoint_frequency == 0):

                # uncomment if you want to save the model weight separately
                # emb_model.save_weights(os.path.join(args.experiment_root, 'model_weights_{0:04d}.w'.format(i)))

                manager.save()

            # Stop the main-loop at the end of the step, if requested.
            if u.interrupted:
                log.info("Interrupted on request!")
                break
示例#11
0
def main():
    # args = parser.parse_args()

    # We store all arguments in a json file. This has two advantages:
    # 1. We can always get back and see what exactly that experiment was
    # 2. We can resume an experiment as-is without needing to remember all flags.

    train_config = cfg.TrainConfig()

    args_file = os.path.join(train_config.experiment_root, 'args.json')
    if train_config.resume:
        if not os.path.isfile(args_file):
            raise IOError('`args.json` not found in {}'.format(args_file))

        print('Loading args from {}.'.format(args_file))
        with open(args_file, 'r') as f:
            args_resumed = json.load(f)
        args_resumed['resume'] = True  # This would be overwritten.

        # When resuming, we not only want to populate the args object with the
        # values from the file, but we also want to check for some possible
        # conflicts between loaded and given arguments.
        for key, value in train_config.__dict__.items():
            if key in args_resumed:
                resumed_value = args_resumed[key]
                if resumed_value != value:
                    print('Warning: For the argument `{}` we are using the'
                          ' loaded value `{}`. The provided value was `{}`'
                          '.'.format(key, resumed_value, value))
                    train_config.__dict__[key] = resumed_value
            else:
                print('Warning: A new argument was added since the last run:'
                      ' `{}`. Using the new value: `{}`.'.format(key, value))

    else:
        # If the experiment directory exists already, we bail in fear.
        if os.path.exists(train_config.experiment_root):
            if os.listdir(train_config.experiment_root):
                print('The directory {} already exists and is not empty.'
                      ' If you want to resume training, append --resume to'
                      ' your call.'.format(train_config.experiment_root))
                exit(1)
        else:
            os.makedirs(train_config.experiment_root)

        # Store the passed arguments for later resuming and grepping in a nice
        # and readable format.
        with open(args_file, 'w') as f:
            json.dump(vars(args), f, ensure_ascii=False, indent=2, sort_keys=True)

    log_file = os.path.join(train_config.experiment_root, "train")
    logging.config.dictConfig(common.get_logging_dict(log_file))
    log = logging.getLogger('train')

    # Also show all parameter values at the start, for ease of reading logs.
    log.info('Training using the following parameters:')
    for key, value in sorted(vars(args).items()):
        log.info('{}: {}'.format(key, value))

    # Check them here, so they are not required when --resume-ing.
    if not train_config.train_set:
        parser.print_help()
        log.error("You did not specify the `train_set` argument!")
        sys.exit(1)
    if not train_config.image_root:
        parser.print_help()
        log.error("You did not specify the required `image_root` argument!")
        sys.exit(1)

    # Load the data from the CSV file.
    pids, fids = common.load_dataset(train_config.train_set, train_config.image_root, is_train=True)

    max_fid_len = max(map(len, fids))  # We'll need this later for logfiles

    # Setup a tf.Dataset where one "epoch" loops over all PIDS.
    # PIDS are shuffled after every epoch and continue indefinitely.

    unique_pids = np.unique(pids)

    dataset = tf.data.Dataset.from_tensor_slices(unique_pids)

    dataset = dataset.shuffle(len(unique_pids))

    # Constrain the dataset size to a multiple of the batch-size, so that
    # we don't get overlap at the end of each epoch.
    dataset = dataset.take((len(unique_pids) // train_config.batch_p) * train_config.batch_p)
    # take(count)  Creates a Dataset with at most count elements from this dataset.

    dataset = dataset.repeat(None)  # Repeat forever. Funny way of stating it.
    # Repeats this dataset count times.

    # For every PID, get K images.
    dataset = dataset.map(lambda pid: sample_k_fids_for_pid(
        pid, all_fids=fids, all_pids=pids, batch_k=train_config.batch_k))

    # Ungroup/flatten the batches for easy loading of the files.
    dataset = dataset.apply(tf.contrib.data.unbatch())
    # apply(transformation_func) Apply a transformation function to this dataset.
    # apply enables chaining of custom Dataset transformations, which are represented as functions that take one Dataset argument and return a transformed Dataset.


    # Convert filenames to actual image tensors.
    net_input_size = (train_config.net_input_height, train_config.net_input_width)
    # 256,128
    pre_crop_size = (train_config.pre_crop_height, train_config.pre_crop_width)
    # 288,144
    
    dataset = dataset.map(
        lambda fid, pid: common.fid_to_image_label(
            fid, pid, image_root=train_config.image_root,
            image_size=pre_crop_size if train_config.crop_augment else net_input_size),
        num_parallel_calls=train_config.loading_threads)

###########################################################################################
    dataset = dataset.map(
        lambda im, keypt, mask, fid, pid: (tf.concat([im, keypt, mask], 2), fid, pid))

###########################################################################################
    
    # Augment the data if specified by the arguments.
    if train_config.flip_augment:
        dataset = dataset.map(
            lambda im, fid, pid: (tf.image.random_flip_left_right(im), fid, pid))


    # net_input_size_aug = net_input_size + (4,)
    if train_config.crop_augment:
        dataset = dataset.map(
            lambda im, fid, pid: (tf.random_crop(im, net_input_size + (21,)), fid, pid))
    # net_input_size + (21,) = (256, 128, 21)
    # split

#############################################################################################
    dataset = dataset.map(
        lambda im, fid, pid: (common.split(im, fid, pid)))

#############################################################################################

    # Group it back into PK batches.
    batch_size = train_config.batch_p * train_config.batch_k
    dataset = dataset.batch(batch_size)

    # Overlap producing and consuming for parallelism.
    dataset = dataset.prefetch(1)
    # prefetch(buffer_size)   Creates a Dataset that prefetches elements from this dataset.

    # Since we repeat the data infinitely, we only need a one-shot iterator.
    images, keypts, masks, fids, pids = dataset.make_one_shot_iterator().get_next()
    # tf.summary.image('image',images,10)

    # Create the model and an embedding head.
    model = import_module('nets.' + train_config.model_name)
    head = import_module('heads.' + train_config.head_name)

    # Feed the image through the model. The returned `body_prefix` will be used
    # further down to load the pre-trained weights for all variables with this
    # prefix.

    endpoints, body_prefix = model.endpoints(images, is_training=True)
    heatmap_in = endpoints[train_config.model_name + '/block4']
    # resnet_block_4_out = heatmap.resnet_block_4(heatmap_in)
    # resnet_block_3_4_out = heatmap.resnet_block_3_4(heatmap_in)
    # resnet_block_2_3_4_out = heatmap.resnet_block_2_3_4(heatmap_in)
    # head for heatmap
    with tf.name_scope('heatmap'):
        # heatmap_in = endpoints['model_output']
        # heatmap_out_layer_0 = heatmap.hmnet_layer_0(resnet_block_4_out, 1)
        # heatmap_out_layer_0 = heatmap.hmnet_layer_0(resnet_block_3_4_out, 1)
        # heatmap_out_layer_0 = heatmap.hmnet_layer_0(resnet_block_2_3_4_out, 1)
        heatmap_out_layer_0 = VAC.hmnet_layer_0(heatmap_in[:, :, :, 1020:2048], 1)
        heatmap_out_layer_1 = VAC.hmnet_layer_1(heatmap_out_layer_0, 1)
        heatmap_out_layer_2 = VAC.hmnet_layer_2(heatmap_out_layer_1, 1)
        heatmap_out_layer_3 = VAC.hmnet_layer_3(heatmap_out_layer_2, 1)
        heatmap_out_layer_4 = VAC.hmnet_layer_4(heatmap_out_layer_3, 1)
        heatmap_out = heatmap_out_layer_4
        heatmap_loss = VAC.loss_mutilayer(heatmap_out_layer_0, heatmap_out_layer_1, heatmap_out_layer_2,
                                              heatmap_out_layer_3, heatmap_out_layer_4, masks, net_input_size)
        # heatmap_loss = heatmap.loss(heatmap_out, labels, net_input_size)
        # heatmap_loss_mean = heatmap_loss

    with tf.name_scope('head'):
        # heatmap_sum = tf.reduce_sum(heatmap_out, axis=3)
        # heatmap_resize = tf.image.resize_images(tf.expand_dims(heatmap_sum, axis=3), [8, 4])
        # featuremap_tmp = tf.multiply(heatmap_resize, endpoints[args.model_name + '/block4'])
        # endpoints[args.model_name + '/block4'] = featuremap_tmp
        endpoints = head.head(endpoints, train_config.embedding_dim, is_training=True)

        tf.summary.image('feature_map', tf.expand_dims(endpoints[train_config.model_name + '/block4'][:, :, :, 0], axis=3), 4)


    with tf.name_scope('keypoints_pre'):
        keypoints_pre_in = endpoints[train_config.model_name + '/block4']
        # keypoints_pre_in_0 = keypoints_pre_in[:, :, :, 0:256]
        # keypoints_pre_in_1 = keypoints_pre_in[:, :, :, 256:512]
        # keypoints_pre_in_2 = keypoints_pre_in[:, :, :, 512:768]
        # keypoints_pre_in_3 = keypoints_pre_in[:, :, :, 768:1024]
        keypoints_pre_in_0 = keypoints_pre_in[:, :, :, 0:170]
        keypoints_pre_in_1 = keypoints_pre_in[:, :, :, 170:340]
        keypoints_pre_in_2 = keypoints_pre_in[:, :, :, 340:510]
        keypoints_pre_in_3 = keypoints_pre_in[:, :, :, 510:680]
        keypoints_pre_in_4 = keypoints_pre_in[:, :, :, 680:850]
        keypoints_pre_in_5 = keypoints_pre_in[:, :, :, 850:1020]

        labels = tf.image.resize_images(keypts, [128, 64])
        # keypoints_gt_0 = tf.concat([labels[:, :, :, 0:5], labels[:, :, :, 14:15], labels[:, :, :, 15:16], labels[:, :, :, 16:17], labels[:, :, :, 17:18]], 3)
        # keypoints_gt_1 = tf.concat([labels[:, :, :, 1:2], labels[:, :, :, 2:3], labels[:, :, :, 3:4], labels[:, :, :, 5:6]], 3)
        # keypoints_gt_2 = tf.concat([labels[:, :, :, 4:5], labels[:, :, :, 7:8], labels[:, :, :, 8:9], labels[:, :, :, 11:12]], 3)
        # keypoints_gt_3 = tf.concat([labels[:, :, :, 9:10], labels[:, :, :, 10:11], labels[:, :, :, 12:13], labels[:, :, :, 13:14]], 3)

        keypoints_gt_0 = labels[:, :, :, 0:5]
        keypoints_gt_1 = labels[:, :, :, 5:7]
        keypoints_gt_2 = labels[:, :, :, 7:9]
        keypoints_gt_3 = labels[:, :, :, 9:13]
        keypoints_gt_4 = labels[:, :, :, 13:15]
        keypoints_gt_5 = labels[:, :, :, 15:17]

        keypoints_pre_0 = PAC.tran_conv_0(keypoints_pre_in, kp_num=5)
        keypoints_pre_1 = PAC.tran_conv_1(keypoints_pre_in, kp_num=2)
        keypoints_pre_2 = PAC.tran_conv_2(keypoints_pre_in, kp_num=2)
        keypoints_pre_3 = PAC.tran_conv_3(keypoints_pre_in, kp_num=4)
        keypoints_pre_4 = PAC.tran_conv_4(keypoints_pre_in, kp_num=2)
        keypoints_pre_5 = PAC.tran_conv_5(keypoints_pre_in, kp_num=2)

        keypoints_loss_0 = PAC.keypoints_loss(keypoints_pre_0, keypoints_gt_0)
        keypoints_loss_1 = PAC.keypoints_loss(keypoints_pre_1, keypoints_gt_1)
        keypoints_loss_2 = PAC.keypoints_loss(keypoints_pre_2, keypoints_gt_2)
        keypoints_loss_3 = PAC.keypoints_loss(keypoints_pre_3, keypoints_gt_3)
        keypoints_loss_4 = PAC.keypoints_loss(keypoints_pre_4, keypoints_gt_4)
        keypoints_loss_5 = PAC.keypoints_loss(keypoints_pre_5, keypoints_gt_5)

        keypoints_loss = 5/17*keypoints_loss_0 + 2/17*keypoints_loss_1 + 2/17*keypoints_loss_2 + 4/17*keypoints_loss_3 + 2/17*keypoints_loss_4 + 2/17*keypoints_loss_5


    # Create the loss in two steps:
    # 1. Compute all pairwise distances according to the specified metric.
    # 2. For each anchor along the first dimension, compute its loss.
    dists = loss.cdist(endpoints['emb'], endpoints['emb'], metric=train_config.metric)
    losses, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[train_config.loss](
        dists, pids, train_config.margin, batch_precision_at_k=train_config.batch_k-1)

    # Count the number of active entries, and compute the total batch loss.
    num_active = tf.reduce_sum(tf.cast(tf.greater(losses, 1e-5), tf.float32))
    loss_mean = tf.reduce_mean(losses)
    
    scale_rate_0 = 1E-7
    scale_rate_1 = 6E-8
    total_loss = loss_mean + keypoints_loss*scale_rate_0 + heatmap_loss*scale_rate_1
    # total_loss = loss_mean + keypoints_loss * scale_rate_0
    # total_loss = loss_mean

    # Some logging for tensorboard.
    tf.summary.histogram('loss_distribution', losses)
    tf.summary.scalar('loss', loss_mean)
############################################################################################
    # tf.summary.histogram('hm_loss_distribution', heatmap_loss)
    tf.summary.scalar('keypt_loss_0', keypoints_loss_0)
    tf.summary.scalar('keypt_loss_1', keypoints_loss_1)
    tf.summary.scalar('keypt_loss_2', keypoints_loss_2)
    tf.summary.scalar('keypt_loss_3', keypoints_loss_3)
    tf.summary.scalar('keypt_loss_all', keypoints_loss)
############################################################################################
    tf.summary.scalar('total_loss', total_loss)
    tf.summary.scalar('batch_top1', train_top1)
    tf.summary.scalar('batch_prec_at_{}'.format(args.batch_k-1), prec_at_k)
    tf.summary.scalar('active_count', num_active)
    tf.summary.histogram('embedding_dists', dists)
    tf.summary.histogram('embedding_pos_dists', pos_dists)
    tf.summary.histogram('embedding_neg_dists', neg_dists)
    tf.summary.histogram('embedding_lengths',
                         tf.norm(endpoints['emb_raw'], axis=1))

    # Create the mem-mapped arrays in which we'll log all training detail in
    # addition to tensorboard, because tensorboard is annoying for detailed
    # inspection and actually discards data in histogram summaries.
    if args.detailed_logs:
        log_embs = lb.create_or_resize_dat(
            os.path.join(train_config.experiment_root, 'embeddings'),
            dtype=np.float32, shape=(train_config.train_iterations, batch_size, args.embedding_dim))
        log_loss = lb.create_or_resize_dat(
            os.path.join(train_config.experiment_root, 'losses'),
            dtype=np.float32, shape=(train_config.train_iterations, batch_size))
        log_fids = lb.create_or_resize_dat(
            os.path.join(train_config.experiment_root, 'fids'),
            dtype='S' + str(max_fid_len), shape=(train_config.train_iterations, batch_size))

    # These are collected here before we add the optimizer, because depending
    # on the optimizer, it might add extra slots, which are also global
    # variables, with the exact same prefix.
    model_variables = tf.get_collection(
        tf.GraphKeys.GLOBAL_VARIABLES, body_prefix)

    # Define the optimizer and the learning-rate schedule.
    # Unfortunately, we get NaNs if we don't handle no-decay separately.
    global_step = tf.Variable(0, name='global_step', trainable=False)
    if 0 <= train_config.decay_start_iteration < train_config.train_iterations:
        learning_rate = tf.train.exponential_decay(
            train_config.learning_rate,
            tf.maximum(0, global_step - train_config.decay_start_iteration),
            train_config.train_iterations - train_config.decay_start_iteration, 0.001)
    else:
        learning_rate = train_config.learning_rate
    tf.summary.scalar('learning_rate', learning_rate)
    optimizer = tf.train.AdamOptimizer(learning_rate)
    # Feel free to try others!
    # optimizer = tf.train.AdadeltaOptimizer(learning_rate)

    # Update_ops are used to update batchnorm stats.
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
    #    train_op = optimizer.minimize(loss_mean, global_step=global_step)
        train_op = optimizer.minimize(total_loss, global_step=global_step)
    #


    # Define a saver for the complete model.
    checkpoint_saver = tf.train.Saver(max_to_keep=0)


    gpu_options = tf.GPUOptions(allow_growth=True)
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        if train_config.resume:
            # In case we're resuming, simply load the full checkpoint to init.
            last_checkpoint = tf.train.latest_checkpoint(args.experiment_root)
            log.info('Restoring from checkpoint: {}'.format(last_checkpoint))
            checkpoint_saver.restore(sess, last_checkpoint)
        else:
            # But if we're starting from scratch, we may need to load some
            # variables from the pre-trained weights, and random init others.
            sess.run(tf.global_variables_initializer())
            if train_config.initial_checkpoint is not None:
                saver = tf.train.Saver(model_variables, write_version=tf.train.SaverDef.V1)
                
                saver.restore(sess, train_config.initial_checkpoint)

                # name_11 = 'resnet_v1_50/block4'
                # name_12 = 'resnet_v1_50/block3'
                # name_13 = 'resnet_v1_50/block2'
                # name_21 = 'Resnet_block_2_3_4/block4'
                # name_22 = 'Resnet_block_2_3_4/block3'
                # name_23 = 'Resnet_block_2_3_4/block2'
                # for var in tf.trainable_variables():
                #     var_name = var.name
                #     if re.match(name_11, var_name):
                #         dst_name = var_name.replace(name_11, name_21)
                #         tensor = tf.get_default_graph().get_tensor_by_name(var_name)
                #         dst_tensor = tf.get_default_graph().get_tensor_by_name(dst_name)
                #         tf.assign(dst_tensor, tensor)
                #     if re.match(name_12, var_name):
                #         dst_name = var_name.replace(name_12, name_22)
                #         tensor = tf.get_default_graph().get_tensor_by_name(var_name)
                #         dst_tensor = tf.get_default_graph().get_tensor_by_name(dst_name)
                #         tf.assign(dst_tensor, tensor)
                #     if re.match(name_13, var_name):
                #         dst_name = var_name.replace(name_13, name_23)
                #         tensor = tf.get_default_graph().get_tensor_by_name(var_name)
                #         dst_tensor = tf.get_default_graph().get_tensor_by_name(dst_name)
                #         tf.assign(dst_tensor, tensor)
            # In any case, we also store this initialization as a checkpoint,
            # such that we could run exactly reproduceable experiments.
            checkpoint_saver.save(sess, os.path.join(
                train_config.experiment_root, 'checkpoint'), global_step=0)

        merged_summary = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(train_config.experiment_root, sess.graph)

        start_step = sess.run(global_step)
        log.info('Starting training from iteration {}.'.format(start_step))

        # Finally, here comes the main-loop. This `Uninterrupt` is a handy
        # utility such that an iteration still finishes on Ctrl+C and we can
        # stop the training cleanly.
        with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u:
            for i in range(start_step, train_config.train_iterations):

                # Compute gradients, update weights, store logs!
                start_time = time.time()
                _, summary, step, b_prec_at_k, b_embs, b_loss, b_fids = \
                    sess.run([train_op, merged_summary, global_step,
                              prec_at_k, endpoints['emb'], losses, fids])
                elapsed_time = time.time() - start_time

                # Compute the iteration speed and add it to the summary.
                # We did observe some weird spikes that we couldn't track down.
                summary2 = tf.Summary()
                summary2.value.add(tag='secs_per_iter', simple_value=elapsed_time)
                summary_writer.add_summary(summary2, step)
                summary_writer.add_summary(summary, step)

                if train_config.detailed_logs:
                    log_embs[i], log_loss[i], log_fids[i] = b_embs, b_loss, b_fids

                # Do a huge print out of the current progress.
                seconds_todo = (train_config.train_iterations - step) * elapsed_time
                log.info('iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, '
                         'batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it)'.format(
                             step,
                             float(np.min(b_loss)),
                             float(np.mean(b_loss)),
                             float(np.max(b_loss)),
                             train_config.batch_k-1, float(b_prec_at_k),
                             timedelta(seconds=int(seconds_todo)),
                             elapsed_time))
                sys.stdout.flush()
                sys.stderr.flush()

                # Save a checkpoint of training every so often.
                if (train_config.checkpoint_frequency > 0 and
                        step % train_config.checkpoint_frequency == 0):
                    checkpoint_saver.save(sess, os.path.join(
                        train_config.experiment_root, 'checkpoint'), global_step=step)

                # Stop the main-loop at the end of the step, if requested.
                if u.interrupted:
                    log.info("Interrupted on request!")
                    break

        # Store one final checkpoint. This might be redundant, but it is crucial
        # in case intermediate storing was disabled and it saves a checkpoint
        # when the process was interrupted.
        checkpoint_saver.save(sess, os.path.join(
            train_config.experiment_root, 'checkpoint'), global_step=step)