def train(scope=''): """Train on dataset for a number of steps.""" with tf.Graph().as_default() as graph, tf.device('/gpu:0'): # Global steps tf_global_step = tf.get_variable( 'GlobalStep', [], initializer=tf.constant_initializer(0), trainable=False) # Learning rate tf_lr = tf.train.exponential_decay(g_config['learning_rate'], tf_global_step, g_config['learning_rate_step'], g_config['learning_rate_decay'], staircase=True, name='LearningRate') tf.summary.scalar('learning_rate', tf_lr, collections=['train']) # Create an optimizer that performs gradient descent. opt = tf.train.AdamOptimizer(tf_lr) data_provider.prepare_images(g_config['train_dataset'].split(':'), num_patches=g_config['num_patches'], verbose=True) path_base = Path(g_config['train_dataset'].split(':')[0]).parent.parent _mean_shape = mio.import_pickle(path_base / 'reference_shape.pkl') _mean_shape = data_provider.align_reference_shape_to_112(_mean_shape) assert (isinstance(_mean_shape, np.ndarray)) assert (_mean_shape.shape[0] == g_config['num_patches']) tf_mean_shape = tf.constant(_mean_shape, dtype=tf.float32, name='MeanShape') def get_random_sample(image, shape, rotation_stddev=10): # Read a random image with landmarks and bb image = menpo.image.Image(image.transpose((2, 0, 1)), copy=False) image.landmarks['PTS'] = PointCloud(shape) if np.random.rand() < .5: image = utils.mirror_image(image) if np.random.rand() < .5: theta = np.random.normal(scale=rotation_stddev) rot = menpo.transform.rotate_ccw_about_centre( image.landmarks['PTS'], theta) image = image.warp_to_shape(image.shape, rot) bb = image.landmarks['PTS'].bounding_box().points miny, minx = np.min(bb, 0) maxy, maxx = np.max(bb, 0) bbsize = max(maxx - minx, maxy - miny) center = [(miny + maxy) / 2., (minx + maxx) / 2.] shift = (np.random.rand(2) - 0.5) * 0.6 * bbsize image.landmarks['bb'] = PointCloud([ [ center[0] - bbsize * 0.5 + shift[0], center[1] - bbsize * 0.5 + shift[1] ], [ center[0] + bbsize * 0.5 + shift[0], center[1] + bbsize * 0.5 + shift[1] ], ]).bounding_box() proportion = 1.0 / 6.0 + float(np.random.rand() - 0.5) / 6. image = image.crop_to_landmarks_proportion(proportion, group='bb') image = image.resize((112, 112)) random_image = image.pixels.transpose(1, 2, 0).astype('float32') random_shape = image.landmarks['PTS'].points.astype('float32') return random_image, random_shape def decode_feature_and_augment(serialized): feature = { 'train/image': tf.FixedLenFeature([], tf.string), 'train/shape': tf.VarLenFeature(tf.float32), } features = tf.parse_single_example(serialized, features=feature) decoded_image = tf.decode_raw(features['train/image'], tf.float32) decoded_image = tf.reshape(decoded_image, (448, 448, 3)) decoded_shape = tf.sparse.to_dense(features['train/shape']) decoded_shape = tf.reshape(decoded_shape, (g_config['num_patches'], 2)) random_image, random_shape = tf.py_func( get_random_sample, [decoded_image, decoded_shape], [tf.float32, tf.float32], stateful=True, name='RandomSample') return data_provider.distort_color(random_image), random_shape def decode_feature(serialized): feature = { 'validate/image': tf.FixedLenFeature([], tf.string), 'validate/shape': tf.VarLenFeature(tf.float32), } features = tf.parse_single_example(serialized, features=feature) decoded_image = tf.decode_raw(features['validate/image'], tf.float32) decoded_image = tf.reshape(decoded_image, (112, 112, 3)) decoded_shape = tf.sparse.to_dense(features['validate/shape']) decoded_shape = tf.reshape(decoded_shape, (g_config['num_patches'], 2)) return decoded_image, decoded_shape with tf.name_scope('DataProvider'): tf_dataset = tf.data.TFRecordDataset( [str(path_base / 'train.bin')]) tf_dataset = tf_dataset.repeat() tf_dataset = tf_dataset.map(decode_feature_and_augment, num_parallel_calls=5) tf_dataset = tf_dataset.batch(g_config['batch_size'], True) tf_dataset = tf_dataset.prefetch(1) tf_iterator = tf_dataset.make_one_shot_iterator() tf_images, tf_shapes = tf_iterator.get_next(name='Batch') tf_images.set_shape([g_config['batch_size'], 112, 112, 3]) tf_shapes.set_shape([g_config['batch_size'], 73, 2]) tf_dataset_v = tf.data.TFRecordDataset( [str(path_base / 'validate.bin')]) tf_dataset_v = tf_dataset_v.repeat() tf_dataset_v = tf_dataset_v.map(decode_feature, num_parallel_calls=5) tf_dataset_v = tf_dataset_v.batch(50, True) tf_dataset_v = tf_dataset_v.prefetch(1) tf_iterator_v = tf_dataset_v.make_one_shot_iterator() tf_images_v, tf_shapes_v = tf_iterator_v.get_next( name='ValidateBatch') tf_images_v.set_shape([50, 112, 112, 3]) tf_shapes_v.set_shape([50, 73, 2]) print('Defining model...') with tf.device(g_config['train_device']): tf_model = mdm_model.MDMModel(tf_images, tf_shapes, tf_mean_shape, batch_size=g_config['batch_size'], num_patches=g_config['num_patches'], num_channels=3) tf_grads = opt.compute_gradients(tf_model.nme) with tf.name_scope('Validate'): tf_model_v = mdm_model.MDMModel( tf_images_v, tf_shapes_v, tf_mean_shape, batch_size=50, num_patches=g_config['num_patches'], num_channels=3, is_training=False) tf.summary.histogram('dx', tf_model.prediction - tf_shapes, collections=['train']) bn_updates = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope) # Add histograms for gradients. for grad, var in tf_grads: if grad is not None: tf.summary.histogram(var.op.name + '/gradients', grad, collections=['train']) # Apply the gradients to adjust the shared variables. with tf.name_scope('Optimizer', values=[tf_grads, tf_global_step]): apply_gradient_op = opt.apply_gradients(tf_grads, global_step=tf_global_step) # Add histograms for trainable variables. for var in tf.trainable_variables(): tf.summary.histogram(var.op.name, var, collections=['train']) with tf.name_scope('MovingAverage', values=[tf_global_step]): variable_averages = tf.train.ExponentialMovingAverage( g_config['MOVING_AVERAGE_DECAY'], tf_global_step) variables_to_average = (tf.trainable_variables() + tf.moving_average_variables()) variables_averages_op = variable_averages.apply( variables_to_average) # Group all updates to into a single train op. bn_updates_op = tf.group(*bn_updates, name='BNGroup') train_op = tf.group(apply_gradient_op, variables_averages_op, bn_updates_op, name='TrainGroup') # Create a saver. saver = tf.train.Saver() train_summary_op = tf.summary.merge_all('train') validate_summary_op = tf.summary.merge_all('validate') config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True sess = tf.Session(graph=graph, config=config) init = tf.global_variables_initializer() print('Initializing variables...') sess.run(init) print('Initialized variables.') # Assuming model_checkpoint_path looks something like: # /ckpt/train/model.ckpt-0, # extract global_step from it. start_step = 0 ckpt = tf.train.get_checkpoint_state(g_config['train_dir']) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) start_step = int( ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]) + 1 print('%s: Restart from %s' % (datetime.now(), g_config['train_dir'])) else: ckpt = tf.train.get_checkpoint_state(g_config['ckpt_dir']) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) tf_global_step_op = tf_global_step.assign(0) sess.run(tf_global_step_op) print('%s: Pre-trained model restored from %s' % (datetime.now(), g_config['ckpt_dir'])) train_writer = tf.summary.FileWriter(g_config['train_dir'] + '/train', sess.graph) validate_writer = tf.summary.FileWriter( g_config['train_dir'] + '/validate', sess.graph) print('Starting training...') steps_per_epoch = 15000 / g_config['batch_size'] for step in range(start_step, g_config['max_steps']): if step % steps_per_epoch == 0: start_time = time.time() _, train_loss, train_summary = sess.run( [train_op, tf_model.nme, train_summary_op]) duration = time.time() - start_time validate_loss, validate_summary = sess.run( [tf_model_v.nme, validate_summary_op]) train_writer.add_summary(train_summary, step) validate_writer.add_summary(validate_summary, step) print('%s: step %d, loss = %.4f (%.3f sec/batch)' % (datetime.now(), step, train_loss, duration)) print('%s: step %d, validate loss = %.4f' % (datetime.now(), step, validate_loss)) else: start_time = time.time() _, train_loss = sess.run([train_op, tf_model.nme]) duration = time.time() - start_time if step % 100 == 0: print('%s: step %d, loss = %.4f (%.3f sec/batch)' % (datetime.now(), step, train_loss, duration)) assert not np.isnan(train_loss), 'Model diverged with loss = NaN' if step % steps_per_epoch == 0 or (step + 1) == g_config['max_steps']: checkpoint_path = os.path.join(g_config['train_dir'], 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step)
def train(scope=''): """Train on dataset for a number of steps.""" with tf.Graph().as_default() as graph, tf.device('/gpu:0'): # Global steps tf_global_step = tf.get_variable( 'GlobalStep', [], initializer=tf.constant_initializer(0), trainable=False ) # Learning rate tf_lr = tf.train.exponential_decay( g_config['learning_rate'], tf_global_step, g_config['learning_rate_step'], g_config['learning_rate_decay'], staircase=True, name='LearningRate' ) tf.summary.scalar('learning_rate', tf_lr, collections=['train']) # Create an optimizer that performs gradient descent. opt = tf.train.AdamOptimizer(tf_lr) data_provider.prepare_images( g_config['train_dataset'].split(':'), num_patches=g_config['num_patches'], verbose=True ) path_base = Path(g_config['train_dataset'].split(':')[0]).parent.parent _mean_shape = mio.import_pickle(path_base / 'mean_shape.pkl') _mean_shape = data_provider.align_reference_shape_to_112(_mean_shape) assert(isinstance(_mean_shape, np.ndarray)) assert(_mean_shape.shape[0] == g_config['num_patches']) _negatives = [] for mp_image in mio.import_images('Dataset/Neg/*.png', verbose=True): _negatives.append(mp_image.pixels.transpose(1, 2, 0).astype(np.float32)) _num_negatives = len(_negatives) tf_mean_shape = tf.constant(_mean_shape, dtype=tf.float32, name='MeanShape') def get_random_sample(image, shape): # Occlude _O_AREA = 0.15 _O_MIN_H = 0.15 _O_MAX_H = 1.0 if np.random.rand() < .3: rh = min(112, int((np.random.rand() * (_O_MAX_H - _O_MIN_H) + _O_MIN_H) * 112)) rw = min(112, int(12544 * _O_AREA / rh)) dy = int(np.random.rand() * (112 - rh)) dx = int(np.random.rand() * (112 - rw)) idx = int(np.random.rand() * _num_negatives) image[dy:dy+rh, dx:dx+rw] = np.minimum( 1.0, _negatives[idx][dy:dy+rh, dx:dx+rw] ) return image, shape def decode_feature_and_augment(serialized): feature = { 'train/image': tf.FixedLenFeature([], tf.string), 'train/shape': tf.VarLenFeature(tf.float32), } features = tf.parse_single_example(serialized, features=feature) decoded_image = tf.decode_raw(features['train/image'], tf.float32) decoded_image = tf.reshape(decoded_image, (112, 112, 3)) decoded_shape = tf.sparse.to_dense(features['train/shape']) decoded_shape = tf.reshape(decoded_shape, (g_config['num_patches'], 2)) #decoded_image, decoded_shape = tf.py_func( # get_random_sample, [decoded_image, decoded_shape], [tf.float32, tf.float32], # stateful=True, # name='RandomSample' #) return data_provider.distort_color(decoded_image), decoded_shape def decode_feature(serialized): feature = { 'validate/image': tf.FixedLenFeature([], tf.string), 'validate/shape': tf.VarLenFeature(tf.float32), } features = tf.parse_single_example(serialized, features=feature) decoded_image = tf.decode_raw(features['validate/image'], tf.float32) decoded_image = tf.reshape(decoded_image, (112, 112, 3)) decoded_shape = tf.sparse.to_dense(features['validate/shape']) decoded_shape = tf.reshape(decoded_shape, (g_config['num_patches'], 2)) return decoded_image, decoded_shape with tf.name_scope('DataProvider'): tf_dataset = tf.data.TFRecordDataset([ str(path_base / 'train_0.bin'), str(path_base / 'train_1.bin'), str(path_base / 'train_2.bin'), str(path_base / 'train_3.bin') ]) tf_dataset = tf_dataset.repeat() tf_dataset = tf_dataset.map(decode_feature_and_augment, num_parallel_calls=5) tf_dataset = tf_dataset.shuffle(480) tf_dataset = tf_dataset.batch(g_config['batch_size'], True) tf_dataset = tf_dataset.prefetch(1) tf_iterator = tf_dataset.make_one_shot_iterator() tf_images, tf_shapes = tf_iterator.get_next(name='Batch') tf_images.set_shape([g_config['batch_size'], 112, 112, 3]) tf_shapes.set_shape([g_config['batch_size'], 75, 2]) tf_dataset_v = tf.data.TFRecordDataset([str(path_base / 'validate.bin')]) tf_dataset_v = tf_dataset_v.repeat() tf_dataset_v = tf_dataset_v.map(decode_feature, num_parallel_calls=5) tf_dataset_v = tf_dataset_v.batch(50, True) tf_dataset_v = tf_dataset_v.prefetch(1) tf_iterator_v = tf_dataset_v.make_one_shot_iterator() tf_images_v, tf_shapes_v = tf_iterator_v.get_next(name='ValidateBatch') tf_images_v.set_shape([50, 112, 112, 3]) tf_shapes_v.set_shape([50, 75, 2]) print('Defining model...') with tf.device(g_config['train_device']): tf_model = mdm_model.MDMModel( tf_images, tf_shapes, tf_mean_shape, batch_size=g_config['batch_size'], num_patches=g_config['num_patches'], num_channels=3, multiplier=g_config['multiplier'] ) tf_grads = opt.compute_gradients(tf_model.nme) with tf.name_scope('Validate'): tf_model_v = mdm_model.MDMModel( tf_images_v, tf_shapes_v, tf_mean_shape, batch_size=50, num_patches=g_config['num_patches'], num_channels=3, multiplier=g_config['multiplier'], is_training=False ) tf.summary.histogram('dx', tf_model.prediction - tf_shapes, collections=['train']) bn_updates = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope) # Add histograms for gradients. for grad, var in tf_grads: if grad is not None: tf.summary.histogram(var.op.name + '/gradients', grad, collections=['train']) # Apply the gradients to adjust the shared variables. with tf.name_scope('Optimizer', values=[tf_grads, tf_global_step]): apply_gradient_op = opt.apply_gradients(tf_grads, global_step=tf_global_step) # Add histograms for trainable variables. for var in tf.trainable_variables(): tf.summary.histogram(var.op.name, var, collections=['train']) with tf.name_scope('MovingAverage', values=[tf_global_step]): variable_averages = tf.train.ExponentialMovingAverage(g_config['MOVING_AVERAGE_DECAY'], tf_global_step) variables_to_average = (tf.trainable_variables() + tf.moving_average_variables()) variables_averages_op = variable_averages.apply(variables_to_average) # Group all updates to into a single train op. bn_updates_op = tf.group(*bn_updates, name='BNGroup') train_op = tf.group( apply_gradient_op, variables_averages_op, bn_updates_op, name='TrainGroup' ) # Create a saver. saver = tf.train.Saver() train_summary_op = tf.summary.merge_all('train') validate_summary_op = tf.summary.merge_all('validate') config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True sess = tf.Session(graph=graph, config=config) init = tf.global_variables_initializer() print('Initializing variables...') sess.run(init) print('Initialized variables.') # Assuming model_checkpoint_path looks something like: # /ckpt/train/model.ckpt-0, # extract global_step from it. start_step = 0 ckpt = tf.train.get_checkpoint_state(g_config['train_dir']) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) start_step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]) + 1 print('%s: Restart from %s' % (datetime.now(), g_config['train_dir'])) else: ckpt = tf.train.get_checkpoint_state(g_config['ckpt_dir']) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) tf_global_step_op = tf_global_step.assign(0) sess.run(tf_global_step_op) print('%s: Pre-trained model restored from %s' % (datetime.now(), g_config['ckpt_dir'])) train_writer = tf.summary.FileWriter(g_config['train_dir'] + '/train', sess.graph) validate_writer = tf.summary.FileWriter(g_config['train_dir'] + '/validate', sess.graph) print('Starting training...') steps_per_epoch = 15000 / g_config['batch_size'] for step in range(start_step, g_config['max_steps']): if step % steps_per_epoch == 0: start_time = time.time() _, train_loss, train_summary = sess.run([train_op, tf_model.nme, train_summary_op]) duration = time.time() - start_time validate_loss, validate_summary = sess.run([tf_model_v.nme, validate_summary_op]) train_writer.add_summary(train_summary, step) validate_writer.add_summary(validate_summary, step) print( '%s: step %d, loss = %.4f (%.3f sec/batch)' % ( datetime.now(), step, train_loss, duration ) ) print( '%s: step %d, validate loss = %.4f' % ( datetime.now(), step, validate_loss ) ) else: start_time = time.time() _, train_loss = sess.run([train_op, tf_model.nme]) duration = time.time() - start_time if step % 100 == 0: print( '%s: step %d, loss = %.4f (%.3f sec/batch)' % ( datetime.now(), step, train_loss, duration ) ) assert not np.isnan(train_loss), 'Model diverged with loss = NaN' if step % steps_per_epoch == 0 or (step + 1) == g_config['max_steps']: checkpoint_path = os.path.join(g_config['train_dir'], 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step)
def train(scope=''): """Train on dataset for a number of steps.""" with tf.Graph().as_default(), tf.device('/gpu:0'): # Global steps tf_global_step = tf.get_variable( 'GlobalStep', [], initializer=tf.constant_initializer(0), trainable=False) # Learning rate tf_lr = tf.train.exponential_decay(g_config['learning_rate'], tf_global_step, g_config['learning_rate_step'], g_config['learning_rate_decay'], staircase=True, name='LearningRate') tf.summary.scalar('learning_rate', tf_lr) # Create an optimizer that performs gradient descent. opt = tf.train.AdamOptimizer(tf_lr) data_provider.prepare_images(g_config['train_dataset'].split(':'), num_patches=g_config['num_patches'], verbose=True) path_base = Path(g_config['train_dataset'].split(':')[0]).parent.parent _mean_shape = mio.import_pickle(path_base / 'reference_shape.pkl') with Path(path_base / 'meta.txt').open('r') as ifs: _image_shape = [int(x) for x in ifs.read().split(' ')] assert (isinstance(_mean_shape, np.ndarray)) _pca_shapes = [] _pca_bbs = [] for item in tf.io.tf_record_iterator(str(path_base / 'pca.bin')): example = tf.train.Example() example.ParseFromString(item) _pca_shape = np.array(example.features.feature['pca/shape']. float_list.value).reshape((-1, 2)) _pca_bb = np.array( example.features.feature['pca/bb'].float_list.value).reshape( (-1, 2)) _pca_shapes.append(PointCloud(_pca_shape)) _pca_bbs.append(PointCloud(_pca_bb)) _pca_model = detect.create_generator(_pca_shapes, _pca_bbs) assert (_mean_shape.shape[0] == g_config['num_patches']) tf_mean_shape = tf.constant(_mean_shape, dtype=tf.float32, name='MeanShape') def decode_feature(serialized): feature = { 'train/image': tf.FixedLenFeature([], tf.string), 'train/shape': tf.VarLenFeature(tf.float32), } features = tf.parse_single_example(serialized, features=feature) decoded_image = tf.decode_raw(features['train/image'], tf.float32) decoded_image = tf.reshape(decoded_image, _image_shape) decoded_shape = tf.sparse.to_dense(features['train/shape']) decoded_shape = tf.reshape(decoded_shape, (g_config['num_patches'], 2)) return decoded_image, decoded_shape def get_random_sample(image, shape, rotation_stddev=10): # Read a random image with landmarks and bb image = menpo.image.Image(image.transpose((2, 0, 1)), copy=False) image.landmarks['PTS'] = PointCloud(shape) if np.random.rand() < .5: image = utils.mirror_image(image) if np.random.rand() < .5: theta = np.random.normal(scale=rotation_stddev) rot = menpo.transform.rotate_ccw_about_centre( image.landmarks['PTS'], theta) image = image.warp_to_shape(image.shape, rot) bb = image.landmarks['PTS'].bounding_box().points miny, minx = np.min(bb, 0) maxy, maxx = np.max(bb, 0) bbsize = max(maxx - minx, maxy - miny) center = [(miny + maxy) / 2., (minx + maxx) / 2.] image.landmarks['bb'] = PointCloud([ [center[0] - bbsize * 0.5, center[1] - bbsize * 0.5], [center[0] + bbsize * 0.5, center[1] + bbsize * 0.5], ]).bounding_box() proportion = float(np.random.rand() / 3) image = image.crop_to_landmarks_proportion(proportion, group='bb') image = image.resize((112, 112)) random_image = image.pixels.transpose(1, 2, 0).astype('float32') random_shape = image.landmarks['PTS'].points.astype('float32') return random_image, random_shape def get_init_shape(image, shape, mean_shape): def norm(x): return tf.sqrt( tf.reduce_sum(tf.square(x - tf.reduce_mean(x, 0)))) with tf.name_scope('align_shape_to_bb', values=[mean_shape]): min_xy = tf.reduce_min(mean_shape, 0) max_xy = tf.reduce_max(mean_shape, 0) min_x, min_y = min_xy[0], min_xy[1] max_x, max_y = max_xy[0], max_xy[1] mean_shape_bb = tf.stack([[min_x, min_y], [max_x, min_y], [max_x, max_y], [min_x, max_y]]) bb = tf.stack([[0.0, 0.0], [112.0, 0.0], [112.0, 112.0], [0.0, 112.0]]) ratio = norm(bb) / norm(mean_shape_bb) initial_shape = tf.add( (mean_shape - tf.reduce_mean(mean_shape_bb, 0)) * ratio, tf.reduce_mean(bb, 0), name='initial_shape') initial_shape.set_shape(tf_mean_shape.get_shape()) return image, shape, initial_shape def distort_color(image, shape, init_shape): return data_provider.distort_color(image), shape, init_shape with tf.name_scope('DataProvider', values=[tf_mean_shape]): tf_dataset = tf.data.TFRecordDataset( [str(path_base / 'train.bin')]) tf_dataset = tf_dataset.repeat() tf_dataset = tf_dataset.map(decode_feature) tf_dataset = tf_dataset.map(lambda x, y: tf.py_func( get_random_sample, [x, y], [tf.float32, tf.float32], stateful=True, name='RandomSample')) tf_dataset = tf_dataset.map( partial(get_init_shape, mean_shape=tf_mean_shape)) tf_dataset = tf_dataset.map(distort_color) tf_dataset = tf_dataset.batch(g_config['batch_size'], True) tf_dataset = tf_dataset.prefetch(7500) tf_iterator = tf_dataset.make_one_shot_iterator() tf_images, tf_shapes, tf_initial_shapes = tf_iterator.get_next( name='Batch') tf_images.set_shape([g_config['batch_size'], 112, 112, 3]) tf_shapes.set_shape([g_config['batch_size'], 73, 2]) tf_initial_shapes.set_shape([g_config['batch_size'], 73, 2]) print('Defining model...') with tf.device(g_config['train_device']): tf_model = mdm_model.MDMModel( tf_images, tf_shapes, tf_initial_shapes, batch_size=g_config['batch_size'], num_iterations=g_config['num_iterations'], num_patches=g_config['num_patches'], patch_shape=(g_config['patch_size'], g_config['patch_size']), num_channels=3) with tf.name_scope('Losses', values=[tf_model.prediction, tf_shapes]): tf_norm_error = tf_model.normalized_rmse( tf_model.prediction, tf_shapes) tf_loss = tf.reduce_mean(tf_norm_error) tf.summary.scalar('losses/total', tf_loss) # Calculate the gradients for the batch of data tf_grads = opt.compute_gradients(tf_loss) tf.summary.histogram('dx', tf_model.prediction - tf_shapes) bn_updates = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope) # Add histograms for gradients. for grad, var in tf_grads: if grad is not None: tf.summary.histogram(var.op.name + '/gradients', grad) # Apply the gradients to adjust the shared variables. with tf.name_scope('Optimizer', values=[tf_grads, tf_global_step]): apply_gradient_op = opt.apply_gradients(tf_grads, global_step=tf_global_step) # Add histograms for trainable variables. for var in tf.trainable_variables(): tf.summary.histogram(var.op.name, var) # Track the moving averages of all trainable variables. # Note that we maintain a "double-average" of the BatchNormalization # global statistics. This is more complicated then need be but we employ # this for backward-compatibility with our previous models. with tf.name_scope('MovingAverage', values=[tf_global_step]): variable_averages = tf.train.ExponentialMovingAverage( g_config['MOVING_AVERAGE_DECAY'], tf_global_step) variables_to_average = (tf.trainable_variables() + tf.moving_average_variables()) variables_averages_op = variable_averages.apply( variables_to_average) # Group all updates to into a single train op. bn_updates_op = tf.group(*bn_updates, name='BNGroup') train_op = tf.group(apply_gradient_op, variables_averages_op, bn_updates_op, name='TrainGroup') # Create a saver. saver = tf.train.Saver() # Build the summary operation from the last tower summaries. summary_op = tf.summary.merge_all() # Start running operations on the Graph. allow_soft_placement must be # set to True to build towers on GPU, as some of the ops do not have GPU # implementations. config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True sess = tf.Session(config=config) # Build an initialization operation to run below. init = tf.global_variables_initializer() print('Initializing variables...') sess.run(init) print('Initialized variables.') start_step = 0 ckpt = tf.train.get_checkpoint_state(g_config['train_dir']) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) # Assuming model_checkpoint_path looks something like: # /ckpt/train/model.ckpt-0, # extract global_step from it. start_step = int( ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]) + 1 print('%s: Pre-trained model restored from %s' % (datetime.now(), g_config['train_dir'])) summary_writer = tf.summary.FileWriter(g_config['train_dir'], sess.graph) print('Starting training...') for step in range(start_step, g_config['max_steps']): start_time = time.time() _, loss_value = sess.run([train_op, tf_loss]) duration = time.time() - start_time assert not np.isnan(loss_value), 'Model diverged with loss = NaN' if step % 100 == 0: examples_per_sec = g_config['batch_size'] / float(duration) format_str = ( '%s: step %d, loss = %.4f (%.1f examples/sec; %.3f ' 'sec/batch)') print(format_str % (datetime.now(), step, loss_value, examples_per_sec, duration)) if step % 200 == 0: summary_str = sess.run(summary_op) summary_writer.add_summary(summary_str, step) # Save the model checkpoint periodically. if step % 1000 == 0 or (step + 1) == g_config['max_steps']: checkpoint_path = os.path.join(g_config['train_dir'], 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step)
def train(scope=''): """Train on dataset for a number of steps.""" with tf.Graph().as_default() as graph, tf.device('/gpu:0'): # Global steps tf_global_step = tf.get_variable( 'GlobalStep', [], initializer=tf.constant_initializer(0), trainable=False) # Learning rate tf_lr = tf.train.exponential_decay(g_config['learning_rate'], tf_global_step, g_config['learning_rate_step'], g_config['learning_rate_decay'], staircase=True, name='LearningRate') tf.summary.scalar('learning_rate', tf_lr) # Create an optimizer that performs gradient descent. opt = tf.train.AdamOptimizer(tf_lr) data_provider.prepare_images(g_config['train_dataset'].split(':'), num_patches=g_config['num_patches'], verbose=True) path_base = Path(g_config['train_dataset'].split(':')[0]).parent.parent _mean_shape = mio.import_pickle(path_base / 'reference_shape.pkl') _mean_shape = data_provider.align_reference_shape_to_112(_mean_shape) assert (isinstance(_mean_shape, np.ndarray)) assert (_mean_shape.shape[0] == g_config['num_patches']) tf_mean_shape = tf.constant(_mean_shape, dtype=tf.float32, name='MeanShape') def decode_feature(serialized): feature = { 'train/image': tf.FixedLenFeature([], tf.string), 'train/shape': tf.VarLenFeature(tf.float32), } features = tf.parse_single_example(serialized, features=feature) decoded_image = tf.decode_raw(features['train/image'], tf.float32) decoded_image = tf.reshape(decoded_image, (336, 336, 3)) decoded_shape = tf.sparse.to_dense(features['train/shape']) decoded_shape = tf.reshape(decoded_shape, (g_config['num_patches'], 2)) return decoded_image, decoded_shape def get_random_sample(image, shape, rotation_stddev=10): # Read a random image with landmarks and bb image = menpo.image.Image(image.transpose((2, 0, 1)), copy=False) image.landmarks['PTS'] = PointCloud(shape) if np.random.rand() < .5: image = utils.mirror_image(image) if np.random.rand() < .5: theta = np.random.normal(scale=rotation_stddev) rot = menpo.transform.rotate_ccw_about_centre( image.landmarks['PTS'], theta) image = image.warp_to_shape(image.shape, rot) bb = image.landmarks['PTS'].bounding_box().points miny, minx = np.min(bb, 0) maxy, maxx = np.max(bb, 0) bbsize = max(maxx - minx, maxy - miny) center = [(miny + maxy) / 2., (minx + maxx) / 2.] shift = (np.random.rand(2) - 0.5) / 6. * bbsize image.landmarks['bb'] = PointCloud([ [ center[0] - bbsize * 0.5 + shift[0], center[1] - bbsize * 0.5 + shift[1] ], [ center[0] + bbsize * 0.5 + shift[0], center[1] + bbsize * 0.5 + shift[1] ], ]).bounding_box() proportion = 1.0 / 6.0 + float(np.random.rand() - 0.5) / 10.0 image = image.crop_to_landmarks_proportion(proportion, group='bb') image = image.resize((112, 112)) random_image = image.pixels.transpose(1, 2, 0).astype('float32') random_shape = image.landmarks['PTS'].points.astype('float32') return random_image, random_shape def distort_color(image, shape): return data_provider.distort_color(image), shape with tf.name_scope('DataProvider'): tf_dataset = tf.data.TFRecordDataset( [str(path_base / 'train.bin')]) tf_dataset = tf_dataset.repeat() tf_dataset = tf_dataset.map(decode_feature) tf_dataset = tf_dataset.map(lambda x, y: tf.py_func( get_random_sample, [x, y], [tf.float32, tf.float32], stateful=True, name='RandomSample')) tf_dataset = tf_dataset.map(distort_color) tf_dataset = tf_dataset.batch(g_config['batch_size'], True) tf_dataset = tf_dataset.prefetch(7500) tf_iterator = tf_dataset.make_one_shot_iterator() tf_images, tf_shapes = tf_iterator.get_next(name='Batch') tf_images.set_shape([g_config['batch_size'], 112, 112, 3]) tf_shapes.set_shape([g_config['batch_size'], 73, 2]) print('Defining model...') with tf.device(g_config['train_device']): tf_model = mdm_model.MDMModel(tf_images, tf_shapes, tf_mean_shape, batch_size=g_config['batch_size'], num_patches=g_config['num_patches'], num_channels=3) with tf.name_scope('Losses', values=[tf_model.prediction, tf_shapes]): tf_norm_error = tf_model.normalized_rmse( tf_model.prediction, tf_shapes) tf_loss = tf.reduce_mean(tf_norm_error) tf.summary.scalar('losses/total', tf_loss) # Calculate the gradients for the batch of data tf_grads = opt.compute_gradients(tf_loss) tf.summary.histogram('dx', tf_model.prediction - tf_shapes) bn_updates = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope) # Add histograms for gradients. for grad, var in tf_grads: if grad is not None: tf.summary.histogram(var.op.name + '/gradients', grad) # Apply the gradients to adjust the shared variables. with tf.name_scope('Optimizer', values=[tf_grads, tf_global_step]): apply_gradient_op = opt.apply_gradients(tf_grads, global_step=tf_global_step) # Add histograms for trainable variables. for var in tf.trainable_variables(): tf.summary.histogram(var.op.name, var) with tf.name_scope('MovingAverage', values=[tf_global_step]): variable_averages = tf.train.ExponentialMovingAverage( g_config['MOVING_AVERAGE_DECAY'], tf_global_step) variables_to_average = (tf.trainable_variables() + tf.moving_average_variables()) variables_averages_op = variable_averages.apply( variables_to_average) # Group all updates to into a single train op. bn_updates_op = tf.group(*bn_updates, name='BNGroup') train_op = tf.group(apply_gradient_op, variables_averages_op, bn_updates_op, name='TrainGroup') # Create a saver. saver = tf.train.Saver() # Build the summary operation from the last tower summaries. summary_op = tf.summary.merge_all() # Start running operations on the Graph. allow_soft_placement must be # set to True to build towers on GPU, as some of the ops do not have GPU # implementations. config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True sess = tf.Session(config=config) # Build an initialization operation to run below. init = tf.global_variables_initializer() print('Initializing variables...') sess.run(init) print('Initialized variables.') # Assuming model_checkpoint_path looks something like: # /ckpt/train/model.ckpt-0, # extract global_step from it. start_step = 0 ckpt = tf.train.get_checkpoint_state(g_config['ckpt_dir']) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) print('%s: Pre-trained model restored from %s' % (datetime.now(), g_config['ckpt_dir'])) else: ckpt = tf.train.get_checkpoint_state(g_config['train_dir']) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) start_step = int( ckpt.model_checkpoint_path.split('/')[-1].split('-') [-1]) + 1 print('%s: Restart from %s' % (datetime.now(), g_config['train_dir'])) elif TUNE: assign_op = [] vvv = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) cnt = 0 for v in vvv: if v.name in tf_model.var_map: assign_op.append( v.assign( graph.get_tensor_by_name( tf_model.var_map[v.name]))) cnt += 1 else: print(v.name) sess.run(assign_op) print('%s: Tune from graph.pb %d/%d' % (datetime.now(), cnt, len(vvv))) summary_writer = tf.summary.FileWriter(g_config['train_dir'], sess.graph) print('Starting training...') for step in range(start_step, g_config['max_steps']): start_time = time.time() _, loss_value = sess.run([train_op, tf_loss]) duration = time.time() - start_time assert not np.isnan(loss_value), 'Model diverged with loss = NaN' if step % 100 == 0: examples_per_sec = g_config['batch_size'] / float(duration) format_str = ( '%s: step %d, loss = %.4f (%.1f examples/sec; %.3f ' 'sec/batch)') print(format_str % (datetime.now(), step, loss_value, examples_per_sec, duration)) if step % 200 == 0: summary_str = sess.run(summary_op) summary_writer.add_summary(summary_str, step) # Save the model checkpoint periodically. if step % 1000 == 0 or (step + 1) == g_config['max_steps']: checkpoint_path = os.path.join(g_config['train_dir'], 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step)