def ckpt_pb(pb_path):
    tf.reset_default_graph()
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    with tf.Session() as sess:
        path_base = Path(g_config['eval_dataset']).parent.parent
        _mean_shape = mio.import_pickle(path_base / 'reference_shape.pkl')
        _mean_shape = data_provider.align_reference_shape_to_112(_mean_shape)
        assert isinstance(_mean_shape, np.ndarray)
        print(_mean_shape.shape)

        tf_img = tf.placeholder(dtype=tf.float32,
                                shape=(1, 112, 112, 3),
                                name='inputs/input_img')
        tf_dummy = tf.placeholder(dtype=tf.float32,
                                  shape=(1, 73, 2),
                                  name='inputs/input_shape')
        tf_shape = tf.constant(_mean_shape,
                               dtype=tf.float32,
                               shape=(73, 2),
                               name='MeanShape')

        model = mdm_model.MDMModel(tf_img,
                                   tf_dummy,
                                   tf_shape,
                                   batch_size=1,
                                   num_patches=g_config['num_patches'],
                                   num_channels=3,
                                   is_training=False)

        saver = tf.train.Saver()
        ckpt = tf.train.get_checkpoint_state(g_config['train_dir'])
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            global_step = ckpt.model_checkpoint_path.split('/')[-1].split(
                '-')[-1]
            print('Successfully loaded model from {} at step={}.'.format(
                ckpt.model_checkpoint_path, global_step))
        else:
            print('No checkpoint file found')
            return

        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess,
            tf.get_default_graph().as_graph_def(), ['Network/Predict/add'])

        with tf.gfile.FastGFile(pb_path, mode='wb') as f:
            f.write(output_graph_def.SerializeToString())
