def main(args): if not os.path.exists(args.pretrained_model): print('invalid pretrained model path') return weights = np.load(args.pretrained_model) pairs = test_utils.read_pairs('/exports_data/czj/data/lfw/files/pairs.txt') imglist, labels = test_utils.get_paths( '/exports_data/czj/data/lfw/lfw_aligned/', pairs, '_face_.jpg') total_images = len(imglist) # ---- build graph ---- # input = tf.placeholder(tf.float32, shape=[None, 160, 160, 3], name='image_batch') prelogits, _ = inception_resnet_v1.inference(input, 1, phase_train=False) embeddings = tf.nn.l2_normalize(prelogits, 1, 1e-10) # ---- extract ---- # gpu_options = tf.GPUOptions(allow_growth=True) sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, log_device_placement=False, allow_soft_placement=True)) with sess.as_default(): beg_time = time.time() to_assign = [ v.assign(weights[()][v.name][0]) for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) ] sess.run(to_assign) print('restore parameters: %.2fsec' % (time.time() - beg_time)) beg_time = time.time() images = load_data(imglist) print('load images: %.2fsec' % (time.time() - beg_time)) beg_time = time.time() batch_size = 32 beg = 0 end = 0 features = np.zeros((total_images, 128)) while end < total_images: end = min(beg + batch_size, total_images) features[beg:end] = sess.run(embeddings, {input: images[beg:end]}) beg = end print('extract features: %.2fsec' % (time.time() - beg_time)) tpr, fpr, acc, vr, vr_std, far = test_utils.evaluate(features, labels, num_folds=10) # display auc = metrics.auc(fpr, tpr) eer = brentq(lambda x: 1. - x - interpolate.interp1d(fpr, tpr)(x), 0., 1.) print('Acc: %1.3f+-%1.3f' % (np.mean(acc), np.std(acc))) print('VR@FAR=%2.5f: %2.5f+-%2.5f' % (far, vr, vr_std)) print('AUC: %1.3f' % auc) print('EER: %1.3f' % eer) sess.close()
def main(args): pairs = test_utils.read_pairs(args.lfw_pairs) model_list = test_utils.get_model_list(args.model_list) for t, model in enumerate(model_list): # get lfw pair filename paths, labels = test_utils.get_paths(args.lfw_dir, pairs, model[1]) with tf.device('/gpu:%d' % (t + 1)): gpu_options = tf.GPUOptions(allow_growth=True) sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, log_device_placement=False, allow_soft_placement=True)) with sess.as_default(): print("[%d] model: %s" % (t, model[1])) # restore model test_utils.load_model(sess, model[0]) # load data tensor images_pl = tf.get_default_graph().get_tensor_by_name('image_batch:0') embeddings = tf.get_default_graph().get_tensor_by_name('embeddings:0') phase_train_pl = tf.get_default_graph().get_tensor_by_name('phase_train:0') image_size = args.image_size emb_size = embeddings.get_shape()[1] # extract feature batch_size = args.lfw_batch_size num_images = len(paths) num_batches = num_images // batch_size emb_arr = np.zeros((num_images, emb_size)) for i in range(num_batches): print('process %d/%d' % (i + 1, num_batches), end='\r') beg_idx = i * batch_size end_idx = min((i + 1) * batch_size, num_images) images = test_utils.load_data(paths[beg_idx:end_idx], image_size) emb = sess.run(embeddings, feed_dict={images_pl: images, phase_train_pl: False}) emb_arr[beg_idx:end_idx, :] = emb # get lfw pair filename print("\ndone.") # concate feaure if t == 0: emb_ensemble = emb_arr * math.sqrt(float(model[2])) else: emb_ensemble = np.concatenate((emb_ensemble, emb_arr * math.sqrt(float(model[2]))), axis=1) print("ensemble feature:", emb_ensemble.shape) ''' norm = np.linalg.norm(emb_ensemble, axis=1) for i in range(emb_ensemble.shape[0]): emb_ensemble[i] = emb_ensemble[i] / norm[i] ''' tpr, fpr, acc, vr, vr_std, far = test_utils.evaluate(emb_ensemble, labels, num_folds=args.num_folds) # display auc = metrics.auc(fpr, tpr) eer = brentq(lambda x: 1. - x - interpolate.interp1d(fpr, tpr)(x), 0., 1.) print('Acc: %1.3f+-%1.3f' % (np.mean(acc), np.std(acc))) print('VR@FAR=%2.5f: %2.5f+-%2.5f' % (far, vr, vr_std)) print('AUC: %1.3f' % auc) print('EER: %1.3f' % eer)
def main(args): model_module = importlib.import_module(args.model_def) subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S') log_dir = os.path.join(args.logs_base_dir, subdir) model_dir = os.path.join(args.models_base_dir, subdir) if not os.path.isdir(log_dir): os.makedirs(log_dir) if not os.path.isdir(model_dir): os.makedirs(model_dir) print('log dir: %s' % log_dir) print('model dir: %s' % model_dir) if args.lfw_dir: print('LFW directory: %s' % args.lfw_dir) pairs = test_utils.read_pairs(args.lfw_pairs) lfw_paths, actual_issame = test_utils.get_paths( args.lfw_dir, pairs, args.lfw_file_ext) with tf.Graph().as_default(): # ---- data prepration ---- # image_list, label_list, num_classes = train_utils.get_datasets( args.data_dir, args.imglist_path) range_size = len(image_list) assert range_size > 0, 'The dataset should not be empty.' # random indices producer indices_que = tf.train.range_input_producer(range_size) deque_op = indices_que.dequeue_many(args.batch_size * args.epoch_size, 'index_dequeue') tf.set_random_seed(args.seed) random.seed(args.seed) np.random.seed(args.seed) global_step = tf.Variable(0, trainable=False) lr_pl = tf.placeholder(tf.float32, name='learning_rate') batch_size_pl = tf.placeholder(tf.int32, name='batch_size') phase_train_pl = tf.placeholder(tf.bool, name='phase_train') imgpaths_pl = tf.placeholder(tf.string, name='image_paths') labels_pl = tf.placeholder(tf.int64, name='labels') # filename queue input_queue = tf.FIFOQueue( # [notice: capacity > bach_size*epoch_size] capacity=100000, dtypes=[tf.string, tf.int64], shapes=[(1, ), (1, )], shared_name=None, name='input_que') enque_op = input_queue.enqueue_many([imgpaths_pl, labels_pl], name='enque_op') # define 4 readers num_threads = 4 threads_input_list = [] for _ in range(num_threads): img_paths, label = input_queue.dequeue( ) # [notice: 'img_pathx' and 'label' are both tensors] images = [] for img_path in tf.unstack(img_paths): img_contents = tf.read_file(img_path) img = tf.image.decode_jpeg(img_contents) if args.random_crop: img = tf.random_crop(img, [args.image_size, args.image_size, 3]) else: img = tf.image.resize_image_with_crop_or_pad( img, args.image_size, args.image_size) if args.random_flip: img = tf.image.random_flip_left_right(img) img.set_shape((args.image_size, args.image_size, 3)) images.append( tf.image.per_image_standardization(img)) # prewhitened? threads_input_list.append([images, label]) # define 4 buffer queue image_batch, label_batch = tf.train.batch_join( threads_input_list, # [notice: here is 'batch_size_pl', not 'batch_size'!!] batch_size=batch_size_pl, shapes=[(args.image_size, args.image_size, 3), ()], enqueue_many=True, # [notice: how long the prefetching is allowed to fill the queue] capacity=4 * num_threads * args.batch_size, allow_smaller_final_batch=True) image_batch = tf.identity(image_batch, 'image_batch') image_batch = tf.identity(image_batch, 'input') label_batch = tf.identity(label_batch, 'label_batch') print('Total classes: %d' % num_classes) print('Total images: %d' % range_size) tf.summary.image('input_images', image_batch, 10) # ---- build graph ---- # with tf.device('/gpu:%d' % args.gpu_id): # embeddings prelogits, _ = model_module.inference( image_batch, args.keep_prob, phase_train=phase_train_pl, weight_decay=args.weight_decay) # logits logits = slim.fully_connected( prelogits, num_classes, activation_fn=None, weights_initializer=tf.truncated_normal_initializer( stddev=0.1), weights_regularizer=slim.l2_regularizer(args.weight_decay), scope='Logits', reuse=False) # normalized features # [notice: used in test stage] embeddings = tf.nn.l2_normalize(prelogits, 1, 1e-10, name='embeddings') # ---- define loss & train op ---- # # center loss if args.center_loss_factor > 0.0: prelogits_center_loss, _ = train_utils.center_loss( prelogits, label_batch, args.center_loss_alpha, num_classes) tf.summary.scalar( 'center_loss', prelogits_center_loss * args.center_loss_factor) tf.add_to_collection( tf.GraphKeys.REGULARIZATION_LOSSES, prelogits_center_loss * args.center_loss_factor) # cross-entropy cross_entropy_mean = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=label_batch, logits=logits), name='cross_entropy') tf.add_to_collection('losses', cross_entropy_mean) # regularity: weight decay reg_loss = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) # total loss total_loss = tf.add_n([cross_entropy_mean] + reg_loss, name='total_loss') # [notice: here we decay manually] lr = tf.train.exponential_decay(lr_pl, global_step, args.lr_decay_epochs * args.epoch_size, args.lr_decay_factor, staircase=True) tf.summary.scalar('learning_rate', lr) train_op = train_utils.get_train_op( total_loss, global_step, args.optimizer, lr, args.moving_average_decay, # what is the usage of tf.global_variables()? tf.trainable_variables()) # ---- training ---- # # [notice: use 'allow_growth' instead of memory_fraction] gpu_options = tf.GPUOptions(allow_growth=True) # [notice: use 'allow_soft_placement' to solve the problem of 'no supported kernel...'] sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, log_device_placement=False, allow_soft_placement=True)) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) # [notice: 'max_to_keep': keep at most 'max_to_keep' checkpoint files] saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=5) summary_op = tf.summary.merge_all() summary_writer = tf.summary.FileWriter(log_dir, sess.graph) coord = tf.train.Coordinator() tf.train.start_queue_runners(coord=coord, sess=sess) with sess.as_default(): if args.pretrained_model: print('Resume training: %s' % args.pretrained_model) saver.restore(sess, args.pretrained_model) print('Start training ...') epoch = 0 while epoch < args.max_num_epochs: step = sess.run(global_step, feed_dict=None) # training counter epoch = step // args.epoch_size # run epoch run_epoch(args, sess, epoch, image_list, label_list, deque_op, enque_op, imgpaths_pl, labels_pl, lr_pl, phase_train_pl, batch_size_pl, global_step, total_loss, reg_loss, train_op, summary_op, summary_writer) # snapshot for currently learnt weights snapshot(sess, saver, model_dir, subdir, step) # evaluate on LFW if args.lfw_dir: evaluate(sess, enque_op, imgpaths_pl, labels_pl, phase_train_pl, batch_size_pl, embeddings, label_batch, lfw_paths, actual_issame, args.lfw_batch_size, args.lfw_num_folds, log_dir, step, summary_writer) sess.close()
def main(args): subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S') log_dir = os.path.join(args.logs_base_dir, subdir, 'logs') model_dir = os.path.join(args.logs_base_dir, subdir, 'models') if not os.path.isdir(log_dir): os.makedirs(log_dir) if not os.path.isdir(model_dir): os.makedirs(model_dir) print('log dir: %s' % log_dir) print('model dir: %s' % model_dir) # build the graph # ---- load pretrained model ---- # pretrained = {} # Face model pretrained['Face'] = np.load(args.face_model)[()] # Nose model pretrained['Nose'] = np.load(args.nose_model)[()] # Lefteye model pretrained['Lefteye'] = np.load(args.lefteye_model)[()] # Rightmouth model pretrained['Rightmouth'] = np.load(args.rightmouth_model)[()] # ---- data preparation ---- # image_list, label_list, num_classes = train_utils.get_datasets( args.data_dir, args.imglist) range_size = len(image_list) if args.lfw_dir: print('LFW directory: %s' % args.lfw_dir) pairs = test_utils.read_pairs(args.lfw_pairs) lfw_paths, actual_issame = test_utils.get_paths( args.lfw_dir, pairs, args.lfw_file_ext) with tf.Graph().as_default(): # random indices producer indices_que = tf.train.range_input_producer(range_size) dequeue_op = indices_que.dequeue_many( args.batch_size * args.epoch_size, 'index_dequeue') tf.set_random_seed(args.seed) random.seed(args.seed) np.random.seed(args.seed) global_step = tf.Variable(0, trainable=False) # lr_base_pl = tf.placeholder(tf.float32, name='base_learning_rate') lr_fusion_pl = tf.placeholder(tf.float32, name='fusion_learning_rate') batch_size_pl = tf.placeholder(tf.int32, name='batch_size') phase_train_pl = tf.placeholder(tf.bool, name='phase_train') face_pl = tf.placeholder(tf.string, name='image_paths1') # face images nose_pl = tf.placeholder(tf.string, name='image_paths2') # nose images lefteye_pl = tf.placeholder(tf.string, name='image_paths3') # left eye images rightmouth_pl = tf.placeholder( tf.string, name='image_paths4') # right mouth images labels_pl = tf.placeholder(tf.int64, name='labels') # define a filename queue input_queue = tf.FIFOQueue( # [notice: capacity > bach_size*epoch_size] capacity=100000, dtypes=[tf.string, tf.string, tf.string, tf.string, tf.int64], shapes=[(1, ), (1, ), (1, ), (1, ), (1, )], shared_name=None, name='input_que') enque_op = input_queue.enqueue_many( [face_pl, nose_pl, lefteye_pl, rightmouth_pl, labels_pl], name='enque_op') # define 4 readers num_threads = 4 threads_input_list = [] for _ in range(num_threads): imgpath1, imgpath2, imgpath3, imgpath4, label = input_queue.dequeue( ) images1 = [] images2 = [] images3 = [] images4 = [] # face for img_path in tf.unstack(imgpath1): img_contents = tf.read_file(img_path) img = tf.image.decode_jpeg(img_contents) # [notice: random crop only used in face image] if args.random_crop: img = tf.random_crop(img, [160, 160, 3]) else: img = tf.image.resize_image_with_crop_or_pad(img, 160, 160) # [notice: flip only used in face image or nose patch] if args.random_flip: img = tf.image.random_flip_left_right(img) img.set_shape((160, 160, 3)) images1.append(tf.image.per_image_standardization(img)) # Nose for img_path in tf.unstack(imgpath2): img_contents = tf.read_file(img_path) img = tf.image.decode_jpeg(img_contents) # [notice: flip only used in face image or nose patch] if args.random_flip: img = tf.image.random_flip_left_right(img) img.set_shape((160, 160, 3)) images2.append(tf.image.per_image_standardization(img)) # Lefteye for img_path in tf.unstack(imgpath3): img_contents = tf.read_file(img_path) img = tf.image.decode_jpeg(img_contents) img.set_shape((160, 160, 3)) images3.append(tf.image.per_image_standardization(img)) # Rightmouth for img_path in tf.unstack(imgpath4): img_contents = tf.read_file(img_path) img = tf.image.decode_jpeg(img_contents) img.set_shape((160, 160, 3)) images4.append(tf.image.per_image_standardization(img)) threads_input_list.append( [images1, images2, images3, images4, label]) # define 4 buffer queue face_batch, nose_batch, lefteye_batch, rightmouth_batch, label_batch = tf.train.batch_join( threads_input_list, # [notice: here is 'batch_size_pl', not 'batch_size'!!] batch_size=batch_size_pl, shapes=[ # [notice: shape of each element should be assigned, otherwise it raises # "tensorflow queue shapes must have the same length as dtype" exception] (args.image_size, args.image_size, 3), (args.image_size, args.image_size, 3), (args.image_size, args.image_size, 3), (args.image_size, args.image_size, 3), () ], enqueue_many=True, # [notice: how long the prefetching is allowed to fill the queue] capacity=4 * num_threads * args.batch_size, allow_smaller_final_batch=True) print('Total classes: %d' % num_classes) print('Total images: %d' % range_size) tf.summary.image('face_images', face_batch, 10) tf.summary.image('nose_images', nose_batch, 10) tf.summary.image('lefteye_images', lefteye_batch, 10) tf.summary.image('rightmouth_images', rightmouth_batch, 10) # ---- build graph ---- # with tf.variable_scope('BaseModel'): with tf.device('/gpu:%d' % args.gpu_id1): # embeddings for face model features1, _ = inception_resnet_v1.inference( face_batch, args.keep_prob, phase_train=phase_train_pl, weight_decay=args.weight_decay, scope='Face') with tf.device('/gpu:%d' % args.gpu_id2): # embeddings for nose model features2, _ = inception_resnet_v1.inference( nose_batch, args.keep_prob, phase_train=phase_train_pl, weight_decay=args.weight_decay, scope='Nose') with tf.device('/gpu:%d' % args.gpu_id3): # embeddings for left eye model features3, _ = inception_resnet_v1.inference( lefteye_batch, args.keep_prob, phase_train=phase_train_pl, weight_decay=args.weight_decay, scope='Lefteye') with tf.device('/gpu:%d' % args.gpu_id4): # embeddings for right mouth model features4, _ = inception_resnet_v1.inference( rightmouth_batch, args.keep_prob, phase_train=phase_train_pl, weight_decay=args.weight_decay, scope='Rightmouth') with tf.device('/gpu:%d' % args.gpu_id5): with tf.variable_scope("Fusion"): # ---- concatenate ---- # concated_features = tf.concat( [features1, features2, features3, features4], 1) # prelogits prelogits = slim.fully_connected( concated_features, args.fusion_dim, activation_fn=None, weights_initializer=tf.truncated_normal_initializer( stddev=0.1), weights_regularizer=slim.l2_regularizer(args.weight_decay), scope='prelogits', reuse=False) # logits logits = slim.fully_connected( prelogits, num_classes, activation_fn=None, weights_initializer=tf.truncated_normal_initializer( stddev=0.1), weights_regularizer=slim.l2_regularizer(args.weight_decay), scope='logits', reuse=False) # normalized feaures # [notice: used in test stage] embeddings = tf.nn.l2_normalize(prelogits, 1, 1e-10, name='embeddings') # ---- define loss & train op ---- # cross_entropy = -tf.reduce_sum(tf.one_hot( indices=tf.cast(label_batch, tf.int32), depth=num_classes) * tf.log(tf.nn.softmax(logits) + 1e-10), reduction_indices=[1]) cross_entropy_mean = tf.reduce_mean(cross_entropy) tf.add_to_collection('losses', cross_entropy_mean) # weight decay reg_loss = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) # total loss: cross_entropy + weight_decay total_loss = tf.add_n([cross_entropy_mean] + reg_loss, name='total_loss') ''' lr_base = tf.train.exponential_decay(lr_base_pl, global_step, args.lr_decay_epochs * args.epoch_size, args.lr_decay_factor, staircase = True) ''' lr_fusion = tf.train.exponential_decay(lr_fusion_pl, global_step, args.lr_decay_epochs * args.epoch_size, args.lr_decay_factor, staircase=True) # tf.summary.scalar('base_learning_rate', lr_base) tf.summary.scalar('fusion_learning_rate', lr_fusion) var_list1 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='BaseModel') var_list2 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Fusion') ''' train_op = train_utils.get_fusion_train_op( total_loss, global_step, args.optimizer, lr_base, var_list1, lr_fusion, var_list2, args.moving_average_decay) ''' train_op = train_utils.get_train_op(total_loss, global_step, args.optimizer, lr_fusion, args.moving_average_decay, var_list2) # ---- training ---- # gpu_options = tf.GPUOptions(allow_growth=True) sess = tf.Session(config=tf.ConfigProto( gpu_options=gpu_options, log_device_placement=False, # [notice: 'allow_soft_placement' will switch to cpu automatically # when some operations are not supported by GPU] allow_soft_placement=True)) saver = tf.train.Saver(var_list1 + var_list2) summary_op = tf.summary.merge_all() summary_writer = tf.summary.FileWriter(log_dir, sess.graph) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) coord = tf.train.Coordinator() tf.train.start_queue_runners(coord=coord, sess=sess) with sess.as_default(): # ---- restore pre-trained parameters ---- # to_assign = [] print("restore pretrained parameters...") print("total:", len(var_list1)) for v in var_list1: v_name = v.name # 'BaseModel/Face/xxx' v_name = v_name[v_name.find('/') + 1:] # 'Face/xxx' v_name_1 = v_name[:v_name.find('/')] # 'Face' v_name_2 = v_name[v_name.find('/'):] # '/xxx' print("precess: %s" % v_name, end=" ") if v_name_1 in pretrained: to_assign.append( v.assign(pretrained[v_name_1][v_name_2][0])) print("[ok]") else: print("[no found]") v.assign(pretrained[v_name_1][v_name_2][0]) print("done") sess.run(to_assign) print("start training ...") epoch = 0 while epoch < args.max_num_epochs: step = sess.run(global_step, feed_dict=None) epoch = step // args.epoch_size # run one epoch run_epoch(args, sess, epoch, image_list, label_list, dequeue_op, enque_op, face_pl, nose_pl, lefteye_pl, rightmouth_pl, labels_pl, lr_fusion_pl, phase_train_pl, batch_size_pl, global_step, total_loss, reg_loss, train_op, summary_op, summary_writer) # snapshot for currently learnt weights snapshot(sess, saver, model_dir, subdir, step) # evaluate on LFW if args.lfw_dir: evaluate(sess, enque_op, face_pl, nose_pl, lefteye_pl, rightmouth_pl, labels_pl, phase_train_pl, batch_size_pl, embeddings, label_batch, lfw_paths, actual_issame, args.lfw_batch_size, args.lfw_num_folds, log_dir, step, summary_writer) sess.close()
def main(args): subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S') log_dir = os.path.join(args.logs_base_dir, 'logs', subdir) model_dir = os.path.join(args.logs_base_dir, 'models', subdir) if not os.path.isdir(log_dir): os.makedirs(log_dir) if not os.path.isdir(model_dir): os.makedirs(model_dir) print('log dir: %s' % log_dir) print('model dir: %s' % model_dir) if args.lfw_dir: print('lfw directory: %s' % args.lfw_dir) pairs = test_utils.read_pairs(args.lfw_pairs) lfw_paths, lfw_label = test_utils.get_paths(args.lfw_dir, pairs, args.lfw_file_ext) with tf.Graph().as_default(): # ------------ data preparation ------------ # image_list, label_list, num_classes = train_utils.get_datasets( args.data_dir, args.imglist_path) range_size = len(image_list) assert range_size > 0, 'The data set should not be empty.' # random indices producer indices_que = tf.train.range_input_producer(range_size) deque_op = indices_que.dequeue_many(args.batch_size * args.epoch_size, 'index_dequeue') # [notice: how to set random seed?] tf.set_random_seed(args.seed) random.seed(args.seed) np.random.seed(args.seed) # filename queue imgpaths_pl = tf.placeholder(tf.string, name='image_paths') labels_pl = tf.placeholder(tf.int64, name='labels') input_queue = tf.FIFOQueue( # [notice: capacity > bach_size*epoch_size] capacity=100000, dtypes=[tf.string, tf.int64], shapes=[(1, ), (1, )], shared_name=None, name='input_que') enque_op = input_queue.enqueue_many([imgpaths_pl, labels_pl], name='enque_op') # define 4 readers num_threads = 4 threads_input_list = [] for _ in range(num_threads): img_paths, label = input_queue.dequeue( ) # [notice: 'img_pathx' and 'label' are both tensors] images = [] for img_path in tf.unstack(img_paths): img_contents = tf.read_file(img_path) img = tf.image.decode_jpeg(img_contents) if args.random_crop: img = tf.random_crop(img, [args.image_size, args.image_size, 3]) else: img = tf.image.resize_image_with_crop_or_pad( img, args.image_size, args.image_size) if args.random_flip: img = tf.image.random_flip_left_right(img) img.set_shape((args.image_size, args.image_size, 3)) images.append( tf.image.per_image_standardization(img)) # pre-whitened? threads_input_list.append([images, label]) # define 4 buffer queue batch_size_pl = tf.placeholder(tf.int32, name='batch_size') image_batch, label_batch = tf.train.batch_join( threads_input_list, # [notice: here is 'batch_size_pl', not 'batch_size'!!] batch_size=batch_size_pl, shapes=[(args.image_size, args.image_size, 3), ()], enqueue_many=True, # [notice: how long the pre-fetching is allowed to fill the queue] capacity=4 * num_threads * args.batch_size, allow_smaller_final_batch=True) image_batch = tf.identity(image_batch, 'image_batch') label_batch = tf.identity(label_batch, 'label_batch') print('Total classes: %d' % num_classes) print('Total images: %d' % range_size) tf.summary.image('input_images', image_batch, 10) # ------------ build graph ------------ # hps_train = resnet.HParams(batch_size=batch_size_pl, num_residual_units=5, use_bottleneck=True, relu_leakiness=0.1) global_step = tf.Variable(0, trainable=False) phase_train_pl = tf.placeholder(tf.bool, name='phase_train') resnet_model = resnet(hps_train, phase_train_pl) with tf.device('/gpu:%d' % args.gpu_id): # ---- base graph ---- # with tf.variable_scope('ResNet'): # prelogits prelogits = resnet_model.inference(image_batch) # prelogits -> embeddings [notice: used in test stage] embeddings = tf.nn.l2_normalize(prelogits, 1, 1e-10, name='embeddings') # prelogits -> logits with tf.variable_scope('Logits'): logits = resnet.fully_connected(prelogits, num_classes) # predictions = tf.nn.softmax(logits) # ---- losses ---- # # cross entropy with tf.variable_scope('cross_entropy'): cross_entropy = tf.reduce_sum( tf.one_hot(indices=tf.cast(label_batch, tf.int32), depth=num_classes) * tf.log(tf.nn.softmax(logits) + 1e-10), reduction_indices=[1]) cross_entropy_mean = tf.reduce_mean(cross_entropy) tf.summary.scalar('cross_entropy', cross_entropy_mean) # l2 loss reg_loss = resnet.decay(args.weight_decay) tf.summary.scalar('reg_loss', reg_loss) # total loss total_loss = tf.add_n([cross_entropy_mean] + reg_loss, name='total_loss') train_op = resnet_model.get_train_op(total_loss, global_step, args.lr) # ------------ training ------------ # # [notice: use 'allow_growth' instead of memory_fraction] gpu_options = tf.GPUOptions(allow_growth=True) # [notice: use 'allow_soft_placement' to solve the problem of 'no supported kernel...'] sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, log_device_placement=False, allow_soft_placement=True)) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) # [notice: 'max_to_keep': keep at most 'max_to_keep' checkpoint files] saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=5) summary_op = tf.summary.merge_all() summary_writer = tf.summary.FileWriter(log_dir, sess.graph) coord = tf.train.Coordinator() tf.train.start_queue_runners(coord=coord, sess=sess) with sess.as_default(): if args.pretrained_model: print('Resume training: %s' % args.pretrained_model) saver.restore(sess, args.pretrained_model) print('Start training ...') epoch = 0 while epoch < args.max_num_epochs: step = sess.run(global_step, feed_dict=None) # training counter epoch = step // args.epoch_size # run epoch run_epoch(args, sess, epoch, image_list, label_list, deque_op, enque_op, imgpaths_pl, labels_pl, phase_train_pl, batch_size_pl, global_step, total_loss, reg_loss, train_op, summary_op, summary_writer) # snapshot for currently learnt weights snapshot(sess, saver, model_dir, subdir, step) # evaluate on LFW if args.lfw_dir: evaluate(sess, enque_op, imgpaths_pl, labels_pl, phase_train_pl, batch_size_pl, embeddings, label_batch, lfw_paths, lfw_label, args.lfw_batch_size, args.lfw_num_folds, log_dir, step, summary_writer) sess.close()