def __init__(self): self.x_input = tf.placeholder(tf.float32, shape=[None, 784]) self.y_input = tf.placeholder(tf.int64, shape=[None]) self.x_image = tf.reshape(self.x_input, [-1, 28, 28, 1]) with slim.arg_scope([slim.conv2d], kernel_size=3, padding='SAME'): with slim.arg_scope([slim.max_pool2d], kernel_size=2): x = slim.conv2d(self.x_image, num_outputs=32, scope='conv1_1') x = slim.conv2d(x, num_outputs=32, scope='conv1_2') x = slim.max_pool2d(x, scope='pool1') x = slim.conv2d(x, num_outputs=64, scope='conv2_1') x = slim.conv2d(x, num_outputs=64, scope='conv2_2') x = slim.max_pool2d(x, scope='pool2') x = slim.conv2d(x, num_outputs=128, scope='conv3_1') x = slim.conv2d(x, num_outputs=128, scope='conv3_2') x = slim.max_pool2d(x, scope='pool3') x = slim.flatten(x, scope='flatten') x = slim.fully_connected(x, num_outputs=32, activation_fn=None, scope='fc1') x = slim.fully_connected(x, num_outputs=2, activation_fn=None, scope='fc2') self.feature = x = tflearn.prelu(x) self.xent, logits, tmp = cos_loss(x, self.y_input, 10, alpha=0.25) self.y_pred = tf.arg_max( tf.matmul(tmp['x_feat_norm'], tmp['w_feat_norm']), 1) self.accuracy = tf.reduce_mean( tf.cast(tf.equal(self.y_pred, self.y_input), tf.float32))
def main(args): #network = importlib.import_module(args.model_def) subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S') log_dir = os.path.join(os.path.expanduser(args.logs_base_dir), subdir) if not os.path.isdir( log_dir): # Create the log directory if it doesn't exist os.makedirs(log_dir) model_dir = os.path.join(os.path.expanduser(args.models_base_dir), subdir) if not os.path.isdir( model_dir): # Create the model directory if it doesn't exist os.makedirs(model_dir) # Write arguments to a text file utils.write_arguments_to_file(args, os.path.join(log_dir, 'arguments.txt')) # Store some git revision info in a text file in the log directory src_path, _ = os.path.split(os.path.realpath(__file__)) utils.store_revision_info(src_path, log_dir, ' '.join(sys.argv)) np.random.seed(seed=args.seed) train_set = utils.get_dataset(args.data_dir) nrof_classes = len(train_set) print('nrof_classes: ', nrof_classes) image_list, label_list = utils.get_image_paths_and_labels(train_set) image_list = np.array(image_list) label_list = np.array(label_list, dtype=np.int32) dataset_size = len(image_list) single_batch_size = args.people_per_batch * args.images_per_person indices = range(dataset_size) np.random.shuffle(indices) def _sample_people_softmax(x): global softmax_ind if softmax_ind >= dataset_size: np.random.shuffle(indices) softmax_ind = 0 true_num_batch = min(single_batch_size, dataset_size - softmax_ind) sample_paths = image_list[indices[softmax_ind:softmax_ind + true_num_batch]] sample_labels = label_list[indices[softmax_ind:softmax_ind + true_num_batch]] softmax_ind += true_num_batch return (np.array(sample_paths), np.array(sample_labels, dtype=np.int32)) def _sample_people(x): '''We sample people based on tf.data, where we can use transform and prefetch. ''' image_paths, num_per_class = sample_people( train_set, args.people_per_batch * (args.num_gpus - 1), args.images_per_person) labels = [] for i in range(len(num_per_class)): labels.extend([i] * num_per_class[i]) return (np.array(image_paths), np.array(labels, dtype=np.int32)) def _parse_function(filename, label): file_contents = tf.read_file(filename) image = tf.image.decode_image(file_contents, channels=3) #image = tf.image.decode_jpeg(file_contents, channels=3) print(image.shape) if args.random_crop: print('use random crop') image = tf.random_crop(image, [args.image_size, args.image_size, 3]) else: print('Not use random crop') #image.set_shape((args.image_size, args.image_size, 3)) image.set_shape((None, None, 3)) image = tf.image.resize_images(image, size=(args.image_height, args.image_width)) #print(image.shape) if args.random_flip: image = tf.image.random_flip_left_right(image) #pylint: disable=no-member #image.set_shape((args.image_size, args.image_size, 3)) image.set_shape((args.image_height, args.image_width, 3)) if debug: image = tf.cast(image, tf.float32) else: image = tf.image.per_image_standardization(image) return image, label print('Model directory: %s' % model_dir) print('Log directory: %s' % log_dir) if args.pretrained_model: print('Pre-trained model: %s' % os.path.expanduser(args.pretrained_model)) with tf.Graph().as_default(): tf.set_random_seed(args.seed) global_step = tf.Variable(0, trainable=False, name='global_step') # Placeholder for the learning rate learning_rate_placeholder = tf.placeholder(tf.float32, name='learning_rate') phase_train_placeholder = tf.placeholder(tf.bool, name='phase_train') #the image is generated by sequence with tf.device("/cpu:0"): softmax_dataset = tf_data.Dataset.range(args.epoch_size * args.max_nrof_epochs * 100) softmax_dataset = softmax_dataset.map(lambda x: tf.py_func( _sample_people_softmax, [x], [tf.string, tf.int32])) softmax_dataset = softmax_dataset.flat_map(_from_tensor_slices) softmax_dataset = softmax_dataset.map(_parse_function, num_threads=8, output_buffer_size=2000) softmax_dataset = softmax_dataset.batch(args.num_gpus * single_batch_size) softmax_iterator = softmax_dataset.make_initializable_iterator() softmax_next_element = softmax_iterator.get_next() softmax_next_element[0].set_shape( (args.num_gpus * single_batch_size, args.image_height, args.image_width, 3)) softmax_next_element[1].set_shape(args.num_gpus * single_batch_size) batch_image_split = tf.split(softmax_next_element[0], args.num_gpus) batch_label_split = tf.split(softmax_next_element[1], args.num_gpus) learning_rate = tf.train.exponential_decay( learning_rate_placeholder, global_step, args.learning_rate_decay_epochs * args.epoch_size, args.learning_rate_decay_factor, staircase=True) tf.summary.scalar('learning_rate', learning_rate) print('Using optimizer: {}'.format(args.optimizer)) if args.optimizer == 'ADAGRAD': opt = tf.train.AdagradOptimizer(learning_rate) elif args.optimizer == 'MOM': opt = tf.train.MomentumOptimizer(learning_rate, 0.9) elif args.optimizer == 'ADAM': opt = tf.train.AdamOptimizer(learning_rate, beta1=0.9, beta2=0.999, epsilon=0.1) else: raise Exception("Not supported optimizer: {}".format( args.optimizer)) tower_losses = [] tower_cross = [] tower_dist = [] tower_reg = [] for i in range(args.num_gpus): with tf.device("/gpu:" + str(i)): with tf.name_scope("tower_" + str(i)) as scope: with slim.arg_scope([slim.model_variable, slim.variable], device="/cpu:0"): with tf.variable_scope( tf.get_variable_scope()) as var_scope: reuse = False if i == 0 else True #with slim.arg_scope(resnet_v2.resnet_arg_scope(args.weight_decay)): #prelogits, end_points = resnet_v2.resnet_v2_50(batch_image_split[i],is_training=True, # output_stride=16,num_classes=args.embedding_size,reuse=reuse) #prelogits, end_points = network.inference(batch_image_split[i], args.keep_probability, # phase_train=phase_train_placeholder, bottleneck_layer_size=args.embedding_size, # weight_decay=args.weight_decay, reuse=reuse) if args.network == 'sphere_network': prelogits = network.infer(batch_image_split[i]) elif args.network == 'resface': prelogits, _ = resface.inference( batch_image_split[i], 1.0, weight_decay=args.weight_decay, reuse=reuse) elif args.network == 'inception_net': prelogits, endpoints = inception_net.inference( batch_image_split[i], 1, phase_train=True, bottleneck_layer_size=args.embedding_size, weight_decay=args.weight_decay, reuse=reuse) elif args.network == 'resnet_v2': with slim.arg_scope( resnet_v2.resnet_arg_scope( args.weight_decay)): prelogits, end_points = resnet_v2.resnet_v2_50( batch_image_split[i], is_training=True, output_stride=16, num_classes=args.embedding_size, reuse=reuse) prelogits = tf.squeeze(prelogits, axis=[1, 2]) else: raise Exception( "Not supported network: {}".format( args.network)) if args.fc_bn: prelogits = slim.batch_norm( prelogits, is_training=True, decay=0.997, epsilon=1e-5, scale=True, updates_collections=tf.GraphKeys. UPDATE_OPS, reuse=reuse, scope='softmax_bn') if args.loss_type == 'softmax': cross_entropy_mean = utils.softmax_loss( prelogits, batch_label_split[i], len(train_set), args.weight_decay, reuse) regularization_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) tower_cross.append(cross_entropy_mean) #loss = cross_entropy_mean + args.weight_decay*tf.add_n(regularization_losses) loss = cross_entropy_mean + tf.add_n( regularization_losses) #tower_dist.append(0) #tower_cross.append(cross_entropy_mean) #tower_th.append(0) tower_losses.append(loss) tower_reg.append(regularization_losses) elif args.loss_type == 'cosface': label_reshape = tf.reshape( batch_label_split[i], [single_batch_size]) label_reshape = tf.cast( label_reshape, tf.int64) coco_loss = utils.cos_loss(prelogits, label_reshape, len(train_set), reuse, alpha=args.alpha, scale=args.scale) #regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) #reg_loss = args.weight_decay*tf.add_n(regularization_losses) reg_loss = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) loss = coco_loss + reg_loss tower_losses.append(loss) tower_reg.append(reg_loss) #loss = tf.add_n([cross_entropy_mean] + regularization_losses, name='total_loss') tf.get_variable_scope().reuse_variables() total_loss = tf.reduce_mean(tower_losses) total_reg = tf.reduce_mean(tower_reg) losses = {} losses['total_loss'] = total_loss losses['total_reg'] = total_reg grads = opt.compute_gradients(total_loss, tf.trainable_variables(), colocate_gradients_with_ops=True) apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = tf.group(apply_gradient_op) save_vars = [ var for var in tf.global_variables() if 'Adagrad' not in var.name and 'global_step' not in var.name ] #saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=3) saver = tf.train.Saver(save_vars, max_to_keep=3) # Build the summary operation based on the TF collection of Summaries. summary_op = tf.summary.merge_all() # Start running operations on the Graph. gpu_options = tf.GPUOptions( per_process_gpu_memory_fraction=args.gpu_memory_fraction) sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)) # Initialize variables sess.run(tf.global_variables_initializer(), feed_dict={phase_train_placeholder: True}) sess.run(tf.local_variables_initializer(), feed_dict={phase_train_placeholder: True}) #sess.run(iterator.initializer) sess.run(softmax_iterator.initializer) summary_writer = tf.summary.FileWriter(log_dir, sess.graph) coord = tf.train.Coordinator() tf.train.start_queue_runners(coord=coord, sess=sess) with sess.as_default(): #pdb.set_trace() if args.pretrained_model: print('Restoring pretrained model: %s' % args.pretrained_model) saver.restore(sess, os.path.expanduser(args.pretrained_model)) # Training and validation loop epoch = 0 while epoch < args.max_nrof_epochs: step = sess.run(global_step, feed_dict=None) epoch = step // args.epoch_size if debug: debug_train(args, sess, train_set, epoch, image_batch_gather, enqueue_op, batch_size_placeholder, image_batch_split, image_paths_split, num_per_class_split, image_paths_placeholder, image_paths_split_placeholder, labels_placeholder, labels_batch, num_per_class_placeholder, num_per_class_split_placeholder, len(gpus)) # Train for one epoch train(args, sess, epoch, learning_rate_placeholder, phase_train_placeholder, global_step, losses, train_op, summary_op, summary_writer, args.learning_rate_schedule_file) # Save variables and the metagraph if it doesn't exist already save_variables_and_metagraph(sess, saver, summary_writer, model_dir, subdir, step) return model_dir
def main_train(args): subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S') log_dir = os.path.join(os.path.expanduser(args.logs_base_dir), subdir) if not os.path.isdir( log_dir): # Create the log directory if it doesn't exist os.makedirs(log_dir) model_dir = os.path.join(os.path.expanduser(args.models_base_dir), subdir) if not os.path.isdir( model_dir): # Create the model directory if it doesn't exist os.makedirs(model_dir) # Write arguments to a text file utils.write_arguments_to_file(args, os.path.join(log_dir, 'arguments.txt')) # Store some git revision info in a text file in the log directory src_path, _ = os.path.split(os.path.realpath(__file__)) utils.store_revision_info(src_path, log_dir, ' '.join(sys.argv)) np.random.seed(seed=args.seed) train_set = utils.dataset_from_list( args.train_data_dir, args.train_list_dir) # class objects in a list #----------------------class definition------------------------------------- ''' class ImageClass(): "Stores the paths to images for a given class" def __init__(self, name, image_paths): self.name = name self.image_paths = image_paths def __str__(self): return self.name + ', ' + str(len(self.image_paths)) + ' images' def __len__(self): return len(self.image_paths) ''' nrof_classes = len(train_set) print('nrof_classes: ', nrof_classes) image_list, label_list = utils.get_image_paths_and_labels(train_set) print('total images: ', len(image_list)) # label is in the form scalar. image_list = np.array(image_list) label_list = np.array(label_list, dtype=np.int32) dataset_size = len(image_list) single_batch_size = args.class_per_batch * args.images_per_class indices = list(range(dataset_size)) np.random.shuffle(indices) def _sample_people_softmax(x): # loading the images in batches. global softmax_ind if softmax_ind >= dataset_size: np.random.shuffle(indices) softmax_ind = 0 true_num_batch = min(single_batch_size, dataset_size - softmax_ind) sample_paths = image_list[indices[softmax_ind:softmax_ind + true_num_batch]] sample_images = [] for item in sample_paths: sample_images.append(np.load(str(item))) #print(item) #print(type(sample_paths[0])) sample_labels = label_list[indices[softmax_ind:softmax_ind + true_num_batch]] softmax_ind += true_num_batch return (np.expand_dims(np.array(sample_images, dtype=np.float32), axis=4), np.array(sample_labels, dtype=np.int32)) print('Model directory: %s' % model_dir) print('Log directory: %s' % log_dir) if args.pretrained_model: print('Pre-trained model: %s' % os.path.expanduser(args.pretrained_model)) with tf.Graph().as_default(): tf.set_random_seed(args.seed) global_step = tf.Variable(0, trainable=False, name='global_step') # Placeholder for the learning rate learning_rate_placeholder = tf.placeholder(tf.float32, name='learning_rate') phase_train_placeholder = tf.placeholder(tf.bool, name='phase_train') #the image is generated by sequence with tf.device("/cpu:0"): softmax_dataset = tf.data.Dataset.range(args.epoch_size * args.max_nrof_epochs) softmax_dataset = softmax_dataset.map(lambda x: tf.py_func( _sample_people_softmax, [x], [tf.float32, tf.int32])) softmax_dataset = softmax_dataset.flat_map(_from_tensor_slices) softmax_dataset = softmax_dataset.batch(single_batch_size) softmax_iterator = softmax_dataset.make_initializable_iterator() softmax_next_element = softmax_iterator.get_next() softmax_next_element[0].set_shape( (single_batch_size, args.image_height, args.image_width, args.image_width, 1)) softmax_next_element[1].set_shape(single_batch_size) batch_image_split = softmax_next_element[0] # batch_image_split = tf.expand_dims(batch_image_split, axis = 4) batch_label_split = softmax_next_element[1] learning_rate = tf.train.exponential_decay( learning_rate_placeholder, global_step, args.learning_rate_decay_epochs * args.epoch_size, args.learning_rate_decay_factor, staircase=True) tf.summary.scalar('learning_rate', learning_rate) print('Using optimizer: {}'.format(args.optimizer)) if args.optimizer == 'ADAGRAD': opt = tf.train.AdagradOptimizer(learning_rate) elif args.optimizer == 'SGD': opt = tf.train.GradientDescentOptimizer(learning_rate) elif args.optimizer == 'MOM': opt = tf.train.MomentumOptimizer(learning_rate, 0.9) elif args.optimizer == 'ADAM': opt = tf.train.AdamOptimizer(learning_rate, beta1=0.9, beta2=0.999, epsilon=0.1) else: raise Exception("Not supported optimizer: {}".format( args.optimizer)) losses = {} with slim.arg_scope([slim.model_variable, slim.variable], device="/cpu:0"): with tf.variable_scope(tf.get_variable_scope()) as var_scope: reuse = False if args.network == 'sphere_network': prelogits = network.infer(batch_image_split, args.embedding_size) else: raise Exception("Not supported network: {}".format( args.network)) if args.fc_bn: prelogits = slim.batch_norm(prelogits, is_training=True, decay=0.997,epsilon=1e-5,scale=True,\ updates_collections=tf.GraphKeys.UPDATE_OPS,reuse=reuse,scope='softmax_bn') if args.loss_type == 'softmax': cross_entropy_mean = utils.softmax_loss( prelogits, batch_label_split, len(train_set), 1.0, reuse) regularization_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) loss = cross_entropy_mean + args.weight_decay * tf.add_n( regularization_losses) print('************************' + ' Computing the softmax loss') losses['total_loss'] = cross_entropy_mean losses['total_reg'] = args.weight_decay * tf.add_n( regularization_losses) elif args.loss_type == 'lmcl': label_reshape = tf.reshape(batch_label_split, [single_batch_size]) label_reshape = tf.cast(label_reshape, tf.int64) coco_loss = utils.cos_loss(prelogits, label_reshape, len(train_set), reuse, alpha=args.alpha, scale=args.scale) regularization_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) loss = coco_loss + args.weight_decay * tf.add_n( regularization_losses) print('************************' + ' Computing the lmcl loss') losses['total_loss'] = coco_loss losses['total_reg'] = args.weight_decay * tf.add_n( regularization_losses) elif args.loss_type == 'center': # center loss center_loss, centers, centers_update_op = get_center_loss(prelogits, label_reshape, args.center_loss_alfa, \ args.num_class_train) regularization_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) loss = center_loss + args.weight_decay * tf.add_n( regularization_losses) print('************************' + ' Computing the center loss') losses['total_loss'] = center_loss losses['total_reg'] = args.weight_decay * tf.add_n( regularization_losses) elif args.loss_type == 'lmccl': cross_entropy_mean = utils.softmax_loss( prelogits, batch_label_split, len(train_set), 1.0, reuse) label_reshape = tf.reshape(batch_label_split, [single_batch_size]) label_reshape = tf.cast(label_reshape, tf.int64) coco_loss = utils.cos_loss(prelogits, label_reshape, len(train_set), reuse, alpha=args.alpha, scale=args.scale) center_loss, centers, centers_update_op = get_center_loss(prelogits, label_reshape, args.center_loss_alfa, \ args.num_class_train) regularization_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) reg_loss = args.weight_decay * tf.add_n( regularization_losses) loss = coco_loss + reg_loss + args.center_weighting * center_loss + cross_entropy_mean losses[ 'total_loss_center'] = args.center_weighting * center_loss losses['total_loss_lmcl'] = coco_loss losses['total_loss_softmax'] = cross_entropy_mean losses['total_reg'] = reg_loss grads = opt.compute_gradients(loss, tf.trainable_variables(), colocate_gradients_with_ops=True) apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # used for updating the centers in the center loss. if args.loss_type == 'lmccl' or args.loss_type == 'center': with tf.control_dependencies([centers_update_op]): with tf.control_dependencies(update_ops): train_op = tf.group(apply_gradient_op) else: with tf.control_dependencies(update_ops): train_op = tf.group(apply_gradient_op) save_vars = [ var for var in tf.global_variables() if 'Adagrad' not in var.name and 'global_step' not in var.name ] saver = tf.train.Saver(save_vars, max_to_keep=3) # Build the summary operation based on the TF collection of Summaries. summary_op = tf.summary.merge_all() # Start running operations on the Graph. gpu_options = tf.GPUOptions( per_process_gpu_memory_fraction=args.gpu_memory_fraction) sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)) # Initialize variables sess.run(tf.global_variables_initializer(), feed_dict={phase_train_placeholder: True}) sess.run(tf.local_variables_initializer(), feed_dict={phase_train_placeholder: True}) #sess.run(iterator.initializer) sess.run(softmax_iterator.initializer) summary_writer = tf.summary.FileWriter(log_dir, sess.graph) coord = tf.train.Coordinator() tf.train.start_queue_runners(coord=coord, sess=sess) with sess.as_default(): if args.pretrained_model: print('Restoring pretrained model: %s' % args.pretrained_model) saver.restore(sess, os.path.expanduser(args.pretrained_model)) # Training and validation loop epoch = 0 while epoch < args.max_nrof_epochs: step = sess.run(global_step, feed_dict=None) epoch = step // args.epoch_size if debug: debug_train(args, sess, train_set, epoch, image_batch_gather,\ enqueue_op,batch_size_placeholder, image_batch_split,image_paths_split,num_per_class_split, image_paths_placeholder,image_paths_split_placeholder, labels_placeholder, labels_batch,\ num_per_class_placeholder,num_per_class_split_placeholder,len(gpus)) # Train for one epoch if args.loss_type == 'lmccl' or args.loss_type == 'center': train_contain_center(args, sess, epoch, learning_rate_placeholder, phase_train_placeholder, global_step, losses, train_op, summary_op, summary_writer, '', centers_update_op) else: train(args, sess, epoch, learning_rate_placeholder, phase_train_placeholder, global_step, losses, train_op, summary_op, summary_writer, '') # Save variables and the metagraph if it doesn't exist already save_variables_and_metagraph(sess, saver, summary_writer, model_dir, subdir, step) return model_dir
def main(args): subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S') log_dir = os.path.join(os.path.expanduser(args.logs_base_dir), subdir) if not os.path.isdir( log_dir): # Create the log directory if it doesn't exist os.makedirs(log_dir) model_dir = os.path.join(os.path.expanduser(args.models_base_dir), subdir) if not os.path.isdir( model_dir): # Create the model directory if it doesn't exist os.makedirs(model_dir) np.random.seed(seed=args.seed) print('load data...') if args.dataset == 'webface': train_set = utils.get_dataset(args.data_dir) elif args.dataset == 'mega': train_set = utils.dataset_from_cache(args.data_dir) #train_set.extend(ic_train_set) print('Loaded dataset: {} persons'.format(len(train_set))) nrof_classes = len(train_set) class_indices = list(np.arange(nrof_classes)) #np.random.shuffle(class_indices) #print(class_indices) def _sample_people(x): '''We sample people based on tf.data, where we can use transform and prefetch. ''' scale = 1 if args.mine_method != 'simi_online' else args.scale image_paths, labels = sample_people( train_set, class_indices, args.people_per_batch * args.num_gpus * scale, args.images_per_person) #labels = [] #print(labels) #for i in range(len(num_per_class)): # labels.extend([i]*num_per_class[i]) return (np.array(image_paths), np.array(labels, dtype=np.int32)) def _parse_function(filename, label): file_contents = tf.read_file(filename) image = tf.image.decode_image(file_contents, channels=3) #image = tf.image.decode_jpeg(file_contents, channels=3) if args.random_crop: print('use random crop') image = tf.random_crop(image, [args.image_size, args.image_size, 3]) else: print('Not use random crop') #image.set_shape((args.image_size, args.image_size, 3)) image.set_shape((None, None, 3)) image = tf.image.resize_images(image, size=(args.image_size, args.image_size)) if args.random_flip: image = tf.image.random_flip_left_right(image) #pylint: disable=no-member image.set_shape((args.image_size, args.image_size, 3)) print('img shape', image.shape) image = tf.cast(image, tf.float32) image = tf.subtract(image, 127.5) image = tf.div(image, 128.) return image, label gpus = range(args.num_gpus) print('Model directory: %s' % model_dir) print('Log directory: %s' % log_dir) if args.pretrained_model: print('Pre-trained model: %s' % os.path.expanduser(args.pretrained_model)) with tf.Graph().as_default(): tf.set_random_seed(args.seed) global_step = tf.Variable(0, trainable=False, name='global_step') # Placeholder for the learning rate learning_rate_placeholder = tf.placeholder(tf.float32, name='learning_rate') phase_train_placeholder = tf.placeholder(tf.bool, name='phase_train') #the image is generated by sequence single_batch_size = args.people_per_batch * args.images_per_person total_batch_size = args.num_gpus * single_batch_size with tf.device("/cpu:0"): dataset = tf_data.Dataset.range(args.epoch_size * args.max_nrof_epochs * 100) #dataset.repeat(args.max_nrof_epochs) #sample people based map dataset = dataset.map(lambda x: tf.py_func(_sample_people, [x], [tf.string, tf.int32])) dataset = dataset.flat_map(_from_tensor_slices) dataset = dataset.map(_parse_function, num_parallel_calls=8) dataset = dataset.batch(total_batch_size) iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() batch_image_split = tf.split(next_element[0], args.num_gpus) batch_label = next_element[1] global trip_thresh trip_thresh = args.num_gpus * args.people_per_batch * args.images_per_person * 10 #learning_rate = tf.train.exponential_decay(args.learning_rate, global_step, learning_rate = tf.train.exponential_decay( learning_rate_placeholder, global_step, args.learning_rate_decay_epochs * args.epoch_size, args.learning_rate_decay_factor, staircase=True) tf.summary.scalar('learning_rate', learning_rate) opt = utils.get_opt(args.optimizer, learning_rate) tower_embeddings = [] tower_feats = [] for i in range(len(gpus)): with tf.device("/gpu:" + str(gpus[i])): with tf.name_scope("tower_" + str(gpus[i])) as scope: with slim.arg_scope([slim.model_variable, slim.variable], device="/cpu:0"): # Build the inference graph with tf.variable_scope( tf.get_variable_scope()) as var_scope: reuse = False if i == 0 else True if args.network == 'resnet_v2': with slim.arg_scope( resnet_v2.resnet_arg_scope( args.weight_decay)): #prelogits, end_points = resnet_v1.resnet_v1_50(batch_image_split[i], is_training=phase_train_placeholder, output_stride=16, num_classes=args.embedding_size, reuse=reuse) prelogits, end_points = resnet_v2.resnet_v2_50( batch_image_split[i], is_training=True, output_stride=16, num_classes=args.embedding_size, reuse=reuse) prelogits = tf.squeeze( prelogits, [1, 2], name='SpatialSqueeze') elif args.network == 'resface': prelogits, end_points = resface.inference( batch_image_split[i], 1.0, bottleneck_layer_size=args.embedding_size, weight_decay=args.weight_decay, reuse=reuse) print('res face prelogits', prelogits) elif args.network == 'mobilenet': prelogits, net_points = mobilenet.inference( batch_image_split[i], bottleneck_layer_size=args.embedding_size, phase_train=True, weight_decay=args.weight_decay, reuse=reuse) if args.fc_bn: print('use fc bn') embeddings = slim.batch_norm( prelogits, is_training=True, decay=0.997, epsilon=1e-5, scale=True, updates_collections=tf.GraphKeys. UPDATE_OPS, reuse=reuse, scope='softmax_bn') embeddings = tf.nn.l2_normalize( embeddings, 1, 1e-10, name='embeddings') tf.get_variable_scope().reuse_variables() tower_embeddings.append(embeddings) embeddings_gather = tf.concat(tower_embeddings, axis=0, name='embeddings_concat') if args.with_softmax: coco_loss = utils.cos_loss(embeddings_gather, batch_label, len(train_set)) # select triplet pair by tf op with tf.name_scope('triplet_part'): #embeddings_norm = tf.nn.l2_normalize(embeddings_gather,axis=1) #distances = utils._pairwise_distances(embeddings_norm,squared=True) distances = utils._pairwise_distances(embeddings_gather, squared=True) print('triplet strategy', args.strategy) if args.strategy == 'min_and_min': pair = tf.py_func(select_triplets_min_min, [distances, batch_label, args.alpha], tf.int64) elif args.strategy == 'min_and_max': pair = tf.py_func(select_triplets_min_max, [distances, batch_label, args.alpha], tf.int64) elif args.strategy == 'hardest': pair = tf.py_func(select_triplets_hardest, [distances, batch_label, args.alpha], tf.int64) elif args.strategy == 'batch_random': pair = tf.py_func(select_triplets_batch_random, [distances, batch_label, args.alpha], tf.int64) elif args.strategy == 'batch_all': pair = tf.py_func(select_triplets_batch_all, [distances, batch_label, args.alpha], tf.int64) else: raise ValueError('Not supported strategy {}'.format( args.strategy)) triplet_handle = {} triplet_handle['embeddings'] = embeddings_gather triplet_handle['labels'] = batch_label triplet_handle['pair'] = pair if args.mine_method == 'online': pair_reshape = tf.reshape(pair, [-1]) embeddings_gather = tf.gather(embeddings_gather, pair_reshape) anchor, positive, negative = tf.unstack( tf.reshape(embeddings_gather, [-1, 3, args.embedding_size]), 3, 1) triplet_loss, pos_d, neg_d = utils.triplet_loss( anchor, positive, negative, args.alpha) # Calculate the total losses regularization_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) triplet_loss = tf.add_n([triplet_loss]) total_loss = triplet_loss + tf.add_n(regularization_losses) if args.with_softmax: total_loss = total_loss + args.softmax_loss_weight * coco_loss #total_loss = tf.add_n(regularization_losses) losses = {} losses['triplet_loss'] = triplet_loss losses['total_loss'] = total_loss update_vars = tf.trainable_variables() with tf.device("/gpu:" + str(gpus[0])): grads = opt.compute_gradients(total_loss, update_vars, colocate_gradients_with_ops=True) if args.pretrain_softmax: softmax_vars = [ var for var in update_vars if 'centers_var' in var.name ] print('softmax vars', softmax_vars) softmax_grads = opt.compute_gradients(coco_loss, softmax_vars) softmax_update_op = opt.apply_gradients(softmax_grads) apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) #update_ops = [op for op in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if 'pair_part' in op.name] update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) print('update ops', update_ops) with tf.control_dependencies(update_ops): train_op_dep = tf.group(apply_gradient_op) train_op = tf.cond(tf.is_nan(triplet_loss), lambda: tf.no_op('no_train'), lambda: train_op_dep) save_vars = [ var for var in tf.global_variables() if 'Adagrad' not in var.name and 'global_step' not in var.name ] restore_vars = [ var for var in tf.global_variables() if 'Adagrad' not in var.name and 'global_step' not in var.name and 'pair_part' not in var.name and 'centers_var' not in var.name ] print('restore vars', restore_vars) saver = tf.train.Saver(save_vars, max_to_keep=3) restorer = tf.train.Saver(restore_vars, max_to_keep=3) # Build the summary operation based on the TF collection of Summaries. summary_op = tf.summary.merge_all() # Start running operations on the Graph. gpu_options = tf.GPUOptions( per_process_gpu_memory_fraction=args.gpu_memory_fraction) sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)) # Initialize variables sess.run(tf.global_variables_initializer(), feed_dict={phase_train_placeholder: True}) sess.run(tf.local_variables_initializer(), feed_dict={phase_train_placeholder: True}) sess.run(iterator.initializer) summary_writer = tf.summary.FileWriter(log_dir, sess.graph) coord = tf.train.Coordinator() tf.train.start_queue_runners(coord=coord, sess=sess) forward_embeddings = [] with sess.as_default(): if args.pretrained_model: print('Restoring pretrained model: %s' % args.pretrained_model) restorer.restore(sess, os.path.expanduser(args.pretrained_model)) # Training and validation loop if args.pretrain_softmax: total_images = len(train_set) * 20 lr_init = 0.1 for epoch in range(args.softmax_epoch): for i in range(total_images // total_batch_size): if epoch == 4: lr_init = 0.05 if epoch == 7: lr_init = 0.01 coco_loss_err, _ = sess.run( [coco_loss, softmax_update_op], feed_dict={ phase_train_placeholder: False, learning_rate_placeholder: lr_init }) print('{}/{} {} coco loss err:{} lr:{}'.format( i, total_images // total_batch_size, epoch, coco_loss_err, lr_init)) epoch = 0 while epoch < args.max_nrof_epochs: step = sess.run(global_step, feed_dict=None) epoch = step // args.epoch_size # Train for one epoch if args.mine_method == 'simi_online': train_simi_online(args, sess, epoch, len(gpus), embeddings_gather, batch_label, next_element[0], batch_image_split, learning_rate_placeholder, learning_rate, phase_train_placeholder, global_step, pos_d, neg_d, triplet_handle, losses, train_op, summary_op, summary_writer, args.learning_rate_schedule_file) elif args.mine_method == 'online': train_online(args, sess, epoch, learning_rate, phase_train_placeholder, global_step, losses, train_op, summary_op, summary_writer, args.learning_rate_schedule_file) else: raise ValueError('Not supported mini method {}'.format( args.mine_method)) # Save variables and the metagraph if it doesn't exist already save_variables_and_metagraph(sess, saver, summary_writer, model_dir, subdir, step) return model_dir