Beispiel #2
0
def train(scope=''):
    """Train on dataset for a number of steps."""
    with tf.Graph().as_default() as graph, tf.device('/gpu:0'):
        # Global steps
        tf_global_step = tf.get_variable(
            'GlobalStep', [],
            initializer=tf.constant_initializer(0),
            trainable=False)

        # Learning rate
        tf_lr = tf.train.exponential_decay(g_config['learning_rate'],
                                           tf_global_step,
                                           g_config['learning_rate_step'],
                                           g_config['learning_rate_decay'],
                                           staircase=True,
                                           name='LearningRate')
        tf.summary.scalar('learning_rate', tf_lr, collections=['train'])

        # Create an optimizer that performs gradient descent.
        opt = tf.train.AdamOptimizer(tf_lr)

        data_provider.prepare_images(g_config['train_dataset'].split(':'),
                                     num_patches=g_config['num_patches'],
                                     verbose=True)
        path_base = Path(g_config['train_dataset'].split(':')[0]).parent.parent
        _mean_shape = mio.import_pickle(path_base / 'reference_shape.pkl')
        _mean_shape = data_provider.align_reference_shape_to_112(_mean_shape)
        assert (isinstance(_mean_shape, np.ndarray))
        assert (_mean_shape.shape[0] == g_config['num_patches'])

        tf_mean_shape = tf.constant(_mean_shape,
                                    dtype=tf.float32,
                                    name='MeanShape')

        def get_random_sample(image, shape, rotation_stddev=10):
            # Read a random image with landmarks and bb
            image = menpo.image.Image(image.transpose((2, 0, 1)), copy=False)
            image.landmarks['PTS'] = PointCloud(shape)

            if np.random.rand() < .5:
                image = utils.mirror_image(image)
            if np.random.rand() < .5:
                theta = np.random.normal(scale=rotation_stddev)
                rot = menpo.transform.rotate_ccw_about_centre(
                    image.landmarks['PTS'], theta)
                image = image.warp_to_shape(image.shape, rot)
            bb = image.landmarks['PTS'].bounding_box().points
            miny, minx = np.min(bb, 0)
            maxy, maxx = np.max(bb, 0)
            bbsize = max(maxx - minx, maxy - miny)
            center = [(miny + maxy) / 2., (minx + maxx) / 2.]
            shift = (np.random.rand(2) - 0.5) * 0.6 * bbsize
            image.landmarks['bb'] = PointCloud([
                [
                    center[0] - bbsize * 0.5 + shift[0],
                    center[1] - bbsize * 0.5 + shift[1]
                ],
                [
                    center[0] + bbsize * 0.5 + shift[0],
                    center[1] + bbsize * 0.5 + shift[1]
                ],
            ]).bounding_box()
            proportion = 1.0 / 6.0 + float(np.random.rand() - 0.5) / 6.
            image = image.crop_to_landmarks_proportion(proportion, group='bb')
            image = image.resize((112, 112))

            random_image = image.pixels.transpose(1, 2, 0).astype('float32')
            random_shape = image.landmarks['PTS'].points.astype('float32')
            return random_image, random_shape

        def decode_feature_and_augment(serialized):
            feature = {
                'train/image': tf.FixedLenFeature([], tf.string),
                'train/shape': tf.VarLenFeature(tf.float32),
            }
            features = tf.parse_single_example(serialized, features=feature)
            decoded_image = tf.decode_raw(features['train/image'], tf.float32)
            decoded_image = tf.reshape(decoded_image, (448, 448, 3))
            decoded_shape = tf.sparse.to_dense(features['train/shape'])
            decoded_shape = tf.reshape(decoded_shape,
                                       (g_config['num_patches'], 2))

            random_image, random_shape = tf.py_func(
                get_random_sample, [decoded_image, decoded_shape],
                [tf.float32, tf.float32],
                stateful=True,
                name='RandomSample')
            return data_provider.distort_color(random_image), random_shape

        def decode_feature(serialized):
            feature = {
                'validate/image': tf.FixedLenFeature([], tf.string),
                'validate/shape': tf.VarLenFeature(tf.float32),
            }
            features = tf.parse_single_example(serialized, features=feature)
            decoded_image = tf.decode_raw(features['validate/image'],
                                          tf.float32)
            decoded_image = tf.reshape(decoded_image, (112, 112, 3))
            decoded_shape = tf.sparse.to_dense(features['validate/shape'])
            decoded_shape = tf.reshape(decoded_shape,
                                       (g_config['num_patches'], 2))
            return decoded_image, decoded_shape

        with tf.name_scope('DataProvider'):
            tf_dataset = tf.data.TFRecordDataset(
                [str(path_base / 'train.bin')])
            tf_dataset = tf_dataset.repeat()
            tf_dataset = tf_dataset.map(decode_feature_and_augment,
                                        num_parallel_calls=5)
            tf_dataset = tf_dataset.batch(g_config['batch_size'], True)
            tf_dataset = tf_dataset.prefetch(1)
            tf_iterator = tf_dataset.make_one_shot_iterator()
            tf_images, tf_shapes = tf_iterator.get_next(name='Batch')
            tf_images.set_shape([g_config['batch_size'], 112, 112, 3])
            tf_shapes.set_shape([g_config['batch_size'], 73, 2])

            tf_dataset_v = tf.data.TFRecordDataset(
                [str(path_base / 'validate.bin')])
            tf_dataset_v = tf_dataset_v.repeat()
            tf_dataset_v = tf_dataset_v.map(decode_feature,
                                            num_parallel_calls=5)
            tf_dataset_v = tf_dataset_v.batch(50, True)
            tf_dataset_v = tf_dataset_v.prefetch(1)
            tf_iterator_v = tf_dataset_v.make_one_shot_iterator()
            tf_images_v, tf_shapes_v = tf_iterator_v.get_next(
                name='ValidateBatch')
            tf_images_v.set_shape([50, 112, 112, 3])
            tf_shapes_v.set_shape([50, 73, 2])

        print('Defining model...')
        with tf.device(g_config['train_device']):
            tf_model = mdm_model.MDMModel(tf_images,
                                          tf_shapes,
                                          tf_mean_shape,
                                          batch_size=g_config['batch_size'],
                                          num_patches=g_config['num_patches'],
                                          num_channels=3)
            tf_grads = opt.compute_gradients(tf_model.nme)
            with tf.name_scope('Validate'):
                tf_model_v = mdm_model.MDMModel(
                    tf_images_v,
                    tf_shapes_v,
                    tf_mean_shape,
                    batch_size=50,
                    num_patches=g_config['num_patches'],
                    num_channels=3,
                    is_training=False)
        tf.summary.histogram('dx',
                             tf_model.prediction - tf_shapes,
                             collections=['train'])

        bn_updates = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope)

        # Add histograms for gradients.
        for grad, var in tf_grads:
            if grad is not None:
                tf.summary.histogram(var.op.name + '/gradients',
                                     grad,
                                     collections=['train'])

        # Apply the gradients to adjust the shared variables.
        with tf.name_scope('Optimizer', values=[tf_grads, tf_global_step]):
            apply_gradient_op = opt.apply_gradients(tf_grads,
                                                    global_step=tf_global_step)

        # Add histograms for trainable variables.
        for var in tf.trainable_variables():
            tf.summary.histogram(var.op.name, var, collections=['train'])

        with tf.name_scope('MovingAverage', values=[tf_global_step]):
            variable_averages = tf.train.ExponentialMovingAverage(
                g_config['MOVING_AVERAGE_DECAY'], tf_global_step)
            variables_to_average = (tf.trainable_variables() +
                                    tf.moving_average_variables())
            variables_averages_op = variable_averages.apply(
                variables_to_average)

        # Group all updates to into a single train op.
        bn_updates_op = tf.group(*bn_updates, name='BNGroup')
        train_op = tf.group(apply_gradient_op,
                            variables_averages_op,
                            bn_updates_op,
                            name='TrainGroup')

        # Create a saver.
        saver = tf.train.Saver()

        train_summary_op = tf.summary.merge_all('train')
        validate_summary_op = tf.summary.merge_all('validate')

        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        sess = tf.Session(graph=graph, config=config)
        init = tf.global_variables_initializer()
        print('Initializing variables...')
        sess.run(init)
        print('Initialized variables.')

        # Assuming model_checkpoint_path looks something like:
        #   /ckpt/train/model.ckpt-0,
        # extract global_step from it.
        start_step = 0
        ckpt = tf.train.get_checkpoint_state(g_config['train_dir'])
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            start_step = int(
                ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]) + 1
            print('%s: Restart from %s' %
                  (datetime.now(), g_config['train_dir']))
        else:
            ckpt = tf.train.get_checkpoint_state(g_config['ckpt_dir'])
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                tf_global_step_op = tf_global_step.assign(0)
                sess.run(tf_global_step_op)
                print('%s: Pre-trained model restored from %s' %
                      (datetime.now(), g_config['ckpt_dir']))

        train_writer = tf.summary.FileWriter(g_config['train_dir'] + '/train',
                                             sess.graph)
        validate_writer = tf.summary.FileWriter(
            g_config['train_dir'] + '/validate', sess.graph)

        print('Starting training...')
        steps_per_epoch = 15000 / g_config['batch_size']
        for step in range(start_step, g_config['max_steps']):
            if step % steps_per_epoch == 0:
                start_time = time.time()
                _, train_loss, train_summary = sess.run(
                    [train_op, tf_model.nme, train_summary_op])
                duration = time.time() - start_time
                validate_loss, validate_summary = sess.run(
                    [tf_model_v.nme, validate_summary_op])
                train_writer.add_summary(train_summary, step)
                validate_writer.add_summary(validate_summary, step)

                print('%s: step %d, loss = %.4f (%.3f sec/batch)' %
                      (datetime.now(), step, train_loss, duration))
                print('%s: step %d, validate loss = %.4f' %
                      (datetime.now(), step, validate_loss))
            else:
                start_time = time.time()
                _, train_loss = sess.run([train_op, tf_model.nme])
                duration = time.time() - start_time
                if step % 100 == 0:
                    print('%s: step %d, loss = %.4f (%.3f sec/batch)' %
                          (datetime.now(), step, train_loss, duration))

            assert not np.isnan(train_loss), 'Model diverged with loss = NaN'

            if step % steps_per_epoch == 0 or (step +
                                               1) == g_config['max_steps']:
                checkpoint_path = os.path.join(g_config['train_dir'],
                                               'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
Beispiel #3
0
def evaluate():
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        path_base = Path(g_config['eval_dataset']).parent.parent
        _mean_shape = mio.import_pickle(path_base / 'reference_shape.pkl')
        _mean_shape = data_provider.align_reference_shape_to_112(_mean_shape)
        tf_mean_shape = tf.constant(_mean_shape,
                                    dtype=tf.float32,
                                    name='MeanShape')

        def decode_feature(serialized):
            feature = {
                'test/image': tf.FixedLenFeature([], tf.string),
                'test/shape': tf.VarLenFeature(tf.float32),
                'test/init': tf.VarLenFeature(tf.float32),
            }
            features = tf.parse_single_example(serialized, features=feature)
            decoded_image = tf.decode_raw(features['test/image'], tf.float32)
            decoded_image = tf.reshape(decoded_image, (112, 112, 3))
            decoded_shape = tf.sparse.to_dense(features['test/shape'])
            decoded_shape = tf.reshape(decoded_shape,
                                       (g_config['num_patches'], 2))
            return decoded_image, decoded_shape

        with tf.name_scope('DataProvider', values=[]):
            tf_dataset = tf.data.TFRecordDataset([str(path_base / 'test.bin')])
            tf_dataset = tf_dataset.map(decode_feature)
            tf_dataset = tf_dataset.batch(1)
            tf_dataset = tf_dataset.prefetch(1000)
            tf_iterator = tf_dataset.make_one_shot_iterator()
            tf_images, tf_shapes = tf_iterator.get_next(name='batch')
            tf_images.set_shape((1, 112, 112, 3))
            tf_shapes.set_shape((1, 73, 2))

        print('Loading model...')
        with tf.device(g_config['eval_device']):
            model = mdm_model.MDMModel(tf_images,
                                       tf_shapes,
                                       tf_mean_shape,
                                       batch_size=1,
                                       num_patches=g_config['num_patches'],
                                       num_channels=3,
                                       is_training=False)

        # Restore the moving average version of the learned variables for eval.
        variable_averages = tf.train.ExponentialMovingAverage(
            g_config['MOVING_AVERAGE_DECAY'])
        variables_to_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)

        graph_def = tf.get_default_graph().as_graph_def()
        summary_writer = tf.summary.FileWriter(g_config['eval_dir'],
                                               graph_def=graph_def)

        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        with tf.Session(config=config) as sess:
            ckpt = tf.train.get_checkpoint_state(g_config['train_dir'])
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                # Assuming model_checkpoint_path looks something like:
                #   /ckpt/train/model.ckpt-0,
                # extract global_step from it.
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split(
                    '-')[-1]
                print('Successfully loaded model from {} at step={}.'.format(
                    ckpt.model_checkpoint_path, global_step))
            else:
                print('No checkpoint file found')
                return

            eval_base = Path('Evaluate')
            for i in range(10):
                eval_path = eval_base / 'err{}'.format(i)
                if not eval_path.exists():
                    eval_path.mkdir(parents=True)

            num_iter = g_config['num_examples']
            # Counts the number of correct predictions.
            errors = []
            mean_errors = []

            print('%s: starting evaluation on (%s).' %
                  (datetime.now(), g_config['eval_dataset']))
            start_time = time.time()
            for step in range(num_iter):
                nme, ne, img = sess.run(
                    [model.batch_nme, model.batch_ne, model.out_images])
                error_level = min(9, int(nme[0] * 100))
                plt.imsave(
                    'Evaluate/err{}/step{}.png'.format(error_level, step),
                    img[0])
                errors.append(ne)
                mean_errors.append(nme)
                step += 1
                if step % 20 == 0:
                    duration = time.time() - start_time
                    sec_per_batch = duration / 20.0
                    examples_per_sec = 1. / sec_per_batch
                    log_str = '{}: [{:d} batches out of {:d}] ({:.1f} examples/sec; {:.3f} sec/batch)'
                    print(
                        log_str.format(datetime.now(), step, num_iter,
                                       examples_per_sec, sec_per_batch))
                    start_time = time.time()

            errors = np.array(errors)
            errors = np.reshape(errors, (-1, g_config['num_patches']))
            print(errors.shape)
            mean_errors = np.vstack(mean_errors).ravel()
            mean_rse = np.mean(errors, 0)
            mean_rmse = mean_errors.mean()
            with open('Evaluate/errors.txt', 'w') as ofs:
                for row, avg in zip(errors, mean_errors):
                    for col in row:
                        ofs.write('%.4f, ' % col)
                    ofs.write('%.4f' % avg)
                    ofs.write('\n')
                for col in mean_rse:
                    ofs.write('%.4f, ' % col)
                ofs.write('%.4f' % mean_rmse)
                ofs.write('\n')
            auc_at_08 = (mean_errors < .08).mean()
            auc_at_05 = (mean_errors < .05).mean()
            print('Errors', mean_errors.shape)
            print(
                '%s: mean_rmse = %.4f, auc @ 0.05 = %.4f, auc @ 0.08 = %.4f [%d examples]'
                % (datetime.now(), mean_errors.mean(), auc_at_05, auc_at_08,
                   num_iter))

            ced_image = plot_ced([mean_errors.tolist()], ['MDM'])
            ced_plot = sess.run(
                tf.summary.merge(
                    [tf.summary.image('ced_plot', ced_image[None, ...])]))
            summary_writer.add_summary(ced_plot, global_step)
