def main(): args = parser.parse_args() # We store all arguments in a json file. This has two advantages: # 1. We can always get back and see what exactly that experiment was # 2. We can resume an experiment as-is without needing to remember all flags. args_file = os.path.join(args.experiment_root, 'args.json') if args.resume: if not os.path.isfile(args_file): raise IOError('`args.json` not found in {}'.format(args_file)) print('Loading args from {}.'.format(args_file)) with open(args_file, 'r') as f: args_resumed = json.load(f) args_resumed['resume'] = True # This would be overwritten. # When resuming, we not only want to populate the args object with the # values from the file, but we also want to check for some possible # conflicts between loaded and given arguments. for key, value in args.__dict__.items(): if key in args_resumed: resumed_value = args_resumed[key] if resumed_value != value: print('Warning: For the argument `{}` we are using the' ' loaded value `{}`. The provided value was `{}`' '.'.format(key, resumed_value, value)) comand = input('Would you like to restore it?(yes/no)') if comand == 'yes': args.__dict__[key] = resumed_value print( 'For the argument `{}` we are using the loaded value `{}`.' .format(key, args.__dict__[key])) else: print( 'For the argument `{}` we are using the provided value `{}`.' .format(key, args.__dict__[key])) else: print('Warning: A new argument was added since the last run:' ' `{}`. Using the new value: `{}`.'.format(key, value)) os.remove(args_file) with open(args_file, 'w') as f: json.dump(vars(args), f, ensure_ascii=False, indent=2, sort_keys=True) else: # If the experiment directory exists already, we bail in fear. if os.path.exists(args.experiment_root): if os.listdir(args.experiment_root): print('The directory {} already exists and is not empty.' ' If you want to resume training, append --resume to' ' your call.'.format(args.experiment_root)) exit(1) else: os.makedirs(args.experiment_root) # Store the passed arguments for later resuming and grepping in a nice # and readable format. with open(args_file, 'w') as f: json.dump(vars(args), f, ensure_ascii=False, indent=2, sort_keys=True) log_file = os.path.join(args.experiment_root, "train") logging.config.dictConfig(common.get_logging_dict(log_file)) log = logging.getLogger('train') # Also show all parameter values at the start, for ease of reading logs. log.info('Training using the following parameters:') for key, value in sorted(vars(args).items()): log.info('{}: {}'.format(key, value)) # Check them here, so they are not required when --resume-ing. if not args.train_set: parser.print_help() log.error("You did not specify the `train_set` argument!") sys.exit(1) if not args.image_root: parser.print_help() log.error("You did not specify the required `image_root` argument!") sys.exit(1) ###################################################################################### #prepare the training dataset # Load the data from the TxT file. see Common.load_dataset function for details pids_train, fids_train = common.load_dataset(args.train_set, args.image_root) max_fid_len = max(map(len, fids_train)) # We'll need this later for logfiles. # Setup a tf.Dataset where one "epoch" loops over all PIDS. # PIDS are shuffled after every epoch and continue indefinitely. unique_pids = np.unique(pids_train) dataset = tf.data.Dataset.from_tensor_slices(unique_pids) dataset = dataset.shuffle(len(unique_pids)) # Constrain the dataset size to a multiple of the batch-size, so that # we don't get overlap at the end of each epoch. dataset = dataset.take((len(unique_pids) // args.batch_p) * args.batch_p) dataset = dataset.repeat(None) # Repeat forever. Funny way of stating it. # For every PID, get K images. dataset = dataset.map(lambda pid: sample_k_fids_for_pid( pid, all_fids=fids_train, all_pids=pids_train, batch_k=args.batch_k )) # now the dataset has been modified as [selected_fids # , pid] due to the return of the function 'sample_k_fids_for_pid' # Ungroup/flatten the batches for easy loading of the files. dataset = dataset.apply(tf.contrib.data.unbatch()) # Convert filenames to actual image tensors. net_input_size = (args.net_input_height, args.net_input_width) pre_crop_size = (args.pre_crop_height, args.pre_crop_width) dataset = dataset.map( lambda fid, pid: common.fid_to_image(fid, pid, image_root=args.image_root, image_size=pre_crop_size if args. crop_augment else net_input_size), num_parallel_calls=args.loading_threads ) # now the dataset has been modified as [selected_images # , fid, pid] due to the return of the function 'fid_to_image' # Augment the data if specified by the arguments. if args.flip_augment: dataset = dataset.map(lambda im, fid, pid: (tf.image.random_flip_left_right(im), fid, pid)) if args.crop_augment: dataset = dataset.map(lambda im, fid, pid: (tf.random_crop( im, net_input_size + (3, )), fid, pid)) # Group it back into PK batches. batch_size = args.batch_p * args.batch_k dataset = dataset.batch(batch_size) # Overlap producing and consuming for parallelism. dataset = dataset.prefetch(1) # Since we repeat the data infinitely, we only need a one-shot iterator. images_train, fids_train, pids_train = dataset.make_one_shot_iterator( ).get_next() ######################################################################################################################## #prepare the validation set pids_val, fids_val = common.load_dataset(args.validation_set, args.validation_image_root) # Setup a tf.Dataset where one "epoch" loops over all PIDS. # PIDS are shuffled after every epoch and continue indefinitely. unique_pids_val = np.unique(pids_val) dataset_val = tf.data.Dataset.from_tensor_slices(unique_pids_val) dataset_val = dataset_val.shuffle(len(unique_pids_val)) # Constrain the dataset size to a multiple of the batch-size, so that # we don't get overlap at the end of each epoch. dataset_val = dataset_val.take( (len(unique_pids_val) // args.batch_p) * args.batch_p) dataset_val = dataset_val.repeat( None) # Repeat forever. Funny way of stating it. # For every PID, get K images. dataset_val = dataset_val.map(lambda pid: sample_k_fids_for_pid( pid, all_fids=fids_val, all_pids=pids_val, batch_k=args.batch_k )) # now the dataset has been modified as [selected_fids # , pid] due to the return of the function 'sample_k_fids_for_pid' # Ungroup/flatten the batches for easy loading of the files. dataset_val = dataset_val.apply(tf.contrib.data.unbatch()) # Convert filenames to actual image tensors. net_input_size = (args.net_input_height, args.net_input_width) pre_crop_size = (args.pre_crop_height, args.pre_crop_width) dataset_val = dataset_val.map( lambda fid, pid: common.fid_to_image( fid, pid, image_root=args.validation_image_root, image_size=pre_crop_size if args.crop_augment else net_input_size), num_parallel_calls=args.loading_threads ) # now the dataset has been modified as [selected_images # , fid, pid] due to the return of the function 'fid_to_image' # Augment the data if specified by the arguments. if args.flip_augment: dataset_val = dataset_val.map(lambda im, fid, pid: ( tf.image.random_flip_left_right(im), fid, pid)) if args.crop_augment: dataset_val = dataset_val.map(lambda im, fid, pid: (tf.random_crop( im, net_input_size + (3, )), fid, pid)) # Group it back into PK batches. dataset_val = dataset_val.batch(batch_size) # Overlap producing and consuming for parallelism. dataset_val = dataset_val.prefetch(1) # Since we repeat the data infinitely, we only need a one-shot iterator. images_val, fids_val, pids_val = dataset_val.make_one_shot_iterator( ).get_next() #################################################################################################################### # Create the model and an embedding head. model = import_module('nets.' + args.model_name) head = import_module('heads.' + args.head_name) # Feed the image through the model. The returned `body_prefix` will be used # further down to load the pre-trained weights for all variables with this # prefix. input_images = tf.placeholder( dtype=tf.float32, shape=[None, args.net_input_height, args.net_input_width, 3], name='input') pids = tf.placeholder(dtype=tf.string, shape=[ None, ], name='pids') fids = tf.placeholder(dtype=tf.string, shape=[ None, ], name='fids') endpoints, body_prefix = model.endpoints(input_images, is_training=True) with tf.name_scope('head'): endpoints = head.head(endpoints, args.embedding_dim, is_training=True) # Create the loss in two steps: # 1. Compute all pairwise distances according to the specified metric. # 2. For each anchor along the first dimension, compute its loss. # dists = loss.cdist(endpoints['emb'], endpoints['emb'], metric=args.metric) # losses, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[args.loss]( # dists, pids, args.margin, batch_precision_at_k=args.batch_k-1) # # '_' stands for the boolean matrix shows topK where the correct match of the identities occurs # shape=(batch_size,K) # 更改 # loss1 dists1 = loss.cdist(endpoints['feature1'], endpoints['feature1'], metric=args.metric) losses1, _, _, _, _, _ = loss.LOSS_CHOICES[args.loss]( dists1, pids, args.margin, batch_precision_at_k=args.batch_k - 1) dists2 = loss.cdist(endpoints['feature2'], endpoints['feature2'], metric=args.metric) losses2, _, _, _, _, _ = loss.LOSS_CHOICES[args.loss]( dists2, pids, args.margin, batch_precision_at_k=args.batch_k - 1) dists3 = loss.cdist(endpoints['feature3'], endpoints['feature3'], metric=args.metric) losses3, _, _, _, _, _ = loss.LOSS_CHOICES[args.loss]( dists3, pids, args.margin, batch_precision_at_k=args.batch_k - 1) dists4 = loss.cdist(endpoints['feature4'], endpoints['feature4'], metric=args.metric) losses4, _, _, _, _, _ = loss.LOSS_CHOICES[args.loss]( dists4, pids, args.margin, batch_precision_at_k=args.batch_k - 1) dists_fu = loss.cdist(endpoints['fusion_layer'], endpoints['fusion_layer'], metric=args.metric) losses_fu, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[ args.loss](dists_fu, pids, args.margin, batch_precision_at_k=args.batch_k - 1) losses = losses1 + losses2 + losses3 + losses4 + losses_fu # 更改 #loss # losses_fu, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[args.loss]( # endpoints,pids, model_type=args.model_name, metric=args.metric, batch_precision_at_k=args.batch_k - 1 # ) # Count the number of active entries, and compute the total batch loss. num_active = tf.reduce_sum(tf.cast(tf.greater(losses, 1e-5), tf.float32)) # 此处losses即为 pospair 比 negpair+margin 还大的部分 loss_mean = tf.reduce_mean(losses) # Some logging for tensorboard. tf.summary.histogram('loss_distribution', losses) tf.summary.scalar('loss', loss_mean) tf.summary.scalar('batch_top1', train_top1) tf.summary.scalar('batch_prec_at_{}'.format(args.batch_k - 1), prec_at_k) tf.summary.scalar('active_count', num_active) #tf.summary.histogram('embedding_dists', dists) tf.summary.histogram('embedding_pos_dists', pos_dists) tf.summary.histogram('embedding_neg_dists', neg_dists) tf.summary.histogram('embedding_lengths', tf.norm(endpoints['emb_raw'], axis=1)) # Create the mem-mapped arrays in which we'll log all training detail in # addition to tensorboard, because tensorboard is annoying for detailed # inspection and actually discards data in histogram summaries. if args.detailed_logs: log_embs = lb.create_or_resize_dat( os.path.join(args.experiment_root, 'embeddings'), dtype=np.float32, shape=(args.train_iterations, batch_size, args.embedding_dim)) log_loss = lb.create_or_resize_dat( os.path.join(args.experiment_root, 'losses'), dtype=np.float32, shape=(args.train_iterations, batch_size)) log_fids = lb.create_or_resize_dat( os.path.join(args.experiment_root, 'fids'), dtype='S' + str(max_fid_len), shape=(args.train_iterations, batch_size)) # These are collected here before we add the optimizer, because depending # on the optimizer, it might add extra slots, which are also global # variables, with the exact same prefix. model_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, body_prefix) # Define the optimizer and the learning-rate schedule. # Unfortunately, we get NaNs if we don't handle no-decay separately. global_step = tf.Variable( 0, name='global_step', trainable=False) # 'global_step' means the number of batches seen # by graph if 0 <= args.decay_start_iteration < args.train_iterations: learning_rate = tf.train.exponential_decay( args.learning_rate, tf.maximum(0, global_step - args.decay_start_iteration ), # decay every 'lr_decay_steps' after the # 'decay_start_iteration' # args.train_iterations - args.decay_start_iteration, args.weight_decay_factor) args.lr_decay_steps, args.lr_decay_factor, staircase=True) else: learning_rate = args.learning_rate # the case when we set 'decay_start_iteration' as -1 tf.summary.scalar('learning_rate', learning_rate) optimizer = tf.train.AdamOptimizer(learning_rate, epsilon=1e-3) # Feel free to try others! # optimizer = tf.train.AdadeltaOptimizer(learning_rate) # Update_ops are used to update batchnorm stats. with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): train_op = optimizer.minimize(loss_mean, global_step=global_step) # Define a saver for the complete model. checkpoint_saver = tf.train.Saver(max_to_keep=0) with tf.Session(config=config) as sess: if args.resume: # In case we're resuming, simply load the full checkpoint to init. if args.checkpoint is None: last_checkpoint = tf.train.latest_checkpoint( args.experiment_root) log.info( 'Restoring from checkpoint: {}'.format(last_checkpoint)) checkpoint_saver.restore(sess, last_checkpoint) else: ckpt_path = os.path.join(args.experiment_root, args.checkpoint) log.info('Restoring from checkpoint: {}'.format( args.checkpoint)) checkpoint_saver.restore(sess, ckpt_path) else: # But if we're starting from scratch, we may need to load some # variables from the pre-trained weights, and random init others. sess.run(tf.global_variables_initializer()) if args.initial_checkpoint is not None: saver = tf.train.Saver(model_variables) saver.restore( sess, args.initial_checkpoint ) # restore the pre-trained parameter from online model # In any case, we also store this initialization as a checkpoint, # such that we could run exactly reproduceable experiments. checkpoint_saver.save(sess, os.path.join(args.experiment_root, 'checkpoint'), global_step=0) merged_summary = tf.summary.merge_all() summary_writer = tf.summary.FileWriter(args.experiment_root, sess.graph) start_step = sess.run(global_step) log.info('Starting training from iteration {}.'.format(start_step)) # Finally, here comes the main-loop. This `Uninterrupt` is a handy # utility such that an iteration still finishes on Ctrl+C and we can # stop the training cleanly. with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u: for i in range(start_step, args.train_iterations): # Compute gradients, update weights, store logs! start_time = time.time() _, summary, step, b_prec_at_k, b_embs, b_loss, b_fids = \ sess.run([train_op, merged_summary, global_step, prec_at_k, endpoints['emb'], losses, fids], feed_dict={input_images:images_train.eval(), pids:pids_train.eval(), fids:fids_train.eval()}) elapsed_time = time.time() - start_time # Compute the iteration speed and add it to the summary. # We did observe some weird spikes that we couldn't track down. summary2 = tf.Summary() summary2.value.add(tag='secs_per_iter', simple_value=elapsed_time) summary_writer.add_summary(summary2, step) summary_writer.add_summary(summary, step) if args.detailed_logs: log_embs[i], log_loss[i], log_fids[ i] = b_embs, b_loss, b_fids # Do a huge print out of the current progress. seconds_todo = (args.train_iterations - step) * elapsed_time log.info( 'iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, ' 'batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it)'.format( step, float(np.min(b_loss)), float(np.mean(b_loss)), float(np.max(b_loss)), args.batch_k - 1, float(b_prec_at_k), timedelta(seconds=int(seconds_todo)), elapsed_time)) sys.stdout.flush() sys.stderr.flush() # Save a checkpoint of training every so often. if (args.checkpoint_frequency > 0 and step % args.checkpoint_frequency == 0): checkpoint_saver.save(sess, os.path.join(args.experiment_root, 'checkpoint'), global_step=step) #get validation results if (args.validation_frequency > 0 and step % args.validation_frequency == 0): b_prec_at_k_val, b_loss, b_fids = \ sess.run([prec_at_k, losses, fids], feed_dict={input_images : images_val.eval(), pids:pids_val.eval(), fids:fids_val.eval()}) log.info( 'Validation @:{:6d} iteration, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, ' 'batch-p@{}: {:.2%}'.format(step, float(np.min(b_loss)), float(np.mean(b_loss)), float(np.max(b_loss)), args.batch_k - 1, float(b_prec_at_k_val))) sys.stdout.flush() sys.stderr.flush() summary3 = tf.Summary() summary3.value.add(tag='secs_per_iter', simple_value=float(np.mean(b_loss))) summary_writer.add_summary(summary3, step) summary_writer.add_summary(summary3, step) # Stop the main-loop at the end of the step, if requested. if u.interrupted: log.info("Interrupted on request!") break # Store one final checkpoint. This might be redundant, but it is crucial # in case intermediate storing was disabled and it saves a checkpoint # when the process was interrupted. checkpoint_saver.save(sess, os.path.join(args.experiment_root, 'checkpoint'), global_step=step)
def main(): args = parser.parse_args() # Data augmentation global seq_geo global seq_img seq_geo = iaa.SomeOf( (0, 5), [ iaa.Fliplr(0.5), # horizontally flip 50% of the images iaa.PerspectiveTransform(scale=(0, 0.075)), iaa.Affine( scale={ "x": (0.8, 1.0), "y": (0.8, 1.0) }, rotate=(-5, 5), translate_percent={ "x": (-0.1, 0.1), "y": (-0.1, 0.1) }, ), # rotate by -45 to +45 degrees), iaa.Crop(pc=( 0, 0.125 )), # crop images from each side by 0 to 12.5% (randomly chosen) iaa.CoarsePepper(p=0.01, size_percent=0.1) ], random_order=False) # Content transformation seq_img = iaa.SomeOf( (0, 3), [ iaa.GaussianBlur( sigma=(0, 1.0)), # blur images with a sigma of 0 to 2.0 iaa.ContrastNormalization(alpha=(0.9, 1.1)), iaa.Grayscale(alpha=(0, 0.2)), iaa.Multiply((0.9, 1.1)) ]) # We store all arguments in a json file. This has two advantages: # 1. We can always get back and see what exactly that experiment was # 2. We can resume an experiment as-is without needing to remember all flags. args_file = os.path.join(args.experiment_root, 'args.json') if args.resume: if not os.path.isfile(args_file): raise IOError('`args.json` not found in {}'.format(args_file)) print('Loading args from {}.'.format(args_file)) with open(args_file, 'r') as f: args_resumed = json.load(f) args_resumed['resume'] = True # This would be overwritten. # When resuming, we not only want to populate the args object with the # values from the file, but we also want to check for some possible # conflicts between loaded and given arguments. for key, value in args.__dict__.items(): if key in args_resumed: resumed_value = args_resumed[key] if resumed_value != value: print('Warning: For the argument `{}` we are using the' ' loaded value `{}`. The provided value was `{}`' '.'.format(key, resumed_value, value)) args.__dict__[key] = resumed_value else: print('Warning: A new argument was added since the last run:' ' `{}`. Using the new value: `{}`.'.format(key, value)) else: # If the experiment directory exists already, we bail in fear. if os.path.exists(args.experiment_root): if os.listdir(args.experiment_root): print('The directory {} already exists and is not empty.' ' If you want to resume training, append --resume to' ' your call.'.format(args.experiment_root)) exit(1) else: os.makedirs(args.experiment_root) # Store the passed arguments for later resuming and grepping in a nice # and readable format. with open(args_file, 'w') as f: json.dump(vars(args), f, ensure_ascii=False, indent=2, sort_keys=True) log_file = os.path.join(args.experiment_root, "train") logging.config.dictConfig(common.get_logging_dict(log_file)) log = logging.getLogger('train') # Also show all parameter values at the start, for ease of reading logs. log.info('Training using the following parameters:') for key, value in sorted(vars(args).items()): log.info('{}: {}'.format(key, value)) # Check them here, so they are not required when --resume-ing. if not args.train_set: parser.print_help() log.error("You did not specify the `train_set` argument!") sys.exit(1) if not args.image_root: parser.print_help() log.error("You did not specify the required `image_root` argument!") sys.exit(1) # Load the data from the CSV file. pids, fids = common.load_dataset(args.train_set, args.image_root) max_fid_len = max(map(len, fids)) # We'll need this later for logfiles. # Load feature embeddings if args.hard_pool_size > 0: with h5py.File(args.train_embeddings, 'r') as f_train: train_embs = np.array(f_train['emb']) f_dists = scipy.spatial.distance.cdist(train_embs, train_embs) hard_ids = get_hard_id_pool(pids, f_dists, args.hard_pool_size) # Setup a tf.Dataset where one "epoch" loops over all PIDS. # PIDS are shuffled after every epoch and continue indefinitely. unique_pids = np.unique(pids) dataset = tf.data.Dataset.from_tensor_slices(unique_pids) dataset = dataset.shuffle(len(unique_pids)) # Constrain the dataset size to a multiple of the batch-size, so that # we don't get overlap at the end of each epoch. if args.hard_pool_size == 0: dataset = dataset.take( (len(unique_pids) // args.batch_p) * args.batch_p) dataset = dataset.repeat( None) # Repeat forever. Funny way of stating it. else: dataset = dataset.repeat( None) # Repeat forever. Funny way of stating it. dataset = dataset.map(lambda pid: sample_batch_ids_for_pid( pid, all_pids=pids, batch_p=args.batch_p, all_hard_pids=hard_ids)) # Unbatch the P PIDs dataset = dataset.apply(tf.contrib.data.unbatch()) # For every PID, get K images. dataset = dataset.map(lambda pid: sample_k_fids_for_pid( pid, all_fids=fids, all_pids=pids, batch_k=args.batch_k)) # Ungroup/flatten the batches for easy loading of the files. dataset = dataset.apply(tf.contrib.data.unbatch()) # Convert filenames to actual image tensors. net_input_size = (args.net_input_height, args.net_input_width) pre_crop_size = (args.pre_crop_height, args.pre_crop_width) dataset = dataset.map(lambda im, fid, pid: common.fid_to_image( fid, pid, image_root=args.image_root, image_size=pre_crop_size if args.crop_augment else net_input_size), num_parallel_calls=args.loading_threads) # Augment the data if specified by the arguments. if args.augment == False: dataset = dataset.map( lambda fid, pid: common.fid_to_image(fid, pid, image_root=args.image_root, image_size=pre_crop_size if args.crop_augment else net_input_size), #Ergys num_parallel_calls=args.loading_threads) if args.flip_augment: dataset = dataset.map(lambda im, fid, pid: ( tf.image.random_flip_left_right(im), fid, pid)) if args.crop_augment: dataset = dataset.map(lambda im, fid, pid: (tf.random_crop( im, net_input_size + (3, )), fid, pid)) else: dataset = dataset.map(lambda im, fid, pid: common.fid_to_image( fid, pid, image_root=args.image_root, image_size=net_input_size), num_parallel_calls=args.loading_threads) dataset = dataset.map(lambda im, fid, pid: (tf.py_func( augment_images, [im], [tf.float32]), fid, pid)) dataset = dataset.map(lambda im, fid, pid: (tf.reshape( im[0], (args.net_input_height, args.net_input_width, 3)), fid, pid)) # Group it back into PK batches. batch_size = args.batch_p * args.batch_k dataset = dataset.batch(batch_size) # Overlap producing and consuming for parallelism. dataset = dataset.prefetch(batch_size * 2) # Since we repeat the data infinitely, we only need a one-shot iterator. images, fids, pids = dataset.make_one_shot_iterator().get_next() # Create the model and an embedding head. model = import_module('nets.' + args.model_name) head = import_module('heads.' + args.head_name) # Feed the image through the model. The returned `body_prefix` will be used # further down to load the pre-trained weights for all variables with this # prefix. endpoints, body_prefix = model.endpoints(images, is_training=True) with tf.name_scope('head'): endpoints = head.head(endpoints, args.embedding_dim, is_training=True) # Create the loss in two steps: # 1. Compute all pairwise distances according to the specified metric. # 2. For each anchor along the first dimension, compute its loss. dists = loss.cdist(endpoints['emb'], endpoints['emb'], metric=args.metric) losses, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[ args.loss](dists, pids, args.margin, batch_precision_at_k=args.batch_k - 1) # Count the number of active entries, and compute the total batch loss. num_active = tf.reduce_sum(tf.cast(tf.greater(losses, 1e-5), tf.float32)) loss_mean = tf.reduce_mean(losses) # Some logging for tensorboard. tf.summary.histogram('loss_distribution', losses) tf.summary.scalar('loss', loss_mean) tf.summary.scalar('batch_top1', train_top1) tf.summary.scalar('batch_prec_at_{}'.format(args.batch_k - 1), prec_at_k) tf.summary.scalar('active_count', num_active) tf.summary.histogram('embedding_dists', dists) tf.summary.histogram('embedding_pos_dists', pos_dists) tf.summary.histogram('embedding_neg_dists', neg_dists) tf.summary.histogram('embedding_lengths', tf.norm(endpoints['emb_raw'], axis=1)) # Create the mem-mapped arrays in which we'll log all training detail in # addition to tensorboard, because tensorboard is annoying for detailed # inspection and actually discards data in histogram summaries. if args.detailed_logs: log_embs = lb.create_or_resize_dat( os.path.join(args.experiment_root, 'embeddings'), dtype=np.float32, shape=(args.train_iterations, batch_size, args.embedding_dim)) log_loss = lb.create_or_resize_dat( os.path.join(args.experiment_root, 'losses'), dtype=np.float32, shape=(args.train_iterations, batch_size)) log_fids = lb.create_or_resize_dat( os.path.join(args.experiment_root, 'fids'), dtype='S' + str(max_fid_len), shape=(args.train_iterations, batch_size)) # These are collected here before we add the optimizer, because depending # on the optimizer, it might add extra slots, which are also global # variables, with the exact same prefix. model_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, body_prefix) # Define the optimizer and the learning-rate schedule. # Unfortunately, we get NaNs if we don't handle no-decay separately. global_step = tf.Variable(0, name='global_step', trainable=False) if 0 <= args.decay_start_iteration < args.train_iterations: learning_rate = tf.train.exponential_decay( args.learning_rate, tf.maximum(0, global_step - args.decay_start_iteration), args.train_iterations - args.decay_start_iteration, 0.001) else: learning_rate = args.learning_rate tf.summary.scalar('learning_rate', learning_rate) optimizer = tf.train.AdamOptimizer(learning_rate) # Feel free to try others! # optimizer = tf.train.AdadeltaOptimizer(learning_rate) # Update_ops are used to update batchnorm stats. with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): train_op = optimizer.minimize(loss_mean, global_step=global_step) # Define a saver for the complete model. checkpoint_saver = tf.train.Saver(max_to_keep=0) with tf.Session() as sess: if args.resume: # In case we're resuming, simply load the full checkpoint to init. last_checkpoint = tf.train.latest_checkpoint(args.experiment_root) log.info('Restoring from checkpoint: {}'.format(last_checkpoint)) checkpoint_saver.restore(sess, last_checkpoint) else: # But if we're starting from scratch, we may need to load some # variables from the pre-trained weights, and random init others. sess.run(tf.global_variables_initializer()) if args.initial_checkpoint is not None: saver = tf.train.Saver(model_variables) saver.restore(sess, args.initial_checkpoint) # In any case, we also store this initialization as a checkpoint, # such that we could run exactly reproduceable experiments. checkpoint_saver.save(sess, os.path.join(args.experiment_root, 'checkpoint'), global_step=0) merged_summary = tf.summary.merge_all() summary_writer = tf.summary.FileWriter(args.experiment_root, sess.graph) start_step = sess.run(global_step) log.info('Starting training from iteration {}.'.format(start_step)) # Finally, here comes the main-loop. This `Uninterrupt` is a handy # utility such that an iteration still finishes on Ctrl+C and we can # stop the training cleanly. with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u: for i in range(start_step, args.train_iterations): # Compute gradients, update weights, store logs! start_time = time.time() _, summary, step, b_prec_at_k, b_embs, b_loss, b_fids = \ sess.run([train_op, merged_summary, global_step, prec_at_k, endpoints['emb'], losses, fids]) elapsed_time = time.time() - start_time # Compute the iteration speed and add it to the summary. # We did observe some weird spikes that we couldn't track down. summary2 = tf.Summary() summary2.value.add(tag='secs_per_iter', simple_value=elapsed_time) summary_writer.add_summary(summary2, step) summary_writer.add_summary(summary, step) if args.detailed_logs: log_embs[i], log_loss[i], log_fids[ i] = b_embs, b_loss, b_fids # Do a huge print out of the current progress. seconds_todo = (args.train_iterations - step) * elapsed_time log.info( 'iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, ' 'batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it)'.format( step, float(np.min(b_loss)), float(np.mean(b_loss)), float(np.max(b_loss)), args.batch_k - 1, float(b_prec_at_k), timedelta(seconds=int(seconds_todo)), elapsed_time)) sys.stdout.flush() sys.stderr.flush() # Save a checkpoint of training every so often. if (args.checkpoint_frequency > 0 and step % args.checkpoint_frequency == 0): checkpoint_saver.save(sess, os.path.join(args.experiment_root, 'checkpoint'), global_step=step) # Stop the main-loop at the end of the step, if requested. if u.interrupted: log.info("Interrupted on request!") break # Store one final checkpoint. This might be redundant, but it is crucial # in case intermediate storing was disabled and it saves a checkpoint # when the process was interrupted. checkpoint_saver.save(sess, os.path.join(args.experiment_root, 'checkpoint'), global_step=step)
def main(): args = parser.parse_args() # We store all arguments in a json file. This has two advantages: # 1. We can always get back and see what exactly that experiment was # 2. We can resume an experiment as-is without needing to remember all flags. args_file = os.path.join(args.experiment_root, 'args.json') if args.resume: if not os.path.isfile(args_file): raise IOError('`args.json` not found in {}'.format(args_file)) print('Loading args from {}.'.format(args_file)) with open(args_file, 'r') as f: args_resumed = json.load(f) args_resumed['resume'] = True # This would be overwritten. # When resuming, we not only want to populate the args object with the # values from the file, but we also want to check for some possible # conflicts between loaded and given arguments. for key, value in args.__dict__.items(): if key in args_resumed: resumed_value = args_resumed[key] if resumed_value != value: print('Warning: For the argument `{}` we are using the' ' loaded value `{}`. The provided value was `{}`' '.'.format(key, resumed_value, value)) args.__dict__[key] = resumed_value else: print('Warning: A new argument was added since the last run:' ' `{}`. Using the new value: `{}`.'.format(key, value)) else: # If the experiment directory exists already, we bail in fear. if os.path.exists(args.experiment_root): if os.listdir(args.experiment_root): print('The directory {} already exists and is not empty.' ' If you want to resume training, append --resume to' ' your call.'.format(args.experiment_root)) exit(1) else: os.makedirs(args.experiment_root) # Store the passed arguments for later resuming and grepping in a nice # and readable format. with open(args_file, 'w') as f: json.dump(vars(args), f, ensure_ascii=False, indent=2, sort_keys=True) log_file = os.path.join(args.experiment_root, "train") logging.config.dictConfig(common.get_logging_dict(log_file)) log = logging.getLogger('train') # Also show all parameter values at the start, for ease of reading logs. log.info('Training using the following parameters:') for key, value in sorted(vars(args).items()): log.info('{}: {}'.format(key, value)) # Check them here, so they are not required when --resume-ing. if not args.train_set: parser.print_help() log.error("You did not specify the `train_set` argument!") sys.exit(1) if not args.image_root: parser.print_help() log.error("You did not specify the required `image_root` argument!") sys.exit(1) # Load the data from the CSV file. pids, fids = common.load_dataset(args.train_set, args.image_root) max_fid_len = max(map(len, fids)) # We'll need this later for logfiles. # Setup a tf.Dataset where one "epoch" loops over all PIDS. # PIDS are shuffled after every epoch and continue indefinitely. unique_pids = np.unique(pids) dataset = tf.data.Dataset.from_tensor_slices(unique_pids) dataset = dataset.shuffle(len(unique_pids)) # Constrain the dataset size to a multiple of the batch-size, so that # we don't get overlap at the end of each epoch. dataset = dataset.take((len(unique_pids) // args.batch_p) * args.batch_p) dataset = dataset.repeat(None) # Repeat forever. Funny way of stating it. # For every PID, get K images. dataset = dataset.map(lambda pid: sample_k_fids_for_pid( pid, all_fids=fids, all_pids=pids, batch_k=args.batch_k)) # Ungroup/flatten the batches for easy loading of the files. dataset = dataset.apply(tf.contrib.data.unbatch()) # Convert filenames to actual image tensors. net_input_size = (args.net_input_height, args.net_input_width) pre_crop_size = (args.pre_crop_height, args.pre_crop_width) dataset = dataset.map(lambda fid, pid: common.fid_to_image( fid, pid, image_root=args.image_root, image_size=pre_crop_size if args.crop_augment else net_input_size), num_parallel_calls=args.loading_threads) # Augment the data if specified by the arguments. if args.flip_augment: dataset = dataset.map(lambda im, fid, pid: (tf.image.random_flip_left_right(im), fid, pid)) if args.crop_augment: dataset = dataset.map(lambda im, fid, pid: (tf.random_crop( im, net_input_size + (3, )), fid, pid)) # Group it back into PK batches. batch_size = args.batch_p * args.batch_k dataset = dataset.batch(batch_size) # Overlap producing and consuming for parallelism. dataset = dataset.prefetch(1) # Since we repeat the data infinitely, we only need a one-shot iterator. images, fids, pids = dataset.make_one_shot_iterator().get_next() # Create the model and an embedding head. model = import_module('nets.' + args.model_name) head = import_module('heads.' + args.head_name) # Feed the image through the model. The returned `body_prefix` will be used # further down to load the pre-trained weights for all variables with this # prefix. endpoints, body_prefix = model.endpoints(images, is_training=True) with tf.name_scope('head'): endpoints = head.head(endpoints, args.embedding_dim, is_training=True) # Create the loss in two steps: # 1. Compute all pairwise distances according to the specified metric. # 2. For each anchor along the first dimension, compute its loss. dists = loss.cdist(endpoints['emb'], endpoints['emb'], metric=args.metric) losses, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[ args.loss](dists, pids, args.margin, batch_precision_at_k=args.batch_k - 1) decDense = tf.layers.dense( inputs=endpoints['emb'], units=5120, name='decDense') # ,activation = tf.nn.relu ################ unflat = tf.reshape(decDense, shape=[tf.shape(decDense)[0], 32, 16, 10]) unp3shape = tf.TensorShape( [2 * di for di in unflat.get_shape().as_list()[1:-1]]) unPool3 = tf.image.resize_nearest_neighbor(unflat, unp3shape, name='unpool3') deConv3 = tf.layers.conv2d(inputs=unPool3, filters=64, kernel_size=[5, 5], strides=(1, 1), padding='same', activation=tf.nn.relu, name='deConv3') unp2shape = tf.TensorShape( [2 * di for di in deConv3.get_shape().as_list()[1:-1]]) unPool2 = tf.image.resize_nearest_neighbor(deConv3, unp2shape, name='unpool2') deConv2 = tf.layers.conv2d(inputs=unPool2, filters=32, kernel_size=[5, 5], strides=(1, 1), padding='same', activation=tf.nn.relu, name='deConv2') unp1shape = tf.TensorShape( [2 * di for di in deConv2.get_shape().as_list()[1:-1]]) unPool1 = tf.image.resize_nearest_neighbor(deConv2, unp1shape, name='unpool1') deConv1 = tf.layers.conv2d(inputs=unPool1, filters=3, kernel_size=[5, 5], strides=(1, 1), padding='same', activation=None, name='deConv1') imClip = deConv1 #tf.clip_by_value(t = deConv1,clip_value_min = -1.0,clip_value_max = 1.0,name='clipRelu') print('RconstructeddImage : ', imClip.name) recLoss = tf.multiply( 0.01, tf.losses.mean_squared_error( labels=images, predictions=imClip, )) print('recLoss : ', recLoss.name) decDense1 = tf.layers.dense( inputs=endpoints['emb'], units=5120, name='decDense1') # ,activation = tf.nn.relu ################ unflat1 = tf.reshape(decDense1, shape=[tf.shape(decDense1)[0], 32, 16, 10]) unp3shape1 = tf.TensorShape( [2 * di for di in unflat1.get_shape().as_list()[1:-1]]) unPool3_new = tf.image.resize_nearest_neighbor(unflat1, unp3shape1, name='unpool3_new') deConv3_new = tf.layers.conv2d(inputs=unPool3_new, filters=64, kernel_size=[5, 5], strides=(1, 1), padding='same', activation=tf.nn.relu, name='deConv3_new') unp2shape_new = tf.TensorShape( [2 * di for di in deConv3_new.get_shape().as_list()[1:-1]]) unPool2_new = tf.image.resize_nearest_neighbor(deConv3_new, unp2shape_new, name='unpool2_new') deConv2_new = tf.layers.conv2d(inputs=unPool2_new, filters=3, kernel_size=[5, 5], strides=(1, 1), padding='same', activation=tf.nn.relu, name='deConv2_new') unp1shape_new = tf.TensorShape( [2 * di for di in deConv2_new.get_shape().as_list()[1:-1]]) unPool1_new = tf.image.resize_nearest_neighbor(deConv2_new, unp1shape_new, name='unpool1_new') deConv1_new = tf.layers.conv2d(inputs=unPool1_new, filters=3, kernel_size=[5, 5], strides=(1, 1), padding='same', activation=None, name='deConv1_new') imClip1 = deConv2_new print('RconstructeddImage : ', imClip1.name) print(imClip1.shape) images2 = tf.image.resize_images(images, [128, 64]) print(images2.shape) recLoss1 = tf.multiply( 0.01, tf.losses.mean_squared_error( labels=images2, predictions=imClip1, )) print('recLoss_new : ', recLoss1.name) decDense2 = tf.layers.dense( inputs=endpoints['emb'], units=5120, name='decDense2') # ,activation = tf.nn.relu ################ unflat12 = tf.reshape(decDense2, shape=[tf.shape(decDense2)[0], 32, 16, 10]) unp3shape12 = tf.TensorShape( [2 * di for di in unflat12.get_shape().as_list()[1:-1]]) unPool3_new2 = tf.image.resize_nearest_neighbor(unflat12, unp3shape12, name='unpool3_new2') deConv3_new2 = tf.layers.conv2d(inputs=unPool3_new2, filters=3, kernel_size=[5, 5], strides=(1, 1), padding='same', activation=tf.nn.relu, name='deConv3_new2') unp2shape_new2 = tf.TensorShape( [2 * di for di in deConv3_new2.get_shape().as_list()[1:-1]]) unPool2_new2 = tf.image.resize_nearest_neighbor(deConv3_new2, unp2shape_new2, name='unpool2_new2') imClip11 = deConv3_new2 images21 = tf.image.resize_images(images, [64, 32]) recLoss2 = tf.multiply( 0.01, tf.losses.mean_squared_error( labels=images21, predictions=imClip11, )) print('recLoss_new : ', recLoss2.name) decDensel = tf.layers.dense( inputs=endpoints['emb'], units=5120, name='decDensel') # ,activation = tf.nn.relu ################ unflatl = tf.reshape(decDensel, shape=[tf.shape(decDensel)[0], 32, 16, 10]) unp3shapel = tf.TensorShape( [2 * di for di in unflatl.get_shape().as_list()[1:-1]]) unPool3l = tf.image.resize_nearest_neighbor(unflatl, unp3shapel, name='unpool3l') deConv3l = tf.layers.conv2d(inputs=unPool3l, filters=64, kernel_size=[5, 5], strides=(1, 1), padding='same', activation=tf.nn.relu, name='deConv3l') unp2shapel = tf.TensorShape( [2 * di for di in deConv3l.get_shape().as_list()[1:-1]]) unPool2l = tf.image.resize_nearest_neighbor(deConv3l, unp2shapel, name='unpool2l') deConv2l = tf.layers.conv2d(inputs=unPool2l, filters=32, kernel_size=[5, 5], strides=(1, 1), padding='same', activation=tf.nn.relu, name='deConv2l') unp1shapel = tf.TensorShape( [2 * di for di in deConv2l.get_shape().as_list()[1:-1]]) unPool1l = tf.image.resize_nearest_neighbor(deConv2l, unp1shapel, name='unpool1l') deConv1l = tf.layers.conv2d(inputs=unPool1l, filters=3, kernel_size=[5, 5], strides=(1, 1), padding='same', activation=None, name='deConv1l') imClipl = deConv1l #tf.clip_by_value(t = deConv1,clip_value_min = -1.0,clip_value_max = 1.0,name='clipRelu') print('RconstructeddImage : ', imClipl.name) recLossl = tf.multiply( 0.01, tf.losses.mean_squared_error( labels=images, predictions=imClipl, )) print('recLoss : ', recLossl.name) # Count the number of active entries, and compute the total batch loss. num_active = tf.reduce_sum(tf.cast(tf.greater(losses, 1e-5), tf.float32)) loss_mean = tf.reduce_mean(losses) # Some logging for tensorboard. tf.summary.histogram('loss_distribution', losses) tf.summary.scalar('loss', loss_mean) tf.summary.scalar('batch_top1', train_top1) tf.summary.scalar('batch_prec_at_{}'.format(args.batch_k - 1), prec_at_k) tf.summary.scalar('active_count', num_active) tf.summary.histogram('embedding_dists', dists) tf.summary.histogram('embedding_pos_dists', pos_dists) tf.summary.histogram('embedding_neg_dists', neg_dists) tf.summary.histogram('embedding_lengths', tf.norm(endpoints['emb_raw'], axis=1)) # Create the mem-mapped arrays in which we'll log all training detail in # addition to tensorboard, because tensorboard is annoying for detailed # inspection and actually discards data in histogram summaries. if args.detailed_logs: log_embs = lb.create_or_resize_dat( os.path.join(args.experiment_root, 'embeddings'), dtype=np.float32, shape=(args.train_iterations, batch_size, args.embedding_dim)) log_loss = lb.create_or_resize_dat( os.path.join(args.experiment_root, 'losses'), dtype=np.float32, shape=(args.train_iterations, batch_size)) log_fids = lb.create_or_resize_dat( os.path.join(args.experiment_root, 'fids'), dtype='S' + str(max_fid_len), shape=(args.train_iterations, batch_size)) # These are collected here before we add the optimizer, because depending # on the optimizer, it might add extra slots, which are also global # variables, with the exact same prefix. model_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, body_prefix) # Define the optimizer and the learning-rate schedule. # Unfortunately, we get NaNs if we don't handle no-decay separately. global_step = tf.Variable(0, name='global_step', trainable=False) if 0 <= args.decay_start_iteration < args.train_iterations: learning_rate = tf.train.exponential_decay( args.learning_rate, tf.maximum(0, global_step - args.decay_start_iteration), args.train_iterations - args.decay_start_iteration, 0.001) else: learning_rate = args.learning_rate tf.summary.scalar('learning_rate', learning_rate) optimizer = tf.train.AdamOptimizer(learning_rate) # Feel free to try others! # optimizer = tf.train.AdadeltaOptimizer(learning_rate) # Update_ops are used to update batchnorm stats. with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): train_op = optimizer.minimize(tf.add( loss_mean, tf.add(recLoss, tf.add(recLoss1, tf.add(recLoss2, recLossl)))), global_step=global_step) # Define a saver for the complete model. checkpoint_saver = tf.train.Saver(max_to_keep=0) with tf.Session() as sess: if args.resume: # In case we're resuming, simply load the full checkpoint to init. last_checkpoint = tf.train.latest_checkpoint(args.experiment_root) log.info('Restoring from checkpoint: {}'.format(last_checkpoint)) checkpoint_saver.restore(sess, last_checkpoint) else: # But if we're starting from scratch, we may need to load some # variables from the pre-trained weights, and random init others. sess.run(tf.global_variables_initializer()) if args.initial_checkpoint is not None: saver = tf.train.Saver(model_variables) saver.restore(sess, args.initial_checkpoint) # In any case, we also store this initialization as a checkpoint, # such that we could run exactly reproduceable experiments. checkpoint_saver.save(sess, os.path.join(args.experiment_root, 'checkpoint'), global_step=0) merged_summary = tf.summary.merge_all() summary_writer = tf.summary.FileWriter(args.experiment_root, sess.graph) start_step = sess.run(global_step) log.info('Starting training from iteration {}.'.format(start_step)) # Finally, here comes the main-loop. This `Uninterrupt` is a handy # utility such that an iteration still finishes on Ctrl+C and we can # stop the training cleanly. with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u: for i in range(start_step, args.train_iterations): # Compute gradients, update weights, store logs! start_time = time.time() _, summary, step, b_prec_at_k, b_embs, b_loss, b_fids ,b_rec, b_rec1= \ sess.run([train_op, merged_summary, global_step, prec_at_k, endpoints['emb'], losses, fids,recLoss, recLoss1]) elapsed_time = time.time() - start_time # Compute the iteration speed and add it to the summary. # We did observe some weird spikes that we couldn't track down. summary2 = tf.Summary() summary2.value.add(tag='secs_per_iter', simple_value=elapsed_time) summary_writer.add_summary(summary2, step) summary_writer.add_summary(summary, step) if args.detailed_logs: log_embs[i], log_loss[i], log_fids[ i] = b_embs, b_loss, b_fids # Do a huge print out of the current progress. seconds_todo = (args.train_iterations - step) * elapsed_time log.info( 'iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, ' 'recLoss: {:.3f} batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it)'. format(step, float(np.min(b_loss)), float(np.mean(b_loss)), float(np.max(b_loss)), b_rec, args.batch_k - 1, float(b_prec_at_k), timedelta(seconds=int(seconds_todo)), elapsed_time)) sys.stdout.flush() sys.stderr.flush() # Save a checkpoint of training every so often. if (args.checkpoint_frequency > 0 and step % args.checkpoint_frequency == 0): checkpoint_saver.save(sess, os.path.join(args.experiment_root, 'checkpoint'), global_step=step) # Stop the main-loop at the end of the step, if requested. if u.interrupted: log.info("Interrupted on request!") break # Store one final checkpoint. This might be redundant, but it is crucial # in case intermediate storing was disabled and it saves a checkpoint # when the process was interrupted. checkpoint_saver.save(sess, os.path.join(args.experiment_root, 'checkpoint'), global_step=step)
def main(argv): args = parser.parse_args(argv) if args.gpu: os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu # We store all arguments in a json file. This has two advantages: # 1. We can always get back and see what exactly that experiment was # 2. We can resume an experiment as-is without needing to remember all flags. args_file = os.path.join(args.experiment_root, 'args.json') if args.resume: if not os.path.isfile(args_file): raise IOError('`args.json` not found in {}'.format(args_file)) print('Loading args from {}.'.format(args_file)) with open(args_file, 'r') as f: args_resumed = json.load(f) args_resumed['resume'] = True # This would be overwritten. # When resuming, we not only want to populate the args object with the # values from the file, but we also want to check for some possible # conflicts between loaded and given arguments. for key, value in args.__dict__.items(): if key in args_resumed: resumed_value = args_resumed[key] if resumed_value != value: print('Warning: For the argument `{}` we are using the' ' loaded value `{}`. The provided value was `{}`' '.'.format(key, resumed_value, value)) args.__dict__[key] = resumed_value else: print('Warning: A new argument was added since the last run:' ' `{}`. Using the new value: `{}`.'.format(key, value)) else: # If the experiment directory exists already, we bail in fear. if os.path.exists(args.experiment_root): if os.listdir(args.experiment_root): print('The directory {} already exists and is not empty.' ' If you want to resume training, append --resume to' ' your call.'.format(args.experiment_root)) exit(1) else: os.makedirs(args.experiment_root) # Store the passed arguments for later resuming and grepping in a nice # and readable format. with open(args_file, 'w') as f: json.dump(vars(args), f, ensure_ascii=False, indent=2, sort_keys=True) log_file = os.path.join(args.experiment_root, "train") logging.config.dictConfig(common.get_logging_dict(log_file)) log = logging.getLogger('train') # Also show all parameter values at the start, for ease of reading logs. log.info('Training using the following parameters:') for key, value in sorted(vars(args).items()): log.info('{}: {}'.format(key, value)) # Check them here, so they are not required when --resume-ing. if not args.train_set: parser.print_help() log.error("You did not specify the `train_set` argument!") sys.exit(1) if not args.image_root: parser.print_help() log.error("You did not specify the required `image_root` argument!") sys.exit(1) # Load the data from the CSV file. pids, fids = common.load_dataset(args.train_set, args.image_root) max_fid_len = max(map(len, fids)) # We'll need this later for logfiles. # Setup a tf.Dataset where one "epoch" loops over all PIDS. # PIDS are shuffled after every epoch and continue indefinitely. unique_pids = np.unique(pids) if len(unique_pids) < args.batch_p: unique_pids = np.tile(unique_pids, int(np.ceil(args.batch_p / len(unique_pids)))) dataset = tf.data.Dataset.from_tensor_slices(unique_pids) dataset = dataset.shuffle(len(unique_pids)) # Constrain the dataset size to a multiple of the batch-size, so that # we don't get overlap at the end of each epoch. dataset = dataset.take((len(unique_pids) // args.batch_p) * args.batch_p) dataset = dataset.repeat(None) # Repeat forever. Funny way of stating it. # For every PID, get K images. dataset = dataset.map(lambda pid: sample_k_fids_for_pid( pid, all_fids=fids, all_pids=pids, batch_k=args.batch_k)) # Ungroup/flatten the batches for easy loading of the files. dataset = dataset.apply(tf.contrib.data.unbatch()) # Convert filenames to actual image tensors. net_input_size = (args.net_input_height, args.net_input_width) pre_crop_size = (args.pre_crop_height, args.pre_crop_width) dataset = dataset.map(lambda fid, pid: common.fid_to_image( fid, pid, image_root=args.image_root, image_size=pre_crop_size if args.crop_augment else net_input_size), num_parallel_calls=args.loading_threads) # Augment the data if specified by the arguments. dataset = dataset.map( lambda im, fid, pid: common.fid_to_image(fid, pid, image_root=args.image_root, image_size=pre_crop_size if args.crop_augment else net_input_size), # Ergys num_parallel_calls=args.loading_threads) if args.flip_augment: dataset = dataset.map(lambda im, fid, pid: (tf.image.random_flip_left_right(im), fid, pid)) if args.crop_augment: dataset = dataset.map(lambda im, fid, pid: (tf.random_crop( im, net_input_size + (3, )), fid, pid)) # Group it back into PK batches. batch_size = args.batch_p * args.batch_k dataset = dataset.batch(batch_size) # Overlap producing and consuming for parallelism. dataset = dataset.prefetch(1) # Since we repeat the data infinitely, we only need a one-shot iterator. images, fids, pids = dataset.make_one_shot_iterator().get_next() # Create the model and an embedding head. model = import_module('nets.' + args.model_name) head = import_module('heads.' + args.head_name) # Feed the image through the model. The returned `body_prefix` will be used # further down to load the pre-trained weights for all variables with this # prefix. weight_decay = 10e-4 weights_regularizer = tf.contrib.layers.l2_regularizer(scale=weight_decay) endpoints, body_prefix = model.endpoints(images, is_training=True) with tf.name_scope('head'): endpoints = head.head(endpoints, args.embedding_dim, is_training=True, weights_regularizer=weights_regularizer) # Create the loss in two steps: # 1. Compute all pairwise distances according to the specified metric. # 2. For each anchor along the first dimension, compute its loss. # batch_embedding = endpoints['emb'] batch_embedding = endpoints['emb'] if args.loss == 'semi_hard_triplet': triplet_loss = triplet_semihard_loss(batch_embedding, pids, args.margin) elif args.loss == 'hard_triplet': triplet_loss = batch_hard(batch_embedding, pids, args.margin, args.metric) elif args.loss == 'lifted_loss': triplet_loss = lifted_loss(pids, batch_embedding, margin=args.margin) elif args.loss == 'contrastive_loss': assert batch_size % 2 == 0 assert args.batch_k == 4 ## Can work with other number but will need tuning contrastive_idx = np.tile([0, 1, 4, 3, 2, 5, 6, 7], args.batch_p // 2) for i in range(args.batch_p // 2): contrastive_idx[i * 8:i * 8 + 8] += i * 8 contrastive_idx = np.expand_dims(contrastive_idx, 1) batch_embedding_ordered = tf.gather_nd(batch_embedding, contrastive_idx) pids_ordered = tf.gather_nd(pids, contrastive_idx) # batch_embedding_ordered = tf.Print(batch_embedding_ordered,[pids_ordered],'pids_ordered :: ',summarize=1000) embeddings_anchor, embeddings_positive = tf.unstack( tf.reshape(batch_embedding_ordered, [-1, 2, args.embedding_dim]), 2, 1) # embeddings_anchor = tf.Print(embeddings_anchor,[pids_ordered,embeddings_anchor,embeddings_positive,batch_embedding,batch_embedding_ordered],"Tensors ", summarize=1000) fixed_labels = np.tile([1, 0, 0, 1], args.batch_p // 2) # fixed_labels = np.reshape(fixed_labels,(len(fixed_labels),1)) # print(fixed_labels) labels = tf.constant(fixed_labels) # labels = tf.Print(labels,[labels],'labels ',summarize=1000) triplet_loss = contrastive_loss(labels, embeddings_anchor, embeddings_positive, margin=args.margin) elif args.loss == 'angular_loss': embeddings_anchor, embeddings_positive = tf.unstack( tf.reshape(batch_embedding, [-1, 2, args.embedding_dim]), 2, 1) # pids = tf.Print(pids, [pids], 'pids:: ', summarize=100) pids, _ = tf.unstack(tf.reshape(pids, [-1, 2, 1]), 2, 1) # pids = tf.Print(pids,[pids],'pids:: ',summarize=100) triplet_loss = angular_loss(pids, embeddings_anchor, embeddings_positive, batch_size=args.batch_p, with_l2reg=True) elif args.loss == 'npairs_loss': assert args.batch_k == 2 ## Single positive pair per class embeddings_anchor, embeddings_positive = tf.unstack( tf.reshape(batch_embedding, [-1, 2, args.embedding_dim]), 2, 1) pids, _ = tf.unstack(tf.reshape(pids, [-1, 2, 1]), 2, 1) pids = tf.reshape(pids, [-1]) triplet_loss = npairs_loss(pids, embeddings_anchor, embeddings_positive) else: raise NotImplementedError('loss function {} NotImplemented'.format( args.loss)) loss_mean = tf.reduce_mean(triplet_loss) # These are collected here before we add the optimizer, because depending # on the optimizer, it might add extra slots, which are also global # variables, with the exact same prefix. model_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, body_prefix) # Define the optimizer and the learning-rate schedule. # Unfortunately, we get NaNs if we don't handle no-decay separately. global_step = tf.Variable(0, name='global_step', trainable=False) if 0 <= args.decay_start_iteration < args.train_iterations: learning_rate = tf.train.polynomial_decay(args.learning_rate, global_step, args.train_iterations, end_learning_rate=1e-7, power=1) else: learning_rate = args.learning_rate if args.optimizer == 'adam': optimizer = tf.train.AdamOptimizer(learning_rate) elif args.optimizer == 'momentum': optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.9) else: raise NotImplementedError('Invalid optimizer {}'.format( args.optimizer)) # # learning_rate = tf.train.polynomial_decay(args.learning_rate, global_step, # args.train_iterations, end_learning_rate= 1e-7, # power=1) # # Feel free to try others! # optimizer = tf.train.AdadeltaOptimizer(learning_rate) # Update_ops are used to update batchnorm stats. with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): train_op = optimizer.minimize(loss_mean, global_step=global_step) # Define a saver for the complete model. checkpoint_saver = tf.train.Saver(max_to_keep=2) gpu_options = tf.GPUOptions(allow_growth=True) gpu_config = tf.ConfigProto(gpu_options=gpu_options) with tf.Session(config=gpu_config) as sess: if args.resume: # In case we're resuming, simply load the full checkpoint to init. last_checkpoint = tf.train.latest_checkpoint(args.experiment_root) if last_checkpoint == None: print('Resume with No previous checkpoint') # But if we're starting from scratch, we may need to load some # variables from the pre-trained weights, and random init others. sess.run(tf.global_variables_initializer()) if args.initial_checkpoint is not None: saver = tf.train.Saver(model_variables) saver.restore(sess, args.initial_checkpoint) # In any case, we also store this initialization as a checkpoint, # such that we could run exactly reproduceable experiments. checkpoint_saver.save(sess, os.path.join(args.experiment_root, 'checkpoint'), global_step=0) else: log.info( 'Restoring from checkpoint: {}'.format(last_checkpoint)) checkpoint_saver.restore(sess, last_checkpoint) else: # But if we're starting from scratch, we may need to load some # variables from the pre-trained weights, and random init others. sess.run(tf.global_variables_initializer()) if args.initial_checkpoint is not None: saver = tf.train.Saver(model_variables) saver.restore(sess, args.initial_checkpoint) # In any case, we also store this initialization as a checkpoint, # such that we could run exactly reproduceable experiments. checkpoint_saver.save(sess, os.path.join(args.experiment_root, 'checkpoint'), global_step=0) start_step = sess.run(global_step) log.info('Starting training from iteration {}.'.format(start_step)) # Finally, here comes the main-loop. This `Uninterrupt` is a handy # utility such that an iteration still finishes on Ctrl+C and we can # stop the training cleanly. with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u: for i in range(start_step, args.train_iterations): # Compute gradients, update weights, store logs! start_time = time.time() _, step, b_embs, b_loss, b_fids = \ sess.run([train_op, global_step, endpoints['emb'], triplet_loss, fids]) elapsed_time = time.time() - start_time # Do a huge print out of the current progress. seconds_todo = (args.train_iterations - step) * elapsed_time log.info( 'iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, ETA: {} ({:.2f}s/it)' .format( step, float(np.min(b_loss)), float(np.mean(b_loss)), float(np.max(b_loss)), # args.batch_k - 1, float(b_prec_at_k), timedelta(seconds=int(seconds_todo)), elapsed_time)) sys.stdout.flush() sys.stderr.flush() # Save a checkpoint of training every so often. if (args.checkpoint_frequency > 0 and step % args.checkpoint_frequency == 0): checkpoint_saver.save(sess, os.path.join(args.experiment_root, 'checkpoint'), global_step=step) # Stop the main-loop at the end of the step, if requested. if u.interrupted: log.info("Interrupted on request!") break # Store one final checkpoint. This might be redundant, but it is crucial # in case intermediate storing was disabled and it saves a checkpoint # when the process was interrupted. checkpoint_saver.save(sess, os.path.join(args.experiment_root, 'checkpoint'), global_step=step)
tf.global_variables_initializer().run() if resume: last_checkpoint = tf.train.latest_checkpoint(tr.save_path) saver.restore(sess, last_checkpoint) start_step = sess.run(global_step) logger.debug('Resume training ... Start from step %d / %d .' % (start_step, train_nums)) resume = False else: start_step = 0 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u: for i in range(start_step, train_nums): _, loss_value, step, ttt = sess.run( [train_op, loss, global_step, tt]) if i % print_steps == 0: top1, top5, top10 = sess.run([ accuracy_top1_batch, accuracy_top5_batch, accuracy_top10_batch ]) logger.debug( "After %d training step(s),loss on training batch is %g.The batch test accuracy = %g , %g ,%g." % (i, loss_value, top1, top5, top10)) ''' losslist.append([step,loss_value]) accuracy.append([step,top1])
def main(args): best_acc = -1 logger = bit_common.setup_logger(args) cp, cn = smooth_BCE(eps=0.1) # Lets cuDNN benchmark conv implementations and choose the fastest. # Only good if sizes stay the same within the main loop! torch.backends.cudnn.benchmark = True device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Going to train on {device}") classes = 5 train_set, valid_set, train_loader, valid_loader = mktrainval(args, logger) logger.info(f"Loading model from {args.model}.npz") #model = models.KNOWN_MODELS[args.model](head_size=classes, zero_head=True) #model.load_from(np.load(f"{args.model}.npz")) model = EfficientNet.from_pretrained(args.model, num_classes=classes) logger.info("Moving model onto all GPUs") model = torch.nn.DataParallel(model) # Optionally resume from a checkpoint. # Load it to CPU first as we'll move the model to GPU later. # This way, we save a little bit of GPU memory when loading. start_epoch = 0 # Note: no weight-decay! optim = torch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9) # Resume fine-tuning if we find a saved model. savename = pjoin(args.logdir, args.name, "bit.pth.tar") try: logger.info(f"Model will be saved in '{savename}'") checkpoint = torch.load(savename, map_location="cpu") logger.info(f"Found saved model to resume from at '{savename}'") start_epoch = checkpoint["epoch"] model.load_state_dict(checkpoint["model"]) optim.load_state_dict(checkpoint["optim"]) logger.info(f"Resumed at epoch {start_epoch}") except FileNotFoundError: logger.info("Fine-tuning from BiT") model = model.to(device) optim.zero_grad() model.train() mixup = bit_hyperrule.get_mixup(len(train_set)) #mixup = -1 cri = torch.nn.CrossEntropyLoss().to(device) #cri = FocalLoss(cri) logger.info("Starting training!") chrono = lb.Chrono() accum_steps = 0 mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1 end = time.time() epoches = 10 scheduler = torch.optim.lr_scheduler.OneCycleLR(optim, max_lr=0.01, steps_per_epoch=1, epochs=epoches) with lb.Uninterrupt() as u: for epoch in range(start_epoch, epoches): pbar = enumerate(train_loader) pbar = tqdm.tqdm(pbar, total=len(train_loader)) scheduler.step() all_top1, all_top5 = [], [] for param_group in optim.param_groups: lr = param_group["lr"] #for x, y in recycle(train_loader): for batch_id, (x, y) in pbar: #for batch_id, (x, y) in enumerate(train_loader): # measure data loading time, which is spent in the `for` statement. chrono._done("load", time.time() - end) if u.interrupted: break # Schedule sending to GPU(s) x = x.to(device, non_blocking=True) y = y.to(device, non_blocking=True) # Update learning-rate, including stop training if over. #lr = bit_hyperrule.get_lr(step, len(train_set), args.base_lr) #if lr is None: # break if mixup > 0.0: x, y_a, y_b = mixup_data(x, y, mixup_l) # compute output with chrono.measure("fprop"): logits = model(x) top1, top5 = topk(logits, y, ks=(1, 5)) all_top1.extend(top1.cpu()) all_top5.extend(top5.cpu()) if mixup > 0.0: c = mixup_criterion(cri, logits, y_a, y_b, mixup_l) else: c = cri(logits, y) train_loss = c.item() train_acc = np.mean(all_top1) * 100.0 # Accumulate grads with chrono.measure("grads"): (c / args.batch_split).backward() accum_steps += 1 accstep = f"({accum_steps}/{args.batch_split})" if args.batch_split > 1 else "" s = f"epoch={epoch} batch {batch_id}{accstep}: loss={train_loss:.5f} train_acc={train_acc:.2f} lr={lr:.1e}" #s = f"epoch={epoch} batch {batch_id}{accstep}: loss={c.item():.5f} lr={lr:.1e}" pbar.set_description(s) #logger.info(f"[batch {batch_id}{accstep}]: loss={c_num:.5f} (lr={lr:.1e})") # pylint: disable=logging-format-interpolation logger.flush() # Update params with chrono.measure("update"): optim.step() optim.zero_grad() # Sample new mixup ratio for next batch mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1 # Run evaluation and save the model. val_loss, val_acc = run_eval(model, valid_loader, device, chrono, logger, epoch) best = val_acc > best_acc if best: best_acc = val_acc torch.save( { "epoch": epoch, "val_loss": val_loss, "val_acc": val_acc, "train_acc": train_acc, "model": model.state_dict(), "optim": optim.state_dict(), }, savename) end = time.time() logger.info(f"Timings:\n{chrono}")
def main(): args = parser.parse_args() # We store all arguments in a json file. This has two advantages: 我们将所有参数存储在json文件中。 这有两个好处: # 1. We can always get back and see what exactly that experiment was 1.我们总是可以回头看看实验是什么 # 2. We can resume an experiment as-is without needing to remember all flags.2.我们可以按原样恢复实验,无需记住所有标志。 args_file = os.path.join(args.experiment_root, 'args.json') if args.resume: if not os.path.isfile(args_file): raise IOError('`args.json` not found in {}'.format(args_file)) print('Loading args from {}.'.format(args_file)) with open(args_file, 'r') as f: args_resumed = json.load(f) args_resumed['resume'] = True # This would be overwritten. # When resuming, we not only want to populate the args object with the # values from the file, but we also want to check for some possible # conflicts between loaded and given arguments. #恢复时,我们不仅需要使用文件中的值填充args对象,但我们也想检查加载参数和给定参数之间的一些可能的冲突。 for key, value in args.__dict__.items(): if key in args_resumed: resumed_value = args_resumed[key] if resumed_value != value: print('Warning: For the argument `{}` we are using the' ' loaded value `{}`. The provided value was `{}`' '.'.format(key, resumed_value, value)) args.__dict__[key] = resumed_value else: print('Warning: A new argument was added since the last run:' ' `{}`. Using the new value: `{}`.'.format(key, value)) else: # If the experiment directory exists already, we bail in fear.如果实验目录已经存在,我们就会担心。 if os.path.exists(args.experiment_root): if os.listdir(args.experiment_root): print('The directory {} already exists and is not empty.' ' If you want to resume training, append --resume to' ' your call.'.format(args.experiment_root)) exit(1) else: os.makedirs(args.experiment_root) # Store the passed arguments for later resuming and grepping in a nice # and readable format.将传递的参数存储起来,以便稍后以一种很好且可读的格式恢复和刷新。 with open(args_file, 'w') as f: json.dump(vars(args), f, ensure_ascii=False, indent=2, sort_keys=True) log_file = os.path.join(args.experiment_root, "train") logging.config.dictConfig(common.get_logging_dict(log_file)) log = logging.getLogger('train') # Also show all parameter values at the start, for ease of reading logs. #同时在开始时显示所有参数值,便于读取日志。 log.info('Training using the following parameters:') for key, value in sorted(vars(args).items()): log.info('{}: {}'.format(key, value)) # Check them here, so they are not required when --resume-ing.在这里检查他们, if not args.train_set: parser.print_help() log.error("You did not specify the `train_set` argument!") sys.exit(1) if not args.image_root: parser.print_help() log.error("You did not specify the required `image_root` argument!") sys.exit(1) # Load the data from the CSV file. 加载CSV文件中的数据。 pids, fids = common.load_dataset(args.train_set, args.image_root) max_fid_len = max(map( len, fids)) # We'll need this later for logfiles.我们稍后需要这个日志文件。 # Setup a tf.Dataset where one "epoch" loops over all PIDS. 设置一个tf.Dataset,其中一个“epoch”在所有PIDS上循环。 # PIDS are shuffled after every epoch and continue indefinitely.PIDS在每个时代之后都会被洗牌并无限期地继续下去。 unique_pids = np.unique(pids) dataset = tf.data.Dataset.from_tensor_slices(unique_pids) dataset = dataset.shuffle(len(unique_pids)) # Constrain the dataset size to a multiple of the batch-size, so that # we don't get overlap at the end of each epoch.将数据集大小限制为批量的倍数,以便在每个时期结束时我们不会重叠。 dataset = dataset.take((len(unique_pids) // args.batch_p) * args.batch_p) dataset = dataset.repeat( None) # Repeat forever. Funny way of stating it. 永远重复。 说明它的有趣方式。 # For every PID, get K images.对于每个PID,获得K个图像。 dataset = dataset.map(lambda pid: sample_k_fids_for_pid( pid, all_fids=fids, all_pids=pids, batch_k=args.batch_k)) # Ungroup/flatten the batches for easy loading of the files.取消组合/拼合批次以便轻松加载文件。 dataset = dataset.apply(tf.contrib.data.unbatch()) # Convert filenames to actual image tensors. 将文件名转换为实际图像张量。 net_input_size = (args.net_input_height, args.net_input_width) pre_crop_size = (args.pre_crop_height, args.pre_crop_width) dataset = dataset.map(lambda fid, pid: common.fid_to_image( fid, pid, image_root=args.image_root, image_size=pre_crop_size if args.crop_augment else net_input_size), num_parallel_calls=args.loading_threads) # Augment the data if specified by the arguments. 如果由参数指定,则增加数据。 if args.flip_augment: dataset = dataset.map(lambda im, fid, pid: (tf.image.random_flip_left_right(im), fid, pid)) if args.crop_augment: dataset = dataset.map(lambda im, fid, pid: (tf.random_crop( im, net_input_size + (3, )), fid, pid)) # Group it back into PK batches. 将其重新分组为PK批次。 batch_size = args.batch_p * args.batch_k dataset = dataset.batch(batch_size) # Overlap producing and consuming for parallelism. dataset = dataset.prefetch(1) # Since we repeat the data infinitely, we only need a one-shot iterator.由于我们无限地重复数据,因此我们只需要一个one-shot 迭代器。 images, fids, pids = dataset.make_one_shot_iterator().get_next() # Create the model and an embedding head. model = import_module('nets.' + args.model_name) head = import_module('heads.' + args.head_name) # Feed the image through the model. The returned `body_prefix` will be used # further down to load the pre-trained weights for all variables with this # prefix. #通过模型提供图像。返回的`body_prefix`将进一步用于加载具有此前缀的所有变量的预先训练权重。 endpoints, body_prefix = model.endpoints(images, is_training=True) with tf.name_scope('head'): endpoints = head.head(endpoints, args.embedding_dim, is_training=True) # Create the loss in two steps: 分两步创建损失: # 1. Compute all pairwise distances according to the specified metric.1.根据指定的度量计算所有成对距离。 # 2. For each anchor along the first dimension, compute its loss. 2.对于第一维中的每个锚,计算其损失。 dists = loss.cdist(endpoints['emb'], endpoints['emb'], metric=args.metric) losses, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[ args.loss](dists, pids, args.margin, batch_precision_at_k=args.batch_k - 1) # Count the number of active entries, and compute the total batch loss. 计算活动条目的数量,并计算总批量损失。 num_active = tf.reduce_sum(tf.cast(tf.greater(losses, 1e-5), tf.float32)) loss_mean = tf.reduce_mean(losses) # Some logging for tensorboard. 一些日志记录在tensorboard。 tf.summary.histogram('loss_distribution', losses) tf.summary.scalar('loss', loss_mean) tf.summary.scalar('batch_top1', train_top1) tf.summary.scalar('batch_prec_at_{}'.format(args.batch_k - 1), prec_at_k) tf.summary.scalar('active_count', num_active) tf.summary.histogram('embedding_dists', dists) tf.summary.histogram('embedding_pos_dists', pos_dists) tf.summary.histogram('embedding_neg_dists', neg_dists) tf.summary.histogram('embedding_lengths', tf.norm(endpoints['emb_raw'], axis=1)) # Create the mem-mapped arrays in which we'll log all training detail in # addition to tensorboard, because tensorboard is annoying for detailed # inspection and actually discards data in histogram summaries. #创建mem - mapped数组,我们将记录除tensorboard之外的所有训练细节,因为tensorboard对于详细检查很烦人,实际上在直方图总结中丢弃了数据。 if args.detailed_logs: log_embs = lb.create_or_resize_dat( os.path.join(args.experiment_root, 'embeddings'), dtype=np.float32, shape=(args.train_iterations, batch_size, args.embedding_dim)) log_loss = lb.create_or_resize_dat( os.path.join(args.experiment_root, 'losses'), dtype=np.float32, shape=(args.train_iterations, batch_size)) log_fids = lb.create_or_resize_dat( os.path.join(args.experiment_root, 'fids'), dtype='S' + str(max_fid_len), shape=(args.train_iterations, batch_size)) # These are collected here before we add the optimizer, because depending # on the optimizer, it might add extra slots, which are also global # variables, with the exact same prefix. #在我们添加优化器之前收集这些信息,因为根据优化器的不同,它可能会添加额外的插槽,这些插槽也是全局变量,具有完全相同的前缀。 model_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, body_prefix) # Define the optimizer and the learning-rate schedule. 定义优化器和学习率计划。 # Unfortunately, we get NaNs if we don't handle no-decay separately. 不幸的是,如果我们不单独处理无衰减,我们会得到NaNs。 global_step = tf.Variable(0, name='global_step', trainable=False) if 0 <= args.decay_start_iteration < args.train_iterations: learning_rate = tf.train.exponential_decay( args.learning_rate, tf.maximum(0, global_step - args.decay_start_iteration), args.train_iterations - args.decay_start_iteration, 0.001) else: learning_rate = args.learning_rate tf.summary.scalar('learning_rate', learning_rate) optimizer = tf.train.AdamOptimizer(learning_rate) # Feel free to try others! 随意尝试别的的优化器! # optimizer = tf.train.AdadeltaOptimizer(learning_rate) # Update_ops are used to update batchnorm stats. Update_ops用于更新batchnorm统计信息。 with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): train_op = optimizer.minimize(loss_mean, global_step=global_step) # Define a saver for the complete model. 为整个模型定义一个保存器。 checkpoint_saver = tf.train.Saver(max_to_keep=0) with tf.Session() as sess: if args.resume: # In case we're resuming, simply load the full checkpoint to init.如果我们正在恢复,只需将完整的检查点加载到init。 last_checkpoint = tf.train.latest_checkpoint(args.experiment_root) log.info('Restoring from checkpoint: {}'.format(last_checkpoint)) checkpoint_saver.restore(sess, last_checkpoint) else: # But if we're starting from scratch, we may need to load some # variables from the pre-trained weights, and random init others. #但是如果我们从头开始,我们可能需要从预先训练的权重中加载一些变量,并随机初始化其他的变量。 sess.run(tf.global_variables_initializer()) if args.initial_checkpoint is not None: saver = tf.train.Saver(model_variables) saver.restore(sess, args.initial_checkpoint) # In any case, we also store this initialization as a checkpoint, # such that we could run exactly reproduceable experiments. #无论如何,我们也将这个初始化作为检查点存储,以便我们可以运行完全可再生的实验。 checkpoint_saver.save(sess, os.path.join(args.experiment_root, 'checkpoint'), global_step=0) merged_summary = tf.summary.merge_all() summary_writer = tf.summary.FileWriter(args.experiment_root, sess.graph) start_step = sess.run(global_step) log.info('Starting training from iteration {}.'.format(start_step)) # Finally, here comes the main-loop. This `Uninterrupt` is a handy # utility such that an iteration still finishes on Ctrl+C and we can # stop the training cleanly. #最后,这里是主循环。这个`Uninterrupt`是一个非常方便的工具,可以在Ctrl + C之后完成迭代才停止,我们可以干净地停止训练。 with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u: for i in range(start_step, args.train_iterations): # Compute gradients, update weights, store logs!计算梯度,更新权重,存储日志! start_time = time.time() _, summary, step, b_prec_at_k, b_embs, b_loss, b_fids = \ sess.run([train_op, merged_summary, global_step, prec_at_k, endpoints['emb'], losses, fids]) elapsed_time = time.time() - start_time # Compute the iteration speed and add it to the summary.计算迭代速度并将其添加到摘要中。 # We did observe some weird spikes that we couldn't track down.我们确实观察到一些我们无法追查的奇怪尖峰。 summary2 = tf.Summary() summary2.value.add(tag='secs_per_iter', simple_value=elapsed_time) summary_writer.add_summary(summary2, step) summary_writer.add_summary(summary, step) if args.detailed_logs: log_embs[i], log_loss[i], log_fids[ i] = b_embs, b_loss, b_fids # Do a huge print out of the current progress. 在当前的进展中做大量的印刷。 seconds_todo = (args.train_iterations - step) * elapsed_time log.info( 'iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, ' 'batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it)'.format( step, float(np.min(b_loss)), float(np.mean(b_loss)), float(np.max(b_loss)), args.batch_k - 1, float(b_prec_at_k), timedelta(seconds=int(seconds_todo)), elapsed_time)) sys.stdout.flush() sys.stderr.flush() # Save a checkpoint of training every so often.每隔一段时间保存一次训练的检查点checkpoint。 if (args.checkpoint_frequency > 0 and step % args.checkpoint_frequency == 0): checkpoint_saver.save(sess, os.path.join(args.experiment_root, 'checkpoint'), global_step=step) # Stop the main-loop at the end of the step, if requested. 如果需要,在步骤结束时停止主循环。 if u.interrupted: log.info("Interrupted on request!") break # Store one final checkpoint. This might be redundant, but it is crucial # in case intermediate storing was disabled and it saves a checkpoint # when the process was interrupted. #存储一个最终检查点。 这可能是多余的,但是在中间存储被禁用的情况下它是至关重要的,并且在进程被中断时保存检查点。 checkpoint_saver.save(sess, os.path.join(args.experiment_root, 'checkpoint'), global_step=step)
def main(args): logger = common.setup_logger(args) # Lets cuDNN benchmark conv implementations and choose the fastest. # Only good if sizes stay the same within the main loop! torch.backends.cudnn.benchmark = True device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") logger.info(f"Going to train on {device}") train_set, valid_set, train_loader, valid_loader = mktrainval(args, logger) logger.info(f"Loading model from {args.model}.npz") model = models.KNOWN_MODELS[args.model](head_size=len(valid_set.classes), zero_head=True) model.load_from( np.load(os.path.join(args.pretrained_dir, f"{args.model}.npz"))) logger.info("Moving model onto all GPUs") model = torch.nn.DataParallel(model) # Optionally resume from a checkpoint. # Load it to CPU first as we'll move the model to GPU later. # This way, we save a little bit of GPU memory when loading. step = 0 # Note: no weight-decay! optim = torch.optim.SGD(model.parameters(), lr=args.base_lr, momentum=0.9) writer = SummaryWriter(os.path.join(args.logdir, args.name)) # Resume fine-tuning if we find a saved model. savename = pjoin(args.logdir, args.name, "model.tar") try: logger.info(f"Model will be saved in '{savename}'") checkpoint = torch.load(savename, map_location="cpu") logger.info(f"Found saved model to resume from at '{savename}'") step = checkpoint["step"] model.load_state_dict(checkpoint["model"]) optim.load_state_dict(checkpoint["optim"]) logger.info(f"Resumed at step {step}") except FileNotFoundError: logger.info("Fine-tuning from BiT") model = model.to(device) optim.zero_grad() model.train() mixup = hyperrule.get_mixup(len(train_set)) cri = torch.nn.CrossEntropyLoss().to(device) logger.info("Starting training!") chrono = lb.Chrono() accum_steps = 0 mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1 end = time.time() with lb.Uninterrupt() as u: for x, y in recycle(train_loader): # measure data loading time, which is spent in the `for` statement. chrono._done("load", time.time() - end) if u.interrupted: break # Schedule sending to GPU(s) x = x.to(device, non_blocking=True) y = y.to(device, non_blocking=True) # Update learning-rate, including stop training if over. lr = hyperrule.get_lr(step, len(train_set), args.base_lr) if lr is None: break for param_group in optim.param_groups: param_group["lr"] = lr if mixup > 0.0: x, y_a, y_b = mixup_data(x, y, mixup_l) # compute output with chrono.measure("fprop"): logits = model(x) if mixup > 0.0: c = mixup_criterion(cri, logits, y_a, y_b, mixup_l) else: c = cri(logits, y) c_num = float( c.data.cpu().numpy()) # Also ensures a sync point. # Accumulate grads with chrono.measure("grads"): (c / args.batch_split).backward() accum_steps += 1 accstep = f" ({accum_steps}/{args.batch_split})" if args.batch_split > 1 else "" logger.info( f"[step {step}{accstep}]: loss={c_num:.5f} (lr={lr:.1e})") # pylint: disable=logging-format-interpolation logger.flush() writer.add_scalar('Train/loss', c_num, step) writer.add_scalar('Train/lr', lr, step) # Update params if accum_steps == args.batch_split: with chrono.measure("update"): optim.step() optim.zero_grad() step += 1 accum_steps = 0 # Sample new mixup ratio for next batch mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1 # Run evaluation and save the model. if args.eval_every and step % args.eval_every == 0: run_eval(model, valid_loader, device, chrono, logger, writer, step) if args.save and step % args.save_every == 0: step_savename = pjoin(args.logdir, args.name, "model_" + str(step) + ".tar") torch.save( { "step": step, "model": model.state_dict(), "optim": optim.state_dict() }, step_savename) end = time.time() # Final eval at end of training. run_eval(model, valid_loader, device, chrono, logger, writer, step) logger.info(f"Timings:\n{chrono}")
def train(args, images, fids, pids, max_fid_len, log): ''' Creation model and training neural network :param args: all stored arguments :param images: prepared images for training :param fids: figure id (relative paths from image_root to images) :param pids: person id (or car id) for all images :param log: log file, where logs from training are stored :return: saved files (checkpoints, train log file) ''' ################################################################################################################### # CREATE MODEL ################################################################################################################### # Create the model and an embedding head. model = import_module('nets.resnet_v1_50') # Feed the image through the model. The returned `body_prefix` will be used # further down to load the pre-trained weights for all variables with this # prefix. drops = {} if args.dropout is not None: drops = getDropoutProbs(args.dropout) b4_layers = None try: b4_layers = int(args.b4_layers) if b4_layers not in [1, 2, 3]: raise ValueError() except: ValueError("Argument exception: b4_layers has to be in [1, 2, 3]") endpoints, body_prefix = model.endpoints(images, b4_layers, drops, is_training=True, resnet_stride=int( args.resnet_stride)) endpoints['emb'] = endpoints['emb_raw'] = slim.fully_connected( endpoints['model_output'], args.embedding_dim, activation_fn=None, weights_initializer=tf.orthogonal_initializer(), scope='emb') step_pl = tf.placeholder(dtype=tf.float32) features = endpoints['emb'] # Create the loss in two steps: # 1. Compute all pairwise distances according to the specified metric. # 2. For each anchor along the first dimension, compute its loss. dists = loss.cdist(features, features, metric=args.metric) losses, train_top1, prec_at_k, _, probe_neg_dists, pos_dists, neg_dists = loss.loss_function( dists, pids, [args.alpha1, args.alpha2, args.alpha3], batch_precision_at_k=args.batch_k - 1) # Count the number of active entries, and compute the total batch loss. num_active = tf.reduce_sum(tf.cast(tf.greater(losses, 1e-5), tf.float32)) loss_mean = tf.reduce_mean(losses) # Some logging for tensorboard. tf.summary.histogram('loss_distribution', losses) tf.summary.scalar('loss', loss_mean) tf.summary.scalar('batch_top1', train_top1) tf.summary.scalar('batch_prec_at_{}'.format(args.batch_k - 1), prec_at_k) tf.summary.scalar('active_count', num_active) tf.summary.scalar('embedding_pos_dists', tf.reduce_mean(pos_dists)) tf.summary.scalar('embedding_probe_neg_dists', tf.reduce_mean(probe_neg_dists)) tf.summary.scalar('embedding_neg_dists', tf.reduce_mean(neg_dists)) tf.summary.histogram('embedding_dists', dists) tf.summary.histogram('embedding_pos_dists', pos_dists) tf.summary.histogram('embedding_probe_neg_dists', probe_neg_dists) tf.summary.histogram('embedding_neg_dists', neg_dists) tf.summary.histogram('embedding_lengths', tf.norm(endpoints['emb_raw'], axis=1)) # Create the mem-mapped arrays in which we'll log all training detail in # addition to tensorboard, because tensorboard is annoying for detailed # inspection and actually discards data in histogram summaries. batch_size = args.batch_p * args.batch_k if args.detailed_logs: log_embs = lb.create_or_resize_dat( os.path.join(args.experiment_root, 'embeddings'), dtype=np.float32, shape=(args.train_iterations, batch_size, args.embedding_dim)) log_loss = lb.create_or_resize_dat( os.path.join(args.experiment_root, 'losses'), dtype=np.float32, shape=(args.train_iterations, batch_size)) log_fids = lb.create_or_resize_dat( os.path.join(args.experiment_root, 'fids'), dtype='S' + str(max_fid_len), shape=(args.train_iterations, batch_size)) # These are collected here before we add the optimizer, because depending # on the optimizer, it might add extra slots, which are also global # variables, with the exact same prefix. model_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, body_prefix) # Define the optimizer and the learning-rate schedule. # Unfortunately, we get NaNs if we don't handle no-decay separately. global_step = tf.Variable(0, name='global_step', trainable=False) if args.sgdr: learning_rate = tf.train.cosine_decay_restarts( learning_rate=args.learning_rate, global_step=global_step, first_decay_steps=4000, t_mul=1.5) else: if 0 <= args.decay_start_iteration < args.train_iterations: learning_rate = tf.train.exponential_decay( args.learning_rate, tf.maximum(0, global_step - args.decay_start_iteration), args.train_iterations - args.decay_start_iteration, float(args.lr_decay)) else: learning_rate = args.learning_rate tf.summary.scalar('learning_rate', learning_rate) optimizer = tf.train.AdamOptimizer(tf.convert_to_tensor(learning_rate)) # Update_ops are used to update batchnorm stats. with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): train_op = optimizer.minimize(loss_mean, global_step=global_step) # Define a saver for the complete model. checkpoint_saver = tf.train.Saver(max_to_keep=0) with tf.Session() as sess: if args.resume: # In case we're resuming, simply load the full checkpoint to init. last_checkpoint = tf.train.latest_checkpoint(args.experiment_root) log.info('Restoring from checkpoint: {}'.format(last_checkpoint)) checkpoint_saver.restore(sess, last_checkpoint) else: # But if we're starting from scratch, we may need to load some # variables from the pre-trained weights, and random init others. sess.run(tf.global_variables_initializer()) if args.initial_checkpoint is not None: saver = tf.train.Saver(model_variables) saver.restore(sess, args.initial_checkpoint) # In any case, we also store this initialization as a checkpoint, # such that we could run exactly reproduceable experiments. checkpoint_saver.save(sess, os.path.join(args.experiment_root, 'checkpoint'), global_step=0) merged_summary = tf.summary.merge_all() summary_writer = tf.summary.FileWriter(args.experiment_root, sess.graph) start_step = sess.run(global_step) step = start_step log.info('Starting training from iteration {}.'.format(start_step)) ################################################################################################################### # TRAINING ################################################################################################################### # Finally, here comes the main-loop. This `Uninterrupt` is a handy # utility such that an iteration still finishes on Ctrl+C and we can # stop the training cleanly. with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u: for i in range(start_step, args.train_iterations): # Compute gradients, update weights, store logs! start_time = time.time() _, summary, step, b_prec_at_k, b_embs, b_loss, b_fids = \ sess.run([train_op, merged_summary, global_step, prec_at_k, features, losses, fids], feed_dict={step_pl: step}) elapsed_time = time.time() - start_time # Compute the iteration speed and add it to the summary. # We did observe some weird spikes that we couldn't track down. summary2 = tf.Summary() summary2.value.add(tag='secs_per_iter', simple_value=elapsed_time) summary_writer.add_summary(summary2, step) summary_writer.add_summary(summary, step) if args.detailed_logs: log_embs[i], log_loss[i], log_fids[ i] = b_embs, b_loss, b_fids # Do a huge print out of the current progress. Maybe steal from here. seconds_todo = (args.train_iterations - step) * elapsed_time log.info( 'iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, ' 'batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it), lr={:.4g}'. format(step, float(np.min(b_loss)), float(np.mean(b_loss)), float(np.max(b_loss)), args.batch_k - 1, float(b_prec_at_k), timedelta(seconds=int(seconds_todo)), elapsed_time, sess.run(optimizer._lr))) sys.stdout.flush() sys.stderr.flush() # Save a checkpoint of training every so often. if (args.checkpoint_frequency > 0 and step % args.checkpoint_frequency == 0): checkpoint_saver.save(sess, os.path.join(args.experiment_root, 'checkpoint'), global_step=step) # Stop the main-loop at the end of the step, if requested. if u.interrupted: log.info("Interrupted on request!") break # Store one final checkpoint. This might be redundant, but it is crucial # in case intermediate storing was disabled and it saves a checkpoint # when the process was interrupted. checkpoint_saver.save(sess, os.path.join(args.experiment_root, 'checkpoint'), global_step=step)
def main(argv): args = parser.parse_args(argv) if args.gpu: os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu # tf.compat.v1.disable_eager_execution() # physical_devices = tf.config.experimental.list_physical_devices('GPU') # tf.config.experimental.set_memory_growth(physical_devices[0], True) # We store all arguments in a json file. This has two advantages: # 1. We can always get back and see what exactly that experiment was # 2. We can resume an experiment as-is without needing to remember all flags. args_file = os.path.join(args.experiment_root, 'args.json') if args.resume: if not os.path.isfile(args_file): raise IOError('`args.json` not found in {}'.format(args_file)) print('Loading args from {}.'.format(args_file)) with open(args_file, 'r') as f: args_resumed = json.load(f) args_resumed['resume'] = True # This would be overwritten. # When resuming, we not only want to populate the args object with the # values from the file, but we also want to check for some possible # conflicts between loaded and given arguments. for key, value in args.__dict__.items(): if key in args_resumed: resumed_value = args_resumed[key] if resumed_value != value: print('Warning: For the argument `{}` we are using the' ' loaded value `{}`. The provided value was `{}`' '.'.format(key, resumed_value, value)) args.__dict__[key] = resumed_value else: print('Warning: A new argument was added since the last run:' ' `{}`. Using the new value: `{}`.'.format(key, value)) else: # If the experiment directory exists already, we bail in fear. if os.path.exists(args.experiment_root): if os.listdir(args.experiment_root): print('The directory {} already exists and is not empty.' ' If you want to resume training, append --resume to' ' your call.'.format(args.experiment_root)) exit(1) else: os.makedirs(args.experiment_root) # Store the passed arguments for later resuming and grepping in a nice # and readable format. with open(args_file, 'w') as f: json.dump(vars(args), f, ensure_ascii=False, indent=2, sort_keys=True) log_file = os.path.join(args.experiment_root, "train") logging.config.dictConfig(common.get_logging_dict(log_file)) log = logging.getLogger('train') # Also show all parameter values at the start, for ease of reading logs. log.info('Training using the following parameters:') for key, value in sorted(vars(args).items()): log.info('{}: {}'.format(key, value)) # Check them here, so they are not required when --resume-ing. if not args.train_set: parser.print_help() log.error("You did not specify the `train_set` argument!") sys.exit(1) if not args.image_root: parser.print_help() log.error("You did not specify the required `image_root` argument!") sys.exit(1) # Load the data from the CSV file. pids, fids = common.load_dataset(args.train_set, args.image_root) max_fid_len = max(map(len, fids)) # We'll need this later for logfiles. # Setup a tf.Dataset where one "epoch" loops over all PIDS. # PIDS are shuffled after every epoch and continue indefinitely. unique_pids = np.unique(pids) if len(unique_pids) < args.batch_p: unique_pids = np.tile(unique_pids, int(np.ceil(args.batch_p / len(unique_pids)))) dataset = tf.data.Dataset.from_tensor_slices(unique_pids) dataset = dataset.shuffle(len(unique_pids)) # Constrain the dataset size to a multiple of the batch-size, so that # we don't get overlap at the end of each epoch. dataset = dataset.take((len(unique_pids) // args.batch_p) * args.batch_p) dataset = dataset.repeat(None) # Repeat forever. Funny way of stating it. # For every PID, get K images. dataset = dataset.map(lambda pid: sample_k_fids_for_pid( pid, all_fids=fids, all_pids=pids, batch_k=args.batch_k)) # Ungroup/flatten the batches for easy loading of the files. dataset = dataset.unbatch() # Convert filenames to actual image tensors. net_input_size = (args.net_input_height, args.net_input_width) pre_crop_size = (args.pre_crop_height, args.pre_crop_width) dataset = dataset.map(lambda fid, pid: common.fid_to_image( fid, pid, image_root=args.image_root, image_size=pre_crop_size if args.crop_augment else net_input_size), num_parallel_calls=args.loading_threads) # Augment the data if specified by the arguments. dataset = dataset.map( lambda im, fid, pid: common.fid_to_image(fid, pid, image_root=args.image_root, image_size=pre_crop_size if args.crop_augment else net_input_size), # Ergys num_parallel_calls=args.loading_threads) if args.flip_augment: dataset = dataset.map(lambda im, fid, pid: (tf.image.random_flip_left_right(im), fid, pid)) if args.crop_augment: dataset = dataset.map(lambda im, fid, pid: (tf.image.random_crop( im, net_input_size + (3, )), fid, pid)) # Create the model and an embedding head. tf.keras.backend.set_learning_phase(1) emb_model = EmbeddingModel(args) # Group it back into PK batches. batch_size = args.batch_p * args.batch_k dataset = dataset.map(lambda im, fid, pid: (emb_model.preprocess_input(im), fid, pid)) dataset = dataset.batch(batch_size) # Overlap producing and consuming for parallelism. dataset = dataset.prefetch(1) # Since we repeat the data infinitely, we only need a one-shot iterator. # Feed the image through the model. The returned `body_prefix` will be used # further down to load the pre-trained weights for all variables with this # prefix. # all_trainable_variables = embedding_head.trainable_variables+base_model.trainable_variables # Define the optimizer and the learning-rate schedule. # Unfortunately, we get NaNs if we don't handle no-decay separately. if 0 <= args.decay_start_iteration < args.train_iterations: learning_rate = tf.optimizers.schedules.PolynomialDecay( args.learning_rate, args.train_iterations, end_learning_rate=1e-7) else: learning_rate = args.learning_rate if args.optimizer == 'adam': optimizer = tf.keras.optimizers.Adam(learning_rate) elif args.optimizer == 'momentum': optimizer = tf.keras.optimizers.SGD(learning_rate, momentum=0.9) else: raise NotImplementedError('Invalid optimizer {}'.format( args.optimizer)) @tf.function def train_step(images, pids): with tf.GradientTape() as tape: batch_embedding = emb_model(images) if args.loss == 'semi_hard_triplet': embedding_loss = triplet_semihard_loss(batch_embedding, pids, args.margin) elif args.loss == 'hard_triplet': embedding_loss = batch_hard(batch_embedding, pids, args.margin, args.metric) elif args.loss == 'lifted_loss': embedding_loss = lifted_loss(pids, batch_embedding, margin=args.margin) elif args.loss == 'contrastive_loss': assert batch_size % 2 == 0 assert args.batch_k == 4 ## Can work with other number but will need tuning contrastive_idx = np.tile([0, 1, 4, 3, 2, 5, 6, 7], args.batch_p // 2) for i in range(args.batch_p // 2): contrastive_idx[i * 8:i * 8 + 8] += i * 8 contrastive_idx = np.expand_dims(contrastive_idx, 1) batch_embedding_ordered = tf.gather_nd(batch_embedding, contrastive_idx) pids_ordered = tf.gather_nd(pids, contrastive_idx) # batch_embedding_ordered = tf.Print(batch_embedding_ordered,[pids_ordered],'pids_ordered :: ',summarize=1000) embeddings_anchor, embeddings_positive = tf.unstack( tf.reshape(batch_embedding_ordered, [-1, 2, args.embedding_dim]), 2, 1) # embeddings_anchor = tf.Print(embeddings_anchor,[pids_ordered,embeddings_anchor,embeddings_positive,batch_embedding,batch_embedding_ordered],"Tensors ", summarize=1000) fixed_labels = np.tile([1, 0, 0, 1], args.batch_p // 2) # fixed_labels = np.reshape(fixed_labels,(len(fixed_labels),1)) # print(fixed_labels) labels = tf.constant(fixed_labels) # labels = tf.Print(labels,[labels],'labels ',summarize=1000) embedding_loss = contrastive_loss(labels, embeddings_anchor, embeddings_positive, margin=args.margin) elif args.loss == 'angular_loss': embeddings_anchor, embeddings_positive = tf.unstack( tf.reshape(batch_embedding, [-1, 2, args.embedding_dim]), 2, 1) # pids = tf.Print(pids, [pids], 'pids:: ', summarize=100) pids, _ = tf.unstack(tf.reshape(pids, [-1, 2, 1]), 2, 1) # pids = tf.Print(pids,[pids],'pids:: ',summarize=100) embedding_loss = angular_loss(pids, embeddings_anchor, embeddings_positive, batch_size=args.batch_p, with_l2reg=True) elif args.loss == 'npairs_loss': assert args.batch_k == 2 ## Single positive pair per class embeddings_anchor, embeddings_positive = tf.unstack( tf.reshape(batch_embedding, [-1, 2, args.embedding_dim]), 2, 1) pids, _ = tf.unstack(tf.reshape(pids, [-1, 2, 1]), 2, 1) pids = tf.reshape(pids, [-1]) embedding_loss = npairs_loss(pids, embeddings_anchor, embeddings_positive) else: raise NotImplementedError('Invalid Loss {}'.format(args.loss)) loss_mean = tf.reduce_mean(embedding_loss) gradients = tape.gradient(loss_mean, emb_model.trainable_variables) optimizer.apply_gradients(zip(gradients, emb_model.trainable_variables)) return embedding_loss # sess = tf.compat.v1.Session() # start_step = sess.run(global_step) # checkpoint_saver = tf.train.Saver(max_to_keep=2) start_step = 0 log.info('Starting training from iteration {}.'.format(start_step)) dataset_iter = iter(dataset) ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optimizer, net=emb_model) manager = tf.train.CheckpointManager(ckpt, osp.join(args.experiment_root, 'tf_ckpts'), max_to_keep=3) ckpt.restore(manager.latest_checkpoint) if manager.latest_checkpoint: print("Restored from {}".format(manager.latest_checkpoint)) else: print("Initializing from scratch.") with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u: for i in range(ckpt.step.numpy(), args.train_iterations): # for batch_idx, batch in enumerate(): start_time = time.time() images, fids, pids = next(dataset_iter) batch_loss = train_step(images, pids) elapsed_time = time.time() - start_time seconds_todo = (args.train_iterations - i) * elapsed_time # print(tf.reduce_min(batch_loss).numpy(),tf.reduce_mean(batch_loss).numpy(),tf.reduce_max(batch_loss).numpy()) log.info( 'iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, ETA: {} ({:.2f}s/it)' .format( i, tf.reduce_min(batch_loss).numpy(), tf.reduce_mean(batch_loss).numpy(), tf.reduce_max(batch_loss).numpy(), # args.batch_k - 1, float(b_prec_at_k), timedelta(seconds=int(seconds_todo)), elapsed_time)) ckpt.step.assign_add(1) if (args.checkpoint_frequency > 0 and i % args.checkpoint_frequency == 0): # uncomment if you want to save the model weight separately # emb_model.save_weights(os.path.join(args.experiment_root, 'model_weights_{0:04d}.w'.format(i))) manager.save() # Stop the main-loop at the end of the step, if requested. if u.interrupted: log.info("Interrupted on request!") break
def main(): # args = parser.parse_args() # We store all arguments in a json file. This has two advantages: # 1. We can always get back and see what exactly that experiment was # 2. We can resume an experiment as-is without needing to remember all flags. train_config = cfg.TrainConfig() args_file = os.path.join(train_config.experiment_root, 'args.json') if train_config.resume: if not os.path.isfile(args_file): raise IOError('`args.json` not found in {}'.format(args_file)) print('Loading args from {}.'.format(args_file)) with open(args_file, 'r') as f: args_resumed = json.load(f) args_resumed['resume'] = True # This would be overwritten. # When resuming, we not only want to populate the args object with the # values from the file, but we also want to check for some possible # conflicts between loaded and given arguments. for key, value in train_config.__dict__.items(): if key in args_resumed: resumed_value = args_resumed[key] if resumed_value != value: print('Warning: For the argument `{}` we are using the' ' loaded value `{}`. The provided value was `{}`' '.'.format(key, resumed_value, value)) train_config.__dict__[key] = resumed_value else: print('Warning: A new argument was added since the last run:' ' `{}`. Using the new value: `{}`.'.format(key, value)) else: # If the experiment directory exists already, we bail in fear. if os.path.exists(train_config.experiment_root): if os.listdir(train_config.experiment_root): print('The directory {} already exists and is not empty.' ' If you want to resume training, append --resume to' ' your call.'.format(train_config.experiment_root)) exit(1) else: os.makedirs(train_config.experiment_root) # Store the passed arguments for later resuming and grepping in a nice # and readable format. with open(args_file, 'w') as f: json.dump(vars(args), f, ensure_ascii=False, indent=2, sort_keys=True) log_file = os.path.join(train_config.experiment_root, "train") logging.config.dictConfig(common.get_logging_dict(log_file)) log = logging.getLogger('train') # Also show all parameter values at the start, for ease of reading logs. log.info('Training using the following parameters:') for key, value in sorted(vars(args).items()): log.info('{}: {}'.format(key, value)) # Check them here, so they are not required when --resume-ing. if not train_config.train_set: parser.print_help() log.error("You did not specify the `train_set` argument!") sys.exit(1) if not train_config.image_root: parser.print_help() log.error("You did not specify the required `image_root` argument!") sys.exit(1) # Load the data from the CSV file. pids, fids = common.load_dataset(train_config.train_set, train_config.image_root, is_train=True) max_fid_len = max(map(len, fids)) # We'll need this later for logfiles # Setup a tf.Dataset where one "epoch" loops over all PIDS. # PIDS are shuffled after every epoch and continue indefinitely. unique_pids = np.unique(pids) dataset = tf.data.Dataset.from_tensor_slices(unique_pids) dataset = dataset.shuffle(len(unique_pids)) # Constrain the dataset size to a multiple of the batch-size, so that # we don't get overlap at the end of each epoch. dataset = dataset.take((len(unique_pids) // train_config.batch_p) * train_config.batch_p) # take(count) Creates a Dataset with at most count elements from this dataset. dataset = dataset.repeat(None) # Repeat forever. Funny way of stating it. # Repeats this dataset count times. # For every PID, get K images. dataset = dataset.map(lambda pid: sample_k_fids_for_pid( pid, all_fids=fids, all_pids=pids, batch_k=train_config.batch_k)) # Ungroup/flatten the batches for easy loading of the files. dataset = dataset.apply(tf.contrib.data.unbatch()) # apply(transformation_func) Apply a transformation function to this dataset. # apply enables chaining of custom Dataset transformations, which are represented as functions that take one Dataset argument and return a transformed Dataset. # Convert filenames to actual image tensors. net_input_size = (train_config.net_input_height, train_config.net_input_width) # 256,128 pre_crop_size = (train_config.pre_crop_height, train_config.pre_crop_width) # 288,144 dataset = dataset.map( lambda fid, pid: common.fid_to_image_label( fid, pid, image_root=train_config.image_root, image_size=pre_crop_size if train_config.crop_augment else net_input_size), num_parallel_calls=train_config.loading_threads) ########################################################################################### dataset = dataset.map( lambda im, keypt, mask, fid, pid: (tf.concat([im, keypt, mask], 2), fid, pid)) ########################################################################################### # Augment the data if specified by the arguments. if train_config.flip_augment: dataset = dataset.map( lambda im, fid, pid: (tf.image.random_flip_left_right(im), fid, pid)) # net_input_size_aug = net_input_size + (4,) if train_config.crop_augment: dataset = dataset.map( lambda im, fid, pid: (tf.random_crop(im, net_input_size + (21,)), fid, pid)) # net_input_size + (21,) = (256, 128, 21) # split ############################################################################################# dataset = dataset.map( lambda im, fid, pid: (common.split(im, fid, pid))) ############################################################################################# # Group it back into PK batches. batch_size = train_config.batch_p * train_config.batch_k dataset = dataset.batch(batch_size) # Overlap producing and consuming for parallelism. dataset = dataset.prefetch(1) # prefetch(buffer_size) Creates a Dataset that prefetches elements from this dataset. # Since we repeat the data infinitely, we only need a one-shot iterator. images, keypts, masks, fids, pids = dataset.make_one_shot_iterator().get_next() # tf.summary.image('image',images,10) # Create the model and an embedding head. model = import_module('nets.' + train_config.model_name) head = import_module('heads.' + train_config.head_name) # Feed the image through the model. The returned `body_prefix` will be used # further down to load the pre-trained weights for all variables with this # prefix. endpoints, body_prefix = model.endpoints(images, is_training=True) heatmap_in = endpoints[train_config.model_name + '/block4'] # resnet_block_4_out = heatmap.resnet_block_4(heatmap_in) # resnet_block_3_4_out = heatmap.resnet_block_3_4(heatmap_in) # resnet_block_2_3_4_out = heatmap.resnet_block_2_3_4(heatmap_in) # head for heatmap with tf.name_scope('heatmap'): # heatmap_in = endpoints['model_output'] # heatmap_out_layer_0 = heatmap.hmnet_layer_0(resnet_block_4_out, 1) # heatmap_out_layer_0 = heatmap.hmnet_layer_0(resnet_block_3_4_out, 1) # heatmap_out_layer_0 = heatmap.hmnet_layer_0(resnet_block_2_3_4_out, 1) heatmap_out_layer_0 = VAC.hmnet_layer_0(heatmap_in[:, :, :, 1020:2048], 1) heatmap_out_layer_1 = VAC.hmnet_layer_1(heatmap_out_layer_0, 1) heatmap_out_layer_2 = VAC.hmnet_layer_2(heatmap_out_layer_1, 1) heatmap_out_layer_3 = VAC.hmnet_layer_3(heatmap_out_layer_2, 1) heatmap_out_layer_4 = VAC.hmnet_layer_4(heatmap_out_layer_3, 1) heatmap_out = heatmap_out_layer_4 heatmap_loss = VAC.loss_mutilayer(heatmap_out_layer_0, heatmap_out_layer_1, heatmap_out_layer_2, heatmap_out_layer_3, heatmap_out_layer_4, masks, net_input_size) # heatmap_loss = heatmap.loss(heatmap_out, labels, net_input_size) # heatmap_loss_mean = heatmap_loss with tf.name_scope('head'): # heatmap_sum = tf.reduce_sum(heatmap_out, axis=3) # heatmap_resize = tf.image.resize_images(tf.expand_dims(heatmap_sum, axis=3), [8, 4]) # featuremap_tmp = tf.multiply(heatmap_resize, endpoints[args.model_name + '/block4']) # endpoints[args.model_name + '/block4'] = featuremap_tmp endpoints = head.head(endpoints, train_config.embedding_dim, is_training=True) tf.summary.image('feature_map', tf.expand_dims(endpoints[train_config.model_name + '/block4'][:, :, :, 0], axis=3), 4) with tf.name_scope('keypoints_pre'): keypoints_pre_in = endpoints[train_config.model_name + '/block4'] # keypoints_pre_in_0 = keypoints_pre_in[:, :, :, 0:256] # keypoints_pre_in_1 = keypoints_pre_in[:, :, :, 256:512] # keypoints_pre_in_2 = keypoints_pre_in[:, :, :, 512:768] # keypoints_pre_in_3 = keypoints_pre_in[:, :, :, 768:1024] keypoints_pre_in_0 = keypoints_pre_in[:, :, :, 0:170] keypoints_pre_in_1 = keypoints_pre_in[:, :, :, 170:340] keypoints_pre_in_2 = keypoints_pre_in[:, :, :, 340:510] keypoints_pre_in_3 = keypoints_pre_in[:, :, :, 510:680] keypoints_pre_in_4 = keypoints_pre_in[:, :, :, 680:850] keypoints_pre_in_5 = keypoints_pre_in[:, :, :, 850:1020] labels = tf.image.resize_images(keypts, [128, 64]) # keypoints_gt_0 = tf.concat([labels[:, :, :, 0:5], labels[:, :, :, 14:15], labels[:, :, :, 15:16], labels[:, :, :, 16:17], labels[:, :, :, 17:18]], 3) # keypoints_gt_1 = tf.concat([labels[:, :, :, 1:2], labels[:, :, :, 2:3], labels[:, :, :, 3:4], labels[:, :, :, 5:6]], 3) # keypoints_gt_2 = tf.concat([labels[:, :, :, 4:5], labels[:, :, :, 7:8], labels[:, :, :, 8:9], labels[:, :, :, 11:12]], 3) # keypoints_gt_3 = tf.concat([labels[:, :, :, 9:10], labels[:, :, :, 10:11], labels[:, :, :, 12:13], labels[:, :, :, 13:14]], 3) keypoints_gt_0 = labels[:, :, :, 0:5] keypoints_gt_1 = labels[:, :, :, 5:7] keypoints_gt_2 = labels[:, :, :, 7:9] keypoints_gt_3 = labels[:, :, :, 9:13] keypoints_gt_4 = labels[:, :, :, 13:15] keypoints_gt_5 = labels[:, :, :, 15:17] keypoints_pre_0 = PAC.tran_conv_0(keypoints_pre_in, kp_num=5) keypoints_pre_1 = PAC.tran_conv_1(keypoints_pre_in, kp_num=2) keypoints_pre_2 = PAC.tran_conv_2(keypoints_pre_in, kp_num=2) keypoints_pre_3 = PAC.tran_conv_3(keypoints_pre_in, kp_num=4) keypoints_pre_4 = PAC.tran_conv_4(keypoints_pre_in, kp_num=2) keypoints_pre_5 = PAC.tran_conv_5(keypoints_pre_in, kp_num=2) keypoints_loss_0 = PAC.keypoints_loss(keypoints_pre_0, keypoints_gt_0) keypoints_loss_1 = PAC.keypoints_loss(keypoints_pre_1, keypoints_gt_1) keypoints_loss_2 = PAC.keypoints_loss(keypoints_pre_2, keypoints_gt_2) keypoints_loss_3 = PAC.keypoints_loss(keypoints_pre_3, keypoints_gt_3) keypoints_loss_4 = PAC.keypoints_loss(keypoints_pre_4, keypoints_gt_4) keypoints_loss_5 = PAC.keypoints_loss(keypoints_pre_5, keypoints_gt_5) keypoints_loss = 5/17*keypoints_loss_0 + 2/17*keypoints_loss_1 + 2/17*keypoints_loss_2 + 4/17*keypoints_loss_3 + 2/17*keypoints_loss_4 + 2/17*keypoints_loss_5 # Create the loss in two steps: # 1. Compute all pairwise distances according to the specified metric. # 2. For each anchor along the first dimension, compute its loss. dists = loss.cdist(endpoints['emb'], endpoints['emb'], metric=train_config.metric) losses, train_top1, prec_at_k, _, neg_dists, pos_dists = loss.LOSS_CHOICES[train_config.loss]( dists, pids, train_config.margin, batch_precision_at_k=train_config.batch_k-1) # Count the number of active entries, and compute the total batch loss. num_active = tf.reduce_sum(tf.cast(tf.greater(losses, 1e-5), tf.float32)) loss_mean = tf.reduce_mean(losses) scale_rate_0 = 1E-7 scale_rate_1 = 6E-8 total_loss = loss_mean + keypoints_loss*scale_rate_0 + heatmap_loss*scale_rate_1 # total_loss = loss_mean + keypoints_loss * scale_rate_0 # total_loss = loss_mean # Some logging for tensorboard. tf.summary.histogram('loss_distribution', losses) tf.summary.scalar('loss', loss_mean) ############################################################################################ # tf.summary.histogram('hm_loss_distribution', heatmap_loss) tf.summary.scalar('keypt_loss_0', keypoints_loss_0) tf.summary.scalar('keypt_loss_1', keypoints_loss_1) tf.summary.scalar('keypt_loss_2', keypoints_loss_2) tf.summary.scalar('keypt_loss_3', keypoints_loss_3) tf.summary.scalar('keypt_loss_all', keypoints_loss) ############################################################################################ tf.summary.scalar('total_loss', total_loss) tf.summary.scalar('batch_top1', train_top1) tf.summary.scalar('batch_prec_at_{}'.format(args.batch_k-1), prec_at_k) tf.summary.scalar('active_count', num_active) tf.summary.histogram('embedding_dists', dists) tf.summary.histogram('embedding_pos_dists', pos_dists) tf.summary.histogram('embedding_neg_dists', neg_dists) tf.summary.histogram('embedding_lengths', tf.norm(endpoints['emb_raw'], axis=1)) # Create the mem-mapped arrays in which we'll log all training detail in # addition to tensorboard, because tensorboard is annoying for detailed # inspection and actually discards data in histogram summaries. if args.detailed_logs: log_embs = lb.create_or_resize_dat( os.path.join(train_config.experiment_root, 'embeddings'), dtype=np.float32, shape=(train_config.train_iterations, batch_size, args.embedding_dim)) log_loss = lb.create_or_resize_dat( os.path.join(train_config.experiment_root, 'losses'), dtype=np.float32, shape=(train_config.train_iterations, batch_size)) log_fids = lb.create_or_resize_dat( os.path.join(train_config.experiment_root, 'fids'), dtype='S' + str(max_fid_len), shape=(train_config.train_iterations, batch_size)) # These are collected here before we add the optimizer, because depending # on the optimizer, it might add extra slots, which are also global # variables, with the exact same prefix. model_variables = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, body_prefix) # Define the optimizer and the learning-rate schedule. # Unfortunately, we get NaNs if we don't handle no-decay separately. global_step = tf.Variable(0, name='global_step', trainable=False) if 0 <= train_config.decay_start_iteration < train_config.train_iterations: learning_rate = tf.train.exponential_decay( train_config.learning_rate, tf.maximum(0, global_step - train_config.decay_start_iteration), train_config.train_iterations - train_config.decay_start_iteration, 0.001) else: learning_rate = train_config.learning_rate tf.summary.scalar('learning_rate', learning_rate) optimizer = tf.train.AdamOptimizer(learning_rate) # Feel free to try others! # optimizer = tf.train.AdadeltaOptimizer(learning_rate) # Update_ops are used to update batchnorm stats. with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): # train_op = optimizer.minimize(loss_mean, global_step=global_step) train_op = optimizer.minimize(total_loss, global_step=global_step) # # Define a saver for the complete model. checkpoint_saver = tf.train.Saver(max_to_keep=0) gpu_options = tf.GPUOptions(allow_growth=True) with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess: if train_config.resume: # In case we're resuming, simply load the full checkpoint to init. last_checkpoint = tf.train.latest_checkpoint(args.experiment_root) log.info('Restoring from checkpoint: {}'.format(last_checkpoint)) checkpoint_saver.restore(sess, last_checkpoint) else: # But if we're starting from scratch, we may need to load some # variables from the pre-trained weights, and random init others. sess.run(tf.global_variables_initializer()) if train_config.initial_checkpoint is not None: saver = tf.train.Saver(model_variables, write_version=tf.train.SaverDef.V1) saver.restore(sess, train_config.initial_checkpoint) # name_11 = 'resnet_v1_50/block4' # name_12 = 'resnet_v1_50/block3' # name_13 = 'resnet_v1_50/block2' # name_21 = 'Resnet_block_2_3_4/block4' # name_22 = 'Resnet_block_2_3_4/block3' # name_23 = 'Resnet_block_2_3_4/block2' # for var in tf.trainable_variables(): # var_name = var.name # if re.match(name_11, var_name): # dst_name = var_name.replace(name_11, name_21) # tensor = tf.get_default_graph().get_tensor_by_name(var_name) # dst_tensor = tf.get_default_graph().get_tensor_by_name(dst_name) # tf.assign(dst_tensor, tensor) # if re.match(name_12, var_name): # dst_name = var_name.replace(name_12, name_22) # tensor = tf.get_default_graph().get_tensor_by_name(var_name) # dst_tensor = tf.get_default_graph().get_tensor_by_name(dst_name) # tf.assign(dst_tensor, tensor) # if re.match(name_13, var_name): # dst_name = var_name.replace(name_13, name_23) # tensor = tf.get_default_graph().get_tensor_by_name(var_name) # dst_tensor = tf.get_default_graph().get_tensor_by_name(dst_name) # tf.assign(dst_tensor, tensor) # In any case, we also store this initialization as a checkpoint, # such that we could run exactly reproduceable experiments. checkpoint_saver.save(sess, os.path.join( train_config.experiment_root, 'checkpoint'), global_step=0) merged_summary = tf.summary.merge_all() summary_writer = tf.summary.FileWriter(train_config.experiment_root, sess.graph) start_step = sess.run(global_step) log.info('Starting training from iteration {}.'.format(start_step)) # Finally, here comes the main-loop. This `Uninterrupt` is a handy # utility such that an iteration still finishes on Ctrl+C and we can # stop the training cleanly. with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u: for i in range(start_step, train_config.train_iterations): # Compute gradients, update weights, store logs! start_time = time.time() _, summary, step, b_prec_at_k, b_embs, b_loss, b_fids = \ sess.run([train_op, merged_summary, global_step, prec_at_k, endpoints['emb'], losses, fids]) elapsed_time = time.time() - start_time # Compute the iteration speed and add it to the summary. # We did observe some weird spikes that we couldn't track down. summary2 = tf.Summary() summary2.value.add(tag='secs_per_iter', simple_value=elapsed_time) summary_writer.add_summary(summary2, step) summary_writer.add_summary(summary, step) if train_config.detailed_logs: log_embs[i], log_loss[i], log_fids[i] = b_embs, b_loss, b_fids # Do a huge print out of the current progress. seconds_todo = (train_config.train_iterations - step) * elapsed_time log.info('iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, ' 'batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it)'.format( step, float(np.min(b_loss)), float(np.mean(b_loss)), float(np.max(b_loss)), train_config.batch_k-1, float(b_prec_at_k), timedelta(seconds=int(seconds_todo)), elapsed_time)) sys.stdout.flush() sys.stderr.flush() # Save a checkpoint of training every so often. if (train_config.checkpoint_frequency > 0 and step % train_config.checkpoint_frequency == 0): checkpoint_saver.save(sess, os.path.join( train_config.experiment_root, 'checkpoint'), global_step=step) # Stop the main-loop at the end of the step, if requested. if u.interrupted: log.info("Interrupted on request!") break # Store one final checkpoint. This might be redundant, but it is crucial # in case intermediate storing was disabled and it saves a checkpoint # when the process was interrupted. checkpoint_saver.save(sess, os.path.join( train_config.experiment_root, 'checkpoint'), global_step=step)