def load_images(paths, group=None, verbose=True): """Loads and rescales input images to the diagonal of the reference shape. Args: paths: a list of strings containing the data directories. reference_shape: a numpy array [num_landmarks, 2] group: landmark group containing the grounth truth landmarks. verbose: boolean, print debugging info. Returns: images: a list of numpy arrays containing images. shapes: a list of the ground truth landmarks. reference_shape: a numpy array [num_landmarks, 2]. shape_gen: PCAModel, a shape generator. """ images = [] shapes = [] bbs = [] reference_shape = PointCloud(build_reference_shape(paths)) for path in paths: if verbose: print('Importing data from {}'.format(path)) for im in mio.import_images(path, verbose=verbose, as_generator=True): group = group or im.landmarks[group]._group_label bb_root = im.path.parent.relative_to(im.path.parent.parent.parent) if 'set' not in str(bb_root): bb_root = im.path.parent.relative_to(im.path.parent.parent) im.landmarks['bb'] = mio.import_landmark_file( str(Path('bbs') / bb_root / (im.path.stem + '.pts'))) im = im.crop_to_landmarks_proportion(0.3, group='bb') im = im.rescale_to_pointcloud(reference_shape, group=group) im = grey_to_rgb(im) images.append(im.pixels.transpose(1, 2, 0)) shapes.append(im.landmarks[group].lms) bbs.append(im.landmarks['bb'].lms) train_dir = Path(FLAGS.train_dir) mio.export_pickle(reference_shape.points, train_dir / 'reference_shape.pkl', overwrite=True) print('created reference_shape.pkl using the {} group'.format(group)) pca_model = detect.create_generator(shapes, bbs) # Pad images to max length max_shape = np.max([im.shape for im in images], axis=0) max_shape = [len(images)] + list(max_shape) padded_images = np.random.rand(*max_shape).astype(np.float32) print(padded_images.shape) for i, im in enumerate(images): height, width = im.shape[:2] dy = max(int((max_shape[1] - height - 1) / 2), 0) dx = max(int((max_shape[2] - width - 1) / 2), 0) lms = shapes[i] pts = lms.points pts[:, 0] += dy pts[:, 1] += dx lms = lms.from_vector(pts) padded_images[i, dy:(height + dy), dx:(width + dx)] = im return padded_images, shapes, reference_shape.points, pca_model
def load_images_aflw(paths, group=None, verbose=True, PLOT=True, AFLW=False, PLOT_shape=False): """Loads and rescales input knn_2D to the diagonal of the reference shape. Args: paths: a list of strings containing the data directories. reference_shape (meanshape): a numpy array [num_landmarks, 2] group: landmark group containing the grounth truth landmarks. verbose: boolean, print debugging info. Returns: knn_2D: a list of numpy arrays containing knn_2D. shapes: a list of the ground truth landmarks. reference_shape (meanshape): a numpy array [num_landmarks, 2]. shape_gen: PCAModel, a shape generator. """ images = [] shapes = [] bbs = [] shape_space = [] plot_shape_x = [] plot_shape_y = [] # compute mean shape if AFLW: # reference_shape = PointCloud(mio.import_pickle(Path('/home/hliu/gmh/RL_FA/mdm_aflw/ckpt/train_aflw') / 'reference_shape.pkl')) reference_shape = mio.import_pickle( Path('/home/hliu/gmh/RL_FA/mdm_aflw/ckpt/train_aflw') / 'reference_shape.pkl') else: reference_shape = PointCloud(build_reference_shape(paths)) for path in paths: if verbose: print('Importing data from {}'.format(path)) for im in mio.import_images(path, verbose=verbose, as_generator=True): # group = group or im.landmarks[group]._group_label group = group or im.landmarks.keys()[0] bb_root = im.path.parent.relative_to(im.path.parent.parent.parent) if 'set' not in str(bb_root): bb_root = im.path.parent.relative_to(im.path.parent.parent) if AFLW: im.landmarks['bb'] = im.landmarks['PTS'].lms.bounding_box() else: im.landmarks['bb'] = mio.import_landmark_file( str(Path('bbs') / bb_root / (im.path.stem + '.pts'))) im = im.crop_to_landmarks_proportion(0.3, group='bb') im = im.rescale_to_pointcloud(reference_shape, group=group) im = grey_to_rgb(im) # knn_2D.append(im.pixels.transpose(1, 2, 0)) shapes.append(im.landmarks[group].lms) shape_space.append(im.landmarks[group].lms.points) bbs.append(im.landmarks['bb'].lms) if PLOT_shape: x_tmp = np.sum((im.landmarks[group].lms.points[:, 0] - reference_shape.points[:, 0])) y_tmp = np.sum((im.landmarks[group].lms.points[:, 1] - reference_shape.points[:, 1])) if x_tmp < 0 and y_tmp < 0: plot_shape_x.append(x_tmp) plot_shape_y.append(y_tmp) shape_space = np.array(shape_space) print('shape_space:', shape_space.shape) train_dir = Path(FLAGS.train_dir) if PLOT_shape: k_nn_plot_x = [] k_nn_plot_y = [] centers = utils.k_means(shape_space, 500, num_patches=19) centers = np.reshape(centers, [-1, 19, 2]) for i in range(centers.shape[0]): x_tmp = np.sum((centers[i, :, 0] - reference_shape.points[:, 0])) y_tmp = np.sum((centers[i, :, 1] - reference_shape.points[:, 1])) if x_tmp < 0 and y_tmp < 0: k_nn_plot_x.append(x_tmp) k_nn_plot_y.append(y_tmp) # plt.scatter(plot_shape_x, plot_shape_y, s=20) # plt.scatter(k_nn_plot_x, k_nn_plot_y, s=40) # plt.xticks(()) # plt.yticks(()) # plt.show() # pdb.set_trace() np.save(train_dir / 'shape_space_all.npy', shape_space) # centers = utils.k_means(shape_space, 100) # centers = np.reshape(centers, [-1, 68, 2]) # np.save(train_dir/'shape_space_origin.npy', centers) # print('created shape_space.npy using the {} group'.format(group)) # exit(0) mio.export_pickle(reference_shape.points, train_dir / 'reference_shape.pkl', overwrite=True) print('created reference_shape.pkl using the {} group'.format(group)) pca_model = detect.create_generator(shapes, bbs) # Pad knn_2D to max length max_shape = [272, 261, 3] padded_images = np.random.rand(*max_shape).astype(np.float32) print(padded_images.shape) if PLOT: # plot without padding centers = utils.k_means(shape_space, 500, num_patches=19) centers = np.reshape(centers, [-1, 19, 2]) plot_img = cv2.imread('a.png').transpose(2, 0, 1) centers_tmp = np.zeros(centers.shape) # menpo_img = mio.import_image('a.png') menpo_img = menpo.image.Image(plot_img) for i in range(centers.shape[0]): menpo_img.view() min_y = np.min(centers[i, :, 0]) min_x = np.min(centers[i, :, 1]) centers_tmp[i, :, 0] = centers[i, :, 0] - min_y + 20 centers_tmp[i, :, 1] = centers[i, :, 1] - min_x + 20 print(centers_tmp[i, :, :]) menpo_img.landmarks['center'] = PointCloud(centers_tmp[i, :, :]) menpo_img.view_landmarks(group='center', marker_face_colour='b', marker_size='16') # menpo_img.landmarks['center'].view(render_legend=True) plt.savefig('plot_shape_space_aflw/' + str(i) + '.png') plt.close() exit(0) # !!!shape_space without delta, which means shape_space has already been padded! # delta = np.zeros(shape_space.shape) for i, im in enumerate(images): height, width = im.shape[:2] dy = max(int((max_shape[0] - height - 1) / 2), 0) dx = max(int((max_shape[1] - width - 1) / 2), 0) lms = shapes[i] pts = lms.points pts[:, 0] += dy pts[:, 1] += dx shape_space[i, :, 0] += dy shape_space[i, :, 1] += dx # delta[i][:, 0] = dy # delta[i][:, 1] = dx lms = lms.from_vector(pts) padded_images[i, dy:(height + dy), dx:(width + dx)] = im # shape_space = np.concatenate((shape_space, delta), 2) centers = utils.k_means(shape_space, 1000, num_patches=19) centers = np.reshape(centers, [-1, 19, 2]) # pdb.set_trace() np.save(train_dir / 'shape_space.npy', centers) print('created shape_space.npy using the {} group'.format(group)) exit(0) return padded_images, shapes, reference_shape.points, pca_model, centers
def train(scope=''): """Train on dataset for a number of steps.""" with tf.Graph().as_default(), tf.device('/gpu:0'): # Global steps tf_global_step = tf.get_variable( 'GlobalStep', [], initializer=tf.constant_initializer(0), trainable=False) # Learning rate tf_lr = tf.train.exponential_decay(g_config['learning_rate'], tf_global_step, g_config['learning_rate_step'], g_config['learning_rate_decay'], staircase=True, name='LearningRate') tf.summary.scalar('learning_rate', tf_lr) # Create an optimizer that performs gradient descent. opt = tf.train.AdamOptimizer(tf_lr) data_provider.prepare_images(g_config['train_dataset'].split(':'), num_patches=g_config['num_patches'], verbose=True) path_base = Path(g_config['train_dataset'].split(':')[0]).parent.parent _mean_shape = mio.import_pickle(path_base / 'reference_shape.pkl') with Path(path_base / 'meta.txt').open('r') as ifs: _image_shape = [int(x) for x in ifs.read().split(' ')] assert (isinstance(_mean_shape, np.ndarray)) _pca_shapes = [] _pca_bbs = [] for item in tf.io.tf_record_iterator(str(path_base / 'pca.bin')): example = tf.train.Example() example.ParseFromString(item) _pca_shape = np.array(example.features.feature['pca/shape']. float_list.value).reshape((-1, 2)) _pca_bb = np.array( example.features.feature['pca/bb'].float_list.value).reshape( (-1, 2)) _pca_shapes.append(PointCloud(_pca_shape)) _pca_bbs.append(PointCloud(_pca_bb)) _pca_model = detect.create_generator(_pca_shapes, _pca_bbs) assert (_mean_shape.shape[0] == g_config['num_patches']) tf_mean_shape = tf.constant(_mean_shape, dtype=tf.float32, name='MeanShape') def decode_feature(serialized): feature = { 'train/image': tf.FixedLenFeature([], tf.string), 'train/shape': tf.VarLenFeature(tf.float32), } features = tf.parse_single_example(serialized, features=feature) decoded_image = tf.decode_raw(features['train/image'], tf.float32) decoded_image = tf.reshape(decoded_image, _image_shape) decoded_shape = tf.sparse.to_dense(features['train/shape']) decoded_shape = tf.reshape(decoded_shape, (g_config['num_patches'], 2)) return decoded_image, decoded_shape def get_random_sample(image, shape, rotation_stddev=10): # Read a random image with landmarks and bb image = menpo.image.Image(image.transpose((2, 0, 1)), copy=False) image.landmarks['PTS'] = PointCloud(shape) if np.random.rand() < .5: image = utils.mirror_image(image) if np.random.rand() < .5: theta = np.random.normal(scale=rotation_stddev) rot = menpo.transform.rotate_ccw_about_centre( image.landmarks['PTS'], theta) image = image.warp_to_shape(image.shape, rot) bb = image.landmarks['PTS'].bounding_box().points miny, minx = np.min(bb, 0) maxy, maxx = np.max(bb, 0) bbsize = max(maxx - minx, maxy - miny) center = [(miny + maxy) / 2., (minx + maxx) / 2.] image.landmarks['bb'] = PointCloud([ [center[0] - bbsize * 0.5, center[1] - bbsize * 0.5], [center[0] + bbsize * 0.5, center[1] + bbsize * 0.5], ]).bounding_box() proportion = float(np.random.rand() / 3) image = image.crop_to_landmarks_proportion(proportion, group='bb') image = image.resize((112, 112)) random_image = image.pixels.transpose(1, 2, 0).astype('float32') random_shape = image.landmarks['PTS'].points.astype('float32') return random_image, random_shape def get_init_shape(image, shape, mean_shape): def norm(x): return tf.sqrt( tf.reduce_sum(tf.square(x - tf.reduce_mean(x, 0)))) with tf.name_scope('align_shape_to_bb', values=[mean_shape]): min_xy = tf.reduce_min(mean_shape, 0) max_xy = tf.reduce_max(mean_shape, 0) min_x, min_y = min_xy[0], min_xy[1] max_x, max_y = max_xy[0], max_xy[1] mean_shape_bb = tf.stack([[min_x, min_y], [max_x, min_y], [max_x, max_y], [min_x, max_y]]) bb = tf.stack([[0.0, 0.0], [112.0, 0.0], [112.0, 112.0], [0.0, 112.0]]) ratio = norm(bb) / norm(mean_shape_bb) initial_shape = tf.add( (mean_shape - tf.reduce_mean(mean_shape_bb, 0)) * ratio, tf.reduce_mean(bb, 0), name='initial_shape') initial_shape.set_shape(tf_mean_shape.get_shape()) return image, shape, initial_shape def distort_color(image, shape, init_shape): return data_provider.distort_color(image), shape, init_shape with tf.name_scope('DataProvider', values=[tf_mean_shape]): tf_dataset = tf.data.TFRecordDataset( [str(path_base / 'train.bin')]) tf_dataset = tf_dataset.repeat() tf_dataset = tf_dataset.map(decode_feature) tf_dataset = tf_dataset.map(lambda x, y: tf.py_func( get_random_sample, [x, y], [tf.float32, tf.float32], stateful=True, name='RandomSample')) tf_dataset = tf_dataset.map( partial(get_init_shape, mean_shape=tf_mean_shape)) tf_dataset = tf_dataset.map(distort_color) tf_dataset = tf_dataset.batch(g_config['batch_size'], True) tf_dataset = tf_dataset.prefetch(7500) tf_iterator = tf_dataset.make_one_shot_iterator() tf_images, tf_shapes, tf_initial_shapes = tf_iterator.get_next( name='Batch') tf_images.set_shape([g_config['batch_size'], 112, 112, 3]) tf_shapes.set_shape([g_config['batch_size'], 73, 2]) tf_initial_shapes.set_shape([g_config['batch_size'], 73, 2]) print('Defining model...') with tf.device(g_config['train_device']): tf_model = mdm_model.MDMModel( tf_images, tf_shapes, tf_initial_shapes, batch_size=g_config['batch_size'], num_iterations=g_config['num_iterations'], num_patches=g_config['num_patches'], patch_shape=(g_config['patch_size'], g_config['patch_size']), num_channels=3) with tf.name_scope('Losses', values=[tf_model.prediction, tf_shapes]): tf_norm_error = tf_model.normalized_rmse( tf_model.prediction, tf_shapes) tf_loss = tf.reduce_mean(tf_norm_error) tf.summary.scalar('losses/total', tf_loss) # Calculate the gradients for the batch of data tf_grads = opt.compute_gradients(tf_loss) tf.summary.histogram('dx', tf_model.prediction - tf_shapes) bn_updates = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope) # Add histograms for gradients. for grad, var in tf_grads: if grad is not None: tf.summary.histogram(var.op.name + '/gradients', grad) # Apply the gradients to adjust the shared variables. with tf.name_scope('Optimizer', values=[tf_grads, tf_global_step]): apply_gradient_op = opt.apply_gradients(tf_grads, global_step=tf_global_step) # Add histograms for trainable variables. for var in tf.trainable_variables(): tf.summary.histogram(var.op.name, var) # Track the moving averages of all trainable variables. # Note that we maintain a "double-average" of the BatchNormalization # global statistics. This is more complicated then need be but we employ # this for backward-compatibility with our previous models. with tf.name_scope('MovingAverage', values=[tf_global_step]): variable_averages = tf.train.ExponentialMovingAverage( g_config['MOVING_AVERAGE_DECAY'], tf_global_step) variables_to_average = (tf.trainable_variables() + tf.moving_average_variables()) variables_averages_op = variable_averages.apply( variables_to_average) # Group all updates to into a single train op. bn_updates_op = tf.group(*bn_updates, name='BNGroup') train_op = tf.group(apply_gradient_op, variables_averages_op, bn_updates_op, name='TrainGroup') # Create a saver. saver = tf.train.Saver() # Build the summary operation from the last tower summaries. summary_op = tf.summary.merge_all() # Start running operations on the Graph. allow_soft_placement must be # set to True to build towers on GPU, as some of the ops do not have GPU # implementations. config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True sess = tf.Session(config=config) # Build an initialization operation to run below. init = tf.global_variables_initializer() print('Initializing variables...') sess.run(init) print('Initialized variables.') start_step = 0 ckpt = tf.train.get_checkpoint_state(g_config['train_dir']) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) # Assuming model_checkpoint_path looks something like: # /ckpt/train/model.ckpt-0, # extract global_step from it. start_step = int( ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]) + 1 print('%s: Pre-trained model restored from %s' % (datetime.now(), g_config['train_dir'])) summary_writer = tf.summary.FileWriter(g_config['train_dir'], sess.graph) print('Starting training...') for step in range(start_step, g_config['max_steps']): start_time = time.time() _, loss_value = sess.run([train_op, tf_loss]) duration = time.time() - start_time assert not np.isnan(loss_value), 'Model diverged with loss = NaN' if step % 100 == 0: examples_per_sec = g_config['batch_size'] / float(duration) format_str = ( '%s: step %d, loss = %.4f (%.1f examples/sec; %.3f ' 'sec/batch)') print(format_str % (datetime.now(), step, loss_value, examples_per_sec, duration)) if step % 200 == 0: summary_str = sess.run(summary_op) summary_writer.add_summary(summary_str, step) # Save the model checkpoint periodically. if step % 1000 == 0 or (step + 1) == g_config['max_steps']: checkpoint_path = os.path.join(g_config['train_dir'], 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step)
def load_images(paths, group=None, verbose=True): """Loads and rescales input images to the diagonal of the reference shape. Args: paths: a list of strings containing the data directories. reference_shape: a numpy array [num_landmarks, 2] group: landmark group containing the grounth truth landmarks. verbose: boolean, print debugging info. Returns: images: a list of numpy arrays containing images. shapes: a list of the ground truth landmarks. reference_shape: a numpy array [num_landmarks, 2]. shape_gen: PCAModel, a shape generator. """ images = [] shapes = [] bbs = [] reference_shape = PointCloud(build_reference_shape(paths)) for path in paths: if verbose: print('Importing data from {}'.format(path)) for im in mio.import_images(path, verbose=verbose, as_generator=True): group = group or im.landmarks[group]._group_label bb_root = im.path.parent.relative_to(im.path.parent.parent.parent) if 'set' not in str(bb_root): bb_root = im.path.parent.relative_to(im.path.parent.parent) im.landmarks['bb'] = mio.import_landmark_file(str(Path( 'bbs') / bb_root / (im.path.stem + '.pts'))) im = im.crop_to_landmarks_proportion(0.3, group='bb') im = im.rescale_to_pointcloud(reference_shape, group=group) im = grey_to_rgb(im) images.append(im.pixels.transpose(1, 2, 0)) shapes.append(im.landmarks[group].lms) bbs.append(im.landmarks['bb'].lms) train_dir = Path(FLAGS.train_dir) mio.export_pickle(reference_shape.points, train_dir / 'reference_shape.pkl', overwrite=True) print('created reference_shape.pkl using the {} group'.format(group)) pca_model = detect.create_generator(shapes, bbs) # Pad images to max length max_shape = np.max([im.shape for im in images], axis=0) max_shape = [len(images)] + list(max_shape) padded_images = np.random.rand(*max_shape).astype(np.float32) print(padded_images.shape) for i, im in enumerate(images): height, width = im.shape[:2] dy = max(int((max_shape[1] - height - 1) / 2), 0) dx = max(int((max_shape[2] - width - 1) / 2), 0) lms = shapes[i] pts = lms.points pts[:, 0] += dy pts[:, 1] += dx lms = lms.from_vector(pts) padded_images[i, dy:(height+dy), dx:(width+dx)] = im return padded_images, shapes, reference_shape.points, pca_model