def train(scope=''):
    """Train on dataset for a number of steps."""
    with tf.Graph().as_default() as graph, tf.device('/gpu:0'):
        # Global steps
        tf_global_step = tf.get_variable(
            'GlobalStep', [],
            initializer=tf.constant_initializer(0),
            trainable=False)

        # Learning rate
        tf_lr = tf.train.exponential_decay(g_config['learning_rate'],
                                           tf_global_step,
                                           g_config['learning_rate_step'],
                                           g_config['learning_rate_decay'],
                                           staircase=True,
                                           name='LearningRate')
        tf.summary.scalar('learning_rate', tf_lr)

        # Create an optimizer that performs gradient descent.
        opt = tf.train.AdamOptimizer(tf_lr)

        data_provider.prepare_images(g_config['train_dataset'].split(':'),
                                     num_patches=g_config['num_patches'],
                                     verbose=True)
        path_base = Path(g_config['train_dataset'].split(':')[0]).parent.parent
        _mean_shape = mio.import_pickle(path_base / 'reference_shape.pkl')
        _mean_shape = data_provider.align_reference_shape_to_112(_mean_shape)
        assert (isinstance(_mean_shape, np.ndarray))
        assert (_mean_shape.shape[0] == g_config['num_patches'])

        tf_mean_shape = tf.constant(_mean_shape,
                                    dtype=tf.float32,
                                    name='MeanShape')

        def decode_feature(serialized):
            feature = {
                'train/image': tf.FixedLenFeature([], tf.string),
                'train/shape': tf.VarLenFeature(tf.float32),
            }
            features = tf.parse_single_example(serialized, features=feature)
            decoded_image = tf.decode_raw(features['train/image'], tf.float32)
            decoded_image = tf.reshape(decoded_image, (336, 336, 3))
            decoded_shape = tf.sparse.to_dense(features['train/shape'])
            decoded_shape = tf.reshape(decoded_shape,
                                       (g_config['num_patches'], 2))
            return decoded_image, decoded_shape

        def get_random_sample(image, shape, rotation_stddev=10):
            # Read a random image with landmarks and bb
            image = menpo.image.Image(image.transpose((2, 0, 1)), copy=False)
            image.landmarks['PTS'] = PointCloud(shape)

            if np.random.rand() < .5:
                image = utils.mirror_image(image)
            if np.random.rand() < .5:
                theta = np.random.normal(scale=rotation_stddev)
                rot = menpo.transform.rotate_ccw_about_centre(
                    image.landmarks['PTS'], theta)
                image = image.warp_to_shape(image.shape, rot)
            bb = image.landmarks['PTS'].bounding_box().points
            miny, minx = np.min(bb, 0)
            maxy, maxx = np.max(bb, 0)
            bbsize = max(maxx - minx, maxy - miny)
            center = [(miny + maxy) / 2., (minx + maxx) / 2.]
            shift = (np.random.rand(2) - 0.5) / 3. * bbsize
            image.landmarks['bb'] = PointCloud([
                [
                    center[0] - bbsize * 0.5 + shift[0],
                    center[1] - bbsize * 0.5 + shift[1]
                ],
                [
                    center[0] + bbsize * 0.5 + shift[0],
                    center[1] + bbsize * 0.5 + shift[1]
                ],
            ]).bounding_box()
            proportion = 1.0 / 6.0 + float(np.random.rand() - 0.5) / 50.0
            image = image.crop_to_landmarks_proportion(proportion, group='bb')
            image = image.resize((112, 112))

            random_image = image.pixels.transpose(1, 2, 0).astype('float32')
            random_shape = image.landmarks['PTS'].points.astype('float32')
            return random_image, random_shape

        def get_init_shape(image, shape):
            return image, shape, tf_mean_shape

        def distort_color(image, shape, init_shape):
            return data_provider.distort_color(image), shape, init_shape

        with tf.name_scope('DataProvider', values=[tf_mean_shape]):
            tf_dataset = tf.data.TFRecordDataset(
                [str(path_base / 'train.bin')])
            tf_dataset = tf_dataset.repeat()
            tf_dataset = tf_dataset.map(decode_feature)
            tf_dataset = tf_dataset.map(lambda x, y: tf.py_func(
                get_random_sample, [x, y], [tf.float32, tf.float32],
                stateful=True,
                name='RandomSample'))
            tf_dataset = tf_dataset.map(get_init_shape)
            tf_dataset = tf_dataset.map(distort_color)
            tf_dataset = tf_dataset.batch(g_config['batch_size'], True)
            tf_dataset = tf_dataset.prefetch(7500)
            tf_iterator = tf_dataset.make_one_shot_iterator()
            tf_images, tf_shapes, tf_initial_shapes = tf_iterator.get_next(
                name='Batch')
            tf_images.set_shape([g_config['batch_size'], 112, 112, 3])
            tf_shapes.set_shape([g_config['batch_size'], 73, 2])
            tf_initial_shapes.set_shape([g_config['batch_size'], 73, 2])

        print('Defining model...')
        with tf.device(g_config['train_device']):
            tf_model = mdm_model.MDMModel(
                tf_images,
                tf_shapes,
                tf_initial_shapes,
                batch_size=g_config['batch_size'],
                num_iterations=g_config['num_iterations'],
                num_patches=g_config['num_patches'],
                patch_shape=(g_config['patch_size'], g_config['patch_size']),
                num_channels=3)
            with tf.name_scope('Losses',
                               values=[tf_model.prediction, tf_shapes]):
                tf_norm_error = tf_model.normalized_rmse(
                    tf_model.prediction, tf_shapes)
                tf_loss = tf.reduce_mean(tf_norm_error)
            tf.summary.scalar('losses/total', tf_loss)
            # Calculate the gradients for the batch of data
            tf_grads = opt.compute_gradients(tf_loss)
        tf.summary.histogram('dx', tf_model.prediction - tf_shapes)

        bn_updates = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope)

        # Add histograms for gradients.
        for grad, var in tf_grads:
            if grad is not None:
                tf.summary.histogram(var.op.name + '/gradients', grad)

        # Apply the gradients to adjust the shared variables.
        with tf.name_scope('Optimizer', values=[tf_grads, tf_global_step]):
            apply_gradient_op = opt.apply_gradients(tf_grads,
                                                    global_step=tf_global_step)

        # Add histograms for trainable variables.
        for var in tf.trainable_variables():
            tf.summary.histogram(var.op.name, var)

        with tf.name_scope('MovingAverage', values=[tf_global_step]):
            variable_averages = tf.train.ExponentialMovingAverage(
                g_config['MOVING_AVERAGE_DECAY'], tf_global_step)
            variables_to_average = (tf.trainable_variables() +
                                    tf.moving_average_variables())
            variables_averages_op = variable_averages.apply(
                variables_to_average)

        # Group all updates to into a single train op.
        bn_updates_op = tf.group(*bn_updates, name='BNGroup')
        train_op = tf.group(apply_gradient_op,
                            variables_averages_op,
                            bn_updates_op,
                            name='TrainGroup')

        # Create a saver.
        saver = tf.train.Saver()

        # Build the summary operation from the last tower summaries.
        summary_op = tf.summary.merge_all()
        # Start running operations on the Graph. allow_soft_placement must be
        # set to True to build towers on GPU, as some of the ops do not have GPU
        # implementations.
        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)
        # Build an initialization operation to run below.
        init = tf.global_variables_initializer()
        print('Initializing variables...')
        sess.run(init)
        print('Initialized variables.')

        start_step = 0
        ckpt = tf.train.get_checkpoint_state(g_config['train_dir'])
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            # Assuming model_checkpoint_path looks something like:
            #   /ckpt/train/model.ckpt-0,
            # extract global_step from it.
            start_step = int(
                ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]) + 1
            print('%s: Pre-trained model restored from %s' %
                  (datetime.now(), g_config['train_dir']))
        elif TUNE:
            assign_op = []
            vvv = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
            cnt = 0
            for v in vvv:
                if v.name in tf_model.var_map:
                    assign_op.append(
                        v.assign(
                            graph.get_tensor_by_name(
                                tf_model.var_map[v.name])))
                    cnt += 1
                else:
                    print(v.name)
            sess.run(assign_op)
            print('%s: Pre-trained model restored from graph.pb %d/%d' %
                  (datetime.now(), cnt, len(vvv)))

        summary_writer = tf.summary.FileWriter(g_config['train_dir'],
                                               sess.graph)

        print('Starting training...')
        for step in range(start_step, g_config['max_steps']):
            start_time = time.time()
            _, loss_value = sess.run([train_op, tf_loss])
            duration = time.time() - start_time

            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            if step % 100 == 0:
                examples_per_sec = g_config['batch_size'] / float(duration)
                format_str = (
                    '%s: step %d, loss = %.4f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str % (datetime.now(), step, loss_value,
                                    examples_per_sec, duration))

            if step % 200 == 0:
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

            # Save the model checkpoint periodically.
            if step % 1000 == 0 or (step + 1) == g_config['max_steps']:
                checkpoint_path = os.path.join(g_config['train_dir'],
                                               'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
def train(scope=''):
    """Train on dataset for a number of steps."""
    with tf.Graph().as_default() as graph, tf.device('/gpu:0'):
        # Global steps
        tf_global_step = tf.get_variable(
            'GlobalStep', [],
            initializer=tf.constant_initializer(0),
            trainable=False
        )

        # Learning rate
        tf_lr = tf.train.exponential_decay(
            g_config['learning_rate'],
            tf_global_step,
            g_config['learning_rate_step'],
            g_config['learning_rate_decay'],
            staircase=True,
            name='LearningRate'
        )
        tf.summary.scalar('learning_rate', tf_lr, collections=['train'])

        # Create an optimizer that performs gradient descent.
        opt = tf.train.AdamOptimizer(tf_lr)

        data_provider.prepare_images(
            g_config['train_dataset'].split(':'),
            num_patches=g_config['num_patches'], verbose=True
        )
        path_base = Path(g_config['train_dataset'].split(':')[0]).parent.parent
        _mean_shape = mio.import_pickle(path_base / 'mean_shape.pkl')
        _mean_shape = data_provider.align_reference_shape_to_112(_mean_shape)
        assert(isinstance(_mean_shape, np.ndarray))
        assert(_mean_shape.shape[0] == g_config['num_patches'])
        _negatives = []
        for mp_image in mio.import_images('Dataset/Neg/*.png', verbose=True):
            _negatives.append(mp_image.pixels.transpose(1, 2, 0).astype(np.float32))
        _num_negatives = len(_negatives)

        tf_mean_shape = tf.constant(_mean_shape, dtype=tf.float32, name='MeanShape')

        def get_random_sample(image, shape):
            # Occlude
            _O_AREA = 0.15
            _O_MIN_H = 0.15
            _O_MAX_H = 1.0
            if np.random.rand() < .3:
                rh = min(112, int((np.random.rand() * (_O_MAX_H - _O_MIN_H) + _O_MIN_H) * 112))
                rw = min(112, int(12544 * _O_AREA / rh))
                dy = int(np.random.rand() * (112 - rh))
                dx = int(np.random.rand() * (112 - rw))
                idx = int(np.random.rand() * _num_negatives)
                image[dy:dy+rh, dx:dx+rw] = np.minimum(
                    1.0,
                    _negatives[idx][dy:dy+rh, dx:dx+rw]
                )
            return image, shape

        def decode_feature_and_augment(serialized):
            feature = {
                'train/image': tf.FixedLenFeature([], tf.string),
                'train/shape': tf.VarLenFeature(tf.float32),
            }
            features = tf.parse_single_example(serialized, features=feature)
            decoded_image = tf.decode_raw(features['train/image'], tf.float32)
            decoded_image = tf.reshape(decoded_image, (112, 112, 3))
            decoded_shape = tf.sparse.to_dense(features['train/shape'])
            decoded_shape = tf.reshape(decoded_shape, (g_config['num_patches'], 2))

            #decoded_image, decoded_shape = tf.py_func(
            #    get_random_sample, [decoded_image, decoded_shape], [tf.float32, tf.float32],
            #    stateful=True,
            #    name='RandomSample'
            #)
            return data_provider.distort_color(decoded_image), decoded_shape

        def decode_feature(serialized):
            feature = {
                'validate/image': tf.FixedLenFeature([], tf.string),
                'validate/shape': tf.VarLenFeature(tf.float32),
            }
            features = tf.parse_single_example(serialized, features=feature)
            decoded_image = tf.decode_raw(features['validate/image'], tf.float32)
            decoded_image = tf.reshape(decoded_image, (112, 112, 3))
            decoded_shape = tf.sparse.to_dense(features['validate/shape'])
            decoded_shape = tf.reshape(decoded_shape, (g_config['num_patches'], 2))
            return decoded_image, decoded_shape

        with tf.name_scope('DataProvider'):
            tf_dataset = tf.data.TFRecordDataset([
                str(path_base / 'train_0.bin'),
                str(path_base / 'train_1.bin'),
                str(path_base / 'train_2.bin'),
                str(path_base / 'train_3.bin')
            ])
            tf_dataset = tf_dataset.repeat()
            tf_dataset = tf_dataset.map(decode_feature_and_augment, num_parallel_calls=5)
            tf_dataset = tf_dataset.shuffle(480)
            tf_dataset = tf_dataset.batch(g_config['batch_size'], True)
            tf_dataset = tf_dataset.prefetch(1)
            tf_iterator = tf_dataset.make_one_shot_iterator()
            tf_images, tf_shapes = tf_iterator.get_next(name='Batch')
            tf_images.set_shape([g_config['batch_size'], 112, 112, 3])
            tf_shapes.set_shape([g_config['batch_size'], 75, 2])

            tf_dataset_v = tf.data.TFRecordDataset([str(path_base / 'validate.bin')])
            tf_dataset_v = tf_dataset_v.repeat()
            tf_dataset_v = tf_dataset_v.map(decode_feature, num_parallel_calls=5)
            tf_dataset_v = tf_dataset_v.batch(50, True)
            tf_dataset_v = tf_dataset_v.prefetch(1)
            tf_iterator_v = tf_dataset_v.make_one_shot_iterator()
            tf_images_v, tf_shapes_v = tf_iterator_v.get_next(name='ValidateBatch')
            tf_images_v.set_shape([50, 112, 112, 3])
            tf_shapes_v.set_shape([50, 75, 2])

        print('Defining model...')
        with tf.device(g_config['train_device']):
            tf_model = mdm_model.MDMModel(
                tf_images,
                tf_shapes,
                tf_mean_shape,
                batch_size=g_config['batch_size'],
                num_patches=g_config['num_patches'],
                num_channels=3,
                multiplier=g_config['multiplier']
            )
            tf_grads = opt.compute_gradients(tf_model.nme)
            with tf.name_scope('Validate'):
                tf_model_v = mdm_model.MDMModel(
                    tf_images_v,
                    tf_shapes_v,
                    tf_mean_shape,
                    batch_size=50,
                    num_patches=g_config['num_patches'],
                    num_channels=3,
                    multiplier=g_config['multiplier'],
                    is_training=False
                )
        tf.summary.histogram('dx', tf_model.prediction - tf_shapes, collections=['train'])

        bn_updates = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope)

        # Add histograms for gradients.
        for grad, var in tf_grads:
            if grad is not None:
                tf.summary.histogram(var.op.name + '/gradients', grad, collections=['train'])

        # Apply the gradients to adjust the shared variables.
        with tf.name_scope('Optimizer', values=[tf_grads, tf_global_step]):
            apply_gradient_op = opt.apply_gradients(tf_grads, global_step=tf_global_step)

        # Add histograms for trainable variables.
        for var in tf.trainable_variables():
            tf.summary.histogram(var.op.name, var, collections=['train'])

        with tf.name_scope('MovingAverage', values=[tf_global_step]):
            variable_averages = tf.train.ExponentialMovingAverage(g_config['MOVING_AVERAGE_DECAY'], tf_global_step)
            variables_to_average = (tf.trainable_variables() + tf.moving_average_variables())
            variables_averages_op = variable_averages.apply(variables_to_average)

        # Group all updates to into a single train op.
        bn_updates_op = tf.group(*bn_updates, name='BNGroup')
        train_op = tf.group(
            apply_gradient_op, variables_averages_op, bn_updates_op,
            name='TrainGroup'
        )

        # Create a saver.
        saver = tf.train.Saver()

        train_summary_op = tf.summary.merge_all('train')
        validate_summary_op = tf.summary.merge_all('validate')

        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        sess = tf.Session(graph=graph, config=config)
        init = tf.global_variables_initializer()
        print('Initializing variables...')
        sess.run(init)
        print('Initialized variables.')

        # Assuming model_checkpoint_path looks something like:
        #   /ckpt/train/model.ckpt-0,
        # extract global_step from it.
        start_step = 0
        ckpt = tf.train.get_checkpoint_state(g_config['train_dir'])
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            start_step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]) + 1
            print('%s: Restart from %s' % (datetime.now(), g_config['train_dir']))
        else:
            ckpt = tf.train.get_checkpoint_state(g_config['ckpt_dir'])
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                tf_global_step_op = tf_global_step.assign(0)
                sess.run(tf_global_step_op)
                print('%s: Pre-trained model restored from %s' % (datetime.now(), g_config['ckpt_dir']))

        train_writer = tf.summary.FileWriter(g_config['train_dir'] + '/train', sess.graph)
        validate_writer = tf.summary.FileWriter(g_config['train_dir'] + '/validate', sess.graph)

        print('Starting training...')
        steps_per_epoch = 15000 / g_config['batch_size']
        for step in range(start_step, g_config['max_steps']):
            if step % steps_per_epoch == 0:
                start_time = time.time()
                _, train_loss, train_summary = sess.run([train_op, tf_model.nme, train_summary_op])
                duration = time.time() - start_time
                validate_loss, validate_summary = sess.run([tf_model_v.nme, validate_summary_op])
                train_writer.add_summary(train_summary, step)
                validate_writer.add_summary(validate_summary, step)

                print(
                    '%s: step %d, loss = %.4f (%.3f sec/batch)' % (
                        datetime.now(), step, train_loss, duration
                    )
                )
                print(
                    '%s: step %d, validate loss = %.4f' % (
                        datetime.now(), step, validate_loss
                    )
                )
            else:
                start_time = time.time()
                _, train_loss = sess.run([train_op, tf_model.nme])
                duration = time.time() - start_time
                if step % 100 == 0:
                    print(
                        '%s: step %d, loss = %.4f (%.3f sec/batch)' % (
                            datetime.now(), step, train_loss, duration
                        )
                    )

            assert not np.isnan(train_loss), 'Model diverged with loss = NaN'

            if step % steps_per_epoch == 0 or (step + 1) == g_config['max_steps']:
                checkpoint_path = os.path.join(g_config['train_dir'], 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
def ckpt_pb(pb_path, lite_path):
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    # to pb
    tf.reset_default_graph()
    with tf.Graph().as_default() as graph, tf.Session(graph=graph) as sess:
        path_base = Path(g_config['eval_dataset']).parent.parent
        _mean_shape = mio.import_pickle(path_base / 'reference_shape.pkl')
        _mean_shape = data_provider.align_reference_shape_to_112(_mean_shape)
        assert isinstance(_mean_shape, np.ndarray)
        print(_mean_shape.shape)

        tf_img = tf.placeholder(dtype=tf.float32, shape=(1, 112, 112, 3), name='Inputs/InputImage')
        tf_dummy = tf.placeholder(dtype=tf.float32, shape=(1, 73, 2), name='dummy')
        tf_shape = tf.constant(_mean_shape, dtype=tf.float32, shape=(73, 2), name='Inputs/MeanShape')

        mdm_model.MDMModel(
            tf_img,
            tf_dummy,
            tf_shape,
            batch_size=1,
            num_patches=g_config['num_patches'],
            num_channels=3,
            is_training=False
        )

        saver = tf.train.Saver()
        ckpt = tf.train.get_checkpoint_state(g_config['train_dir'])
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
            print('Successfully loaded model from {} at step={}.'.format(ckpt.model_checkpoint_path, global_step))
        else:
            print('No checkpoint file found')
            return

        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess, tf.get_default_graph().as_graph_def(), ['Network/Predict/add']
        )

        with tf.gfile.FastGFile(pb_path, mode='wb') as f:
            f.write(output_graph_def.SerializeToString())

    # to tflite
    tf.reset_default_graph()
    with tf.Graph().as_default() as graph:
        with tf.gfile.GFile(pb_path, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
        tf.import_graph_def(
            graph_def,
            input_map=None,
            return_elements=None,
            name='prefix',
            op_dict=None,
            producer_op_list=None
        )
        tf.summary.FileWriter('.', graph=graph)

        converter = tf.contrib.lite.TFLiteConverter.from_frozen_graph(
            pb_path,
            ['Inputs/InputImage'],
            ['Network/Predict/Reshape', 'Inputs/MeanShape']
        )
        with tf.gfile.FastGFile(lite_path, mode='wb') as f:
            f.write(converter.convert())
        # Load TFLite model and allocate tensors.
        interpreter = tf.contrib.lite.Interpreter(model_path=lite_path)
        interpreter.allocate_tensors()

        # Get input and output tensors.
        input_details = interpreter.get_input_details()
        output_details = interpreter.get_output_details()
        print(input_details)
        print(output_details)