def train(gitapp: controller.GetInputTargetAndPredictedParameters): """Train a model.""" with tf.device('/cpu:0'): # with slim.arg_scope([slim.model_variable, slim.variable], device='/cpu:0'): g = tf.Graph() with g.as_default(): total_loss_op, _, _ = total_loss(gitapp) if FLAGS.optimizer == OPTIMIZER_MOMENTUM: # TODO(ericmc): We may want to do weight decay with the other # optimizers, too. learning_rate = tf.train.exponential_decay( FLAGS.learning_rate, slim.variables.get_global_step(), FLAGS.learning_decay_steps, 0.999, staircase=False) tf.summary.scalar('learning_rate', learning_rate) optimizer = tf.train.MomentumOptimizer(learning_rate, 0.875) elif FLAGS.optimizer == OPTIMIZER_ADAGRAD: optimizer = tf.train.AdagradOptimizer(FLAGS.learning_rate) elif FLAGS.optimizer == OPTIMIZER_ADAM: optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate) else: raise NotImplementedError('Unsupported optimizer: %s' % FLAGS.optimizer) # Set up training. train_op = slim.learning.create_train_op( total_loss_op, optimizer, summarize_gradients=True, colocate_gradients_with_ops=True) if FLAGS.restore_directory: init_fn = util.restore_model(FLAGS.restore_directory, FLAGS.restore_logits) else: logging.info('Training a new model.') init_fn = None total_variable_size, _ = slim.model_analyzer.analyze_vars( slim.get_variables(), print_info=True) logging.info('Total number of variables: %d', total_variable_size) log_entry_points(g) config = tf.ConfigProto(log_device_placement=True, allow_soft_placement=False) slim.learning.train( train_op=train_op, logdir=output_directory(), master=FLAGS.master, is_chief=FLAGS.task == 0, number_of_steps=None, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs, init_fn=init_fn, saver=tf.train.Saver(keep_checkpoint_every_n_hours=0.25), session_config=config )
def infer( gitapp: controller.GetInputTargetAndPredictedParameters, restore_directory: str, output_directory: str, extract_patch_size: int, stitch_stride: int, infer_size: int, channel_whitelist: Optional[List[str]], simplify_error_panels: bool, ): """Runs inference on an image. Args: gitapp: GetInputTargetAndPredictedParameters. restore_directory: Where to restore the model from. output_directory: Where to write the generated images. extract_patch_size: The size of input to the model. stitch_stride: The stride size when running model inference. Equivalently, the output size of the model. infer_size: The number of simultaneous inferences to perform in the row and column dimensions. For example, if this is 8, inference will be performed in 8 x 8 blocks for a batch size of 64. channel_whitelist: If provided, only images for the given channels will be produced. This can be used to create simpler error panels. simplify_error_panels: Whether to create simplified error panels. Raises: ValueError: If 1) The DataParameters don't contain a ReadPNGsParameters. 2) The images must be larger than the input to the network. 3) The graph must not contain queues. """ rpp = gitapp.dp.io_parameters if not isinstance(rpp, data_provider.ReadPNGsParameters): raise ValueError( 'Data provider must contain a ReadPNGsParameter, but was: %r', gitapp.dp) original_crop_size = rpp.crop_size image_num_rows, image_num_columns = util.image_size(rpp.directory) logging.info('Uncropped image size is %d x %d', image_num_rows, image_num_columns) image_num_rows = min(image_num_rows, original_crop_size) if image_num_rows < extract_patch_size: raise ValueError( 'Image is too small for inference to be performed: %d vs %d', image_num_rows, extract_patch_size) image_num_columns = min(image_num_columns, original_crop_size) if image_num_columns < extract_patch_size: raise ValueError( 'Image is too small for inference to be performed: %d vs %d', image_num_columns, extract_patch_size) logging.info('After cropping, input image size is (%d, %d)', image_num_rows, image_num_columns) num_row_inferences = (image_num_rows - extract_patch_size) // ( stitch_stride * infer_size) num_column_inferences = (image_num_columns - extract_patch_size) // ( stitch_stride * infer_size) logging.info('Running %d x %d inferences', num_row_inferences, num_column_inferences) num_output_rows = (num_row_inferences * infer_size * stitch_stride) num_output_columns = (num_column_inferences * infer_size * stitch_stride) logging.info('Output image size is (%d, %d)', num_output_rows, num_output_columns) g = tf.Graph() with g.as_default(): row_start = tf.placeholder(dtype=np.int32, shape=[]) column_start = tf.placeholder(dtype=np.int32, shape=[]) # Replace the parameters with a new set, which will cause the network to # run inference in just a local region. gitapp = gitapp._replace( dp=gitapp.dp._replace( io_parameters=rpp._replace( row_start=row_start, column_start=column_start, crop_size=(infer_size - 1) * stitch_stride + extract_patch_size, ))) visualization_lts = controller.setup_stitch(gitapp) def get_statistics(tensor): rc = lt.ReshapeCoder(list(tensor.axes.keys())[:-1], ['batch']) return rc.decode(ops.distribution_statistics(rc.encode(tensor))) visualize_input_lt = visualization_lts['input'] visualize_predict_input_lt = get_statistics( visualization_lts['predict_input']) visualize_target_lt = visualization_lts['target'] visualize_predict_target_lt = get_statistics( visualization_lts['predict_target']) input_lt = lt.LabeledTensor( tf.placeholder( dtype=np.float32, shape=[ 1, num_output_rows, num_output_columns, len(gitapp.dp.input_z_values), 1, 2 ]), axes=[ 'batch', 'row', 'column', ('z', gitapp.dp.input_z_values), ('channel', ['TRANSMISSION']), ('mask', [False, True]), ]) predict_input_lt = lt.LabeledTensor( tf.placeholder( dtype=np.float32, shape=[ 1, num_output_rows, num_output_columns, len(gitapp.dp.input_z_values), 1, len(visualize_predict_input_lt.axes['statistic']), ]), axes=[ 'batch', 'row', 'column', ('z', gitapp.dp.input_z_values), ('channel', ['TRANSMISSION']), visualize_predict_input_lt.axes['statistic'], ]) input_error_panel_lt = visualize.error_panel_from_statistics( input_lt, predict_input_lt, simplify_error_panels) target_lt = lt.LabeledTensor( tf.placeholder( dtype=np.float32, shape=[ 1, num_output_rows, num_output_columns, len(gitapp.dp.target_z_values), len(gitapp.dp.target_channel_values) + 1, 2 ]), axes=[ 'batch', 'row', 'column', ('z', gitapp.dp.target_z_values), ('channel', gitapp.dp.target_channel_values + ['NEURITE_CONFOCAL']), ('mask', [False, True]), ]) predict_target_lt = lt.LabeledTensor( tf.placeholder( dtype=np.float32, shape=[ 1, num_output_rows, num_output_columns, len(gitapp.dp.target_z_values), len(gitapp.dp.target_channel_values) + 1, len(visualize_predict_target_lt.axes['statistic']), ]), axes=[ 'batch', 'row', 'column', ('z', gitapp.dp.target_z_values), ('channel', gitapp.dp.target_channel_values + ['NEURITE_CONFOCAL']), visualize_predict_target_lt.axes['statistic'], ]) logging.info('input_lt: %r', input_lt) logging.info('predict_input_lt: %r', predict_input_lt) logging.info('target_lt: %r', target_lt) logging.info('predict_target_lt: %r', predict_target_lt) def select_channels(tensor): if channel_whitelist is not None: return lt.select(tensor, {'channel': channel_whitelist}) else: return tensor target_error_panel_lt = visualize.error_panel_from_statistics( select_channels(target_lt), select_channels(predict_target_lt), simplify_error_panels) # There shouldn't be any queues in this configuration. queue_runners = g.get_collection(tf.GraphKeys.QUEUE_RUNNERS) if queue_runners: raise ValueError('Graph must not have queues, but had: %r', queue_runners) logging.info('Attempting to find restore checkpoint in %s', restore_directory) init_fn = util.restore_model( restore_directory, restore_logits=True, restore_global_step=True) with tf.Session() as sess: logging.info('Generating images') init_fn(sess) input_rows = [] predict_input_rows = [] target_rows = [] predict_target_rows = [] for infer_row in range(num_row_inferences): input_row = [] predict_input_row = [] target_row = [] predict_target_row = [] for infer_column in range(num_column_inferences): rs = infer_row * infer_size * stitch_stride cs = infer_column * infer_size * stitch_stride logging.info('Running inference at offset: (%d, %d)', rs, cs) [inpt, predict_input, target, predict_target] = sess.run( [ visualize_input_lt, visualize_predict_input_lt, visualize_target_lt, visualize_predict_target_lt, ], feed_dict={ row_start: rs, column_start: cs }) input_row.append(inpt) predict_input_row.append(predict_input) target_row.append(target) predict_target_row.append(predict_target) input_rows.append(np.concatenate(input_row, axis=2)) predict_input_rows.append(np.concatenate(predict_input_row, axis=2)) target_rows.append(np.concatenate(target_row, axis=2)) predict_target_rows.append(np.concatenate(predict_target_row, axis=2)) logging.info('Stitching') stitched_input = np.concatenate(input_rows, axis=1) stitched_predict_input = np.concatenate(predict_input_rows, axis=1) stitched_target = np.concatenate(target_rows, axis=1) stitched_predict_target = np.concatenate(predict_target_rows, axis=1) logging.info('Creating error panels') import pickle f = open('/home/ubuntu/resinfer.pk','wb') #pickle.dump() d = {'input_lt':stitched_input, 'predict_input_lt':stitched_predict_input, 'target_lt':stitched_target,'predict_target_lt':stitched_predict_target} pickle.dump(d,f) [input_error_panel, target_error_panel, global_step] = sess.run( [ input_error_panel_lt, target_error_panel_lt, tf.train.get_global_step() ], feed_dict={ input_lt: stitched_input, predict_input_lt: stitched_predict_input, target_lt: stitched_target, predict_target_lt: stitched_predict_target, }) output_directory = os.path.join(output_directory, '%.8d' % global_step) if not gfile.Exists(output_directory): gfile.MakeDirs(output_directory) util.write_image( os.path.join(output_directory, 'input_error_panel.png'), input_error_panel[0, :, :, :]) util.write_image( os.path.join(output_directory, 'target_error_panel.png'), target_error_panel[0, :, :, :]) logging.info('Done generating images')