Esempio n. 1
0
def evaluate():
    """Eval for a number of steps."""
    with tf.Graph().as_default() as g:
        # Just eval once
        num_epochs = 1

        images, labels = input_data.input_pipeline(
            FLAGS.data_dir,
            FLAGS.batch_size,
            fake_data=FLAGS.fake_data,
            num_epochs=num_epochs,
            read_threads=FLAGS.read_threads,
            shuffle_size=FLAGS.shuffle_size,
            num_expected_examples=FLAGS.num_examples)

        logits = topology.inference(images, FLAGS.network_pattern)

        # Add to the Graph the Ops for loss calculation.
        loss = topology.loss(logits, labels)

        saver = tf.train.Saver()

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(FLAGS.log_dir, g)

        while True:
            eval_once(saver, summary_writer, loss, summary_op)
            if FLAGS.run_once:
                break
            time.sleep(FLAGS.eval_interval_secs)
Esempio n. 2
0
def evaluate():
    """Instantiate the network, then eval for a number of steps."""

    # What volume are we scanning?
    edge_len = get_subblock_edge_len()
    scan_sz_V = np.asarray(parse_int_list(FLAGS.scan_size, 3, low_bound=[1, 1, 1]),
                           dtype=np.int)

    scan_start_V = np.asarray(parse_int_list(FLAGS.scan_start, 3, low_bound=[0, 0, 0]),
                              dtype=np.int)
    coord_base_V = scan_start_V - ((edge_len - 1)//2)  # correct for shift to center of ball


    # Generate placeholders for the images and labels.
    loc_iterator = get_loc_iterator(FLAGS.data_dir, FLAGS.batch_size, coord_base_V,
                                    scan_sz_V)

    x_off_op, y_off_op, z_off_op = loc_iterator.get_next()
    subblock = get_subblock_op(x_off_op, y_off_op, z_off_op, edge_len, get_full_block())
    subblock = tf.dtypes.cast(subblock, tf.dtypes.float64)

    images = collect_ball_samples(subblock, edge_len, read_threads=FLAGS.read_threads)
    ctrpt_op = tf.dtypes.cast(images[:,0], tf.dtypes.uint8)

    if FLAGS.random_rotation:
        # ctrpt_op is not modified; it's just being used to fill a needed parameter
        images = tf.cast(images, tf.float32)
        images, ctrpt_op = harmonics.apply_random_rotation(images, ctrpt_op)

    # Build a Graph that computes predictions from the inference model.
    logits = topology.inference(images, FLAGS.network_pattern)
    
    # Set up some prediction statistics
    predicted_op = tf.round(tf.nn.sigmoid(logits))

    saver = tf.train.Saver()

    scanned_blk, pred_blk = scan(coord_base_V, scan_sz_V, loc_iterator,
                                 x_off_op, y_off_op, z_off_op,
                                 saver, ctrpt_op, predicted_op)
    x_base, y_base, z_base = scan_start_V
    scan_sz = scanned_blk.shape
    fname_base = '%s_scanned_%d_%d_%d_%d_%d_%d' % (FLAGS.outname, x_base, y_base, z_base,
                                                   scan_sz[0], scan_sz[1], scan_sz[2])
    writeBOV(fname_base, reorder_array(scanned_blk), 'density')
    fname_base = '%s_pred_%d_%d_%d_%d_%d_%d' % (FLAGS.outname, x_base, y_base, z_base,
                                                scan_sz[0], scan_sz[1], scan_sz[2])
    writeBOV(fname_base, reorder_array(pred_blk), 'prediction')
Esempio n. 3
0
def evaluate():
    """Instantiate the network, then eval for a number of steps."""

    # seed provides the mechanism to control the shuffling which takes place reading input
    seed = tf.placeholder(tf.int64, shape=())

    # Generate placeholders for the images and labels.
    iterator = input_data.input_pipeline_binary(
        FLAGS.data_dir,
        FLAGS.batch_size,
        fake_data=FLAGS.fake_data,
        num_epochs=1,
        read_threads=FLAGS.read_threads,
        shuffle_size=FLAGS.shuffle_size,
        num_expected_examples=FLAGS.num_examples,
        seed=seed)
    image_path, label_path, images, labels = iterator.get_next()

    if FLAGS.verbose:
        print_op = tf.print("images and labels this batch: ", image_path,
                            label_path, labels)
    else:
        print_op = tf.constant('No printing')

    if FLAGS.random_rotation:
        images, labels = harmonics.apply_random_rotation(images, labels)

    # Build a Graph that computes predictions from the inference model.
    logits = topology.inference(images, FLAGS.network_pattern)

    # Add to the Graph the Ops for loss calculation.
    loss = topology.binary_loss(logits, labels)

    # Set up some prediction statistics
    predicted = tf.round(tf.nn.sigmoid(logits))
    correct_pred = tf.equal(predicted, labels)
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    saver = tf.train.Saver()

    with tf.Session() as sess:

        while True:
            eval_once(sess, iterator, saver, seed, labels, loss, accuracy,
                      predicted)
            if FLAGS.run_once:
                break
            time.sleep(FLAGS.eval_interval_secs)
Esempio n. 4
0
def train():
    """Train fish_cubes for a number of steps."""
    # Get the sets of images and labels for training, validation, and
    # test
    if FLAGS.num_epochs:
        num_epochs = FLAGS.num_epochs
    else:
        num_epochs = None

    # Tell TensorFlow that the model will be built into the default Graph.
    with tf.Graph().as_default():
        # Generate placeholders for the images and labels.
        images, labels = input_data.input_pipeline(
            FLAGS.data_dir,
            FLAGS.batch_size,
            fake_data=FLAGS.fake_data,
            num_epochs=num_epochs,
            read_threads=FLAGS.read_threads,
            shuffle_size=FLAGS.shuffle_size,
            num_expected_examples=FLAGS.num_examples)

        # Build a Graph that computes predictions from the inference model.
        logits = topology.inference(images, FLAGS.network_pattern)

        # Add to the Graph the Ops for loss calculation.
        loss = topology.loss(logits, labels)

        # Add to the Graph the Ops that calculate and apply gradients.
        train_op = topology.training(loss, FLAGS.learning_rate)

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()

        # Create the graph, etc.
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())

        # Create a saver for writing training checkpoints.
        saver = tf.train.Saver()

        # Create a session for running operations in the Graph.
        sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))

        # Instantiate a SummaryWriter to output summaries and the Graph.
        summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)

        # Initialize the variables (like the epoch counter).
        if len(FLAGS.starting_snapshot) == 0:
            sess.run(init_op)
        else:
            saver.restore(sess, FLAGS.starting_snapshot)

        #check_numerics_op = tf.add_check_numerics_ops()

        # Start input enqueue threads.
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        step = 0
        loss_value = -1.0  # avoid a corner case where it is unset on error
        duration = 0.0  # ditto
        num_chk = None  # ditto
        #sess = tf_debug.LocalCLIDebugWrapperSession(sess)
        try:
            while not coord.should_stop():
                # Run training steps or whatever
                start_time = time.time()
                _, loss_value = sess.run([train_op, loss])
                duration = time.time() - start_time

                # Write the summaries and print an overview fairly often.
                if ((step + 1) % 100 == 0 or step < 10):
                    # Print status to stdout.
                    print('Step %d: numerics = %s, loss = %.2f (%.3f sec)' %
                          (step, num_chk, loss_value, duration))
                    # Update the events file.
                    summary_str = sess.run(summary_op)
                    summary_writer.add_summary(summary_str, step)
                    summary_writer.flush()

                # Save a checkpoint periodically.
                if (step + 1) % 1000 == 0:
                    # If log_dir is /tmp/cnn/ then checkpoints are saved in that
                    # directory, prefixed with 'cnn'.
                    saver.save(sess, FLAGS.log_dir + 'cnn', global_step=step)

                step += 1

        except tf.errors.OutOfRangeError as e:
            print('Done training -- epoch limit reached')
        finally:
            # When done, ask the threads to stop.
            coord.request_stop()
            print('Final Step %d: numerics = %s, loss = %.2f (%.3f sec)' %
                  (step, num_chk, loss_value, duration))
            summary_str = sess.run(summary_op, num_chk)
            summary_writer.add_summary(summary_str, step)
            summary_writer.flush()

        # Wait for threads to finish.
        coord.join(threads, stop_grace_period=10)
        sess.close()
Esempio n. 5
0
def train():
    """Train fish_cubes for a number of steps."""
    # Get the sets of images and labels for training, validation, and
    # test
    if FLAGS.num_epochs:
        num_epochs = FLAGS.num_epochs
    else:
        num_epochs = None

    # Track global step across multiple iterations.  This is updated in
    # the optimizer.
    with tf.variable_scope('control'):
        global_step = tf.get_variable('global_step',
                                      dtype=tf.int32,
                                      initializer=0,
                                      trainable=False)

    # seed provides the mechanism to control the shuffling which takes place reading input
    seed = tf.placeholder(tf.int64, shape=())

    # Generate placeholders for the images and labels.
    iterator = input_data.input_pipeline_binary(
        FLAGS.data_dir,
        FLAGS.batch_size,
        fake_data=FLAGS.fake_data,
        num_epochs=num_epochs,
        read_threads=FLAGS.read_threads,
        shuffle_size=FLAGS.shuffle_size,
        num_expected_examples=FLAGS.num_examples,
        seed=seed)
    image_path, label_path, images, labels = iterator.get_next()

    if FLAGS.verbose:
        print_op = tf.print("images and labels this batch: ", image_path,
                            label_path, labels)
    else:
        print_op = tf.constant('No printing')

    if FLAGS.random_rotation:
        images, labels = harmonics.apply_random_rotation(images, labels)

    # Build a Graph that computes predictions from the inference model.
    logits = topology.inference(images, FLAGS.network_pattern)

    # Add to the Graph the Ops for loss calculation.
    loss = topology.binary_loss(logits, labels)
    print('loss: ', loss)

    if FLAGS.check_numerics:
        if FLAGS.random_rotation:
            sys.exit('check_numerics is not compatible with random_rotation')
        check_numerics_op = tf.add_check_numerics_ops()
    else:
        check_numerics_op = tf.constant('not checked')

    var_pfx_map = {'cnn': 'cnn/', 'classifier': 'image_binary_classifier/'}

    if len(FLAGS.starting_snapshot):
        keys = FLAGS.snapshot_load.split(',') if FLAGS.snapshot_load else [
            'all'
        ]
        keys = [k.strip() for k in keys]
        if 'all' in keys:
            vars_to_load = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
        else:
            assert all([k in var_pfx_map
                        for k in keys]), 'unknown key to load: %s' % key
            vars_to_load = [global_step]
            for k in keys:
                vars_to_load.extend([
                    v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
                    if v.name.startswith(var_pfx_map[k])
                ])
        if FLAGS.reset_global_step:
            vars_to_load.remove(global_step)
    else:
        vars_to_load = []

    vars_to_hold_constant = []  # empty list means hold nothing constant
    if FLAGS.hold_constant is not None:
        keys = [k.strip() for k in FLAGS.hold_constant.split(',')]
        assert all([k in var_pfx_map
                    for k in keys]), 'unknown key to hold constant: %s' % key
        for k in keys:
            vars_to_hold_constant.extend([
                v for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
                if v.name.startswith(var_pfx_map[k])
            ])
    print('not subject to training: %s' %
          [v.name for v in vars_to_hold_constant])

    if FLAGS.starting_snapshot and len(FLAGS.starting_snapshot):
        vars_in_snapshot = [
            k for k in (pywrap_tensorflow.NewCheckpointReader(
                FLAGS.starting_snapshot).get_variable_to_shape_map())
        ]
    else:
        vars_in_snapshot = []
    vars_in_snapshot = set(vars_in_snapshot)
    print('vars in snapshot: %s' % vars_in_snapshot)

    if FLAGS.optimizer == 'Adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate,
                                           epsilon=0.1)
    elif FLAGS.optimizer == 'SGD':
        optimizer = tf.train.GradientDescentOptimizer(
            learning_rate=FLAGS.learning_rate)
    else:
        raise RuntimeError('Unimplemented optimizer %s was requested' %
                           FLAGS.optimizer)
    train_op = topology.training(loss,
                                 FLAGS.learning_rate,
                                 exclude=vars_to_hold_constant,
                                 optimizer=optimizer)

    # Also load any variables the optimizer created for variables we want to load
    vars_to_load.extend([
        optimizer.get_slot(var, name) for name in optimizer.get_slot_names()
        for var in vars_to_load
    ])
    vars_to_load = [var for var in vars_to_load if var is not None]
    vars_to_load = list(set(vars_to_load))  # remove duplicates

    # Filter vars to load based on what is in the checkpoint
    in_vars = []
    out_vars = []
    for var in vars_to_load:
        if get_cpt_name(var) in vars_in_snapshot:
            in_vars.append(var)
        else:
            out_vars.append(var)
    if out_vars:
        print(
            'WARNING: cannot load the following vars because they are not in the snapshot: %s'
            % [var.name for var in out_vars])
    if in_vars:
        print('loading from checkpoint: %s' % [var.name for var in in_vars])
        tf.train.init_from_checkpoint(
            FLAGS.starting_snapshot,
            {get_cpt_name(var): var
             for var in in_vars})

    # Try making histograms of *everything*
    for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
        if var.name.startswith('cnn') or var.name.startswith(
                'image_binary_classifier'):
            tf.summary.histogram(var.name, var)

    # Create a saver for writing training checkpoints.
    saver = tf.train.Saver(max_to_keep=10)

    # Build the summary operation based on the TF collection of Summaries.
    summary_op = tf.summary.merge_all()

    # Create a session for running operations in the Graph.
    sess = tf.Session(config=tf.ConfigProto(
        log_device_placement=FLAGS.verbose))

    # Create the graph, etc.
    # we either have no snapshot and must initialize everything, or we do have a snapshot
    # and have already set appropriate vars to be initialized from it
    init_op = tf.variables_initializer(
        tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
    sess.run(init_op)

    # Instantiate a SummaryWriter to output summaries and the Graph.
    summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)

    loss_value = -1.0  # avoid a corner case where it is unset on error
    duration = 0.0  # ditto
    num_chk = None  # ditto
    #sess = tf_debug.LocalCLIDebugWrapperSession(sess)
    # Loop through training epochs
    for epoch in range(num_epochs):
        try:
            sess.run(iterator.initializer, feed_dict={seed: epoch})
            saver.save(sess, FLAGS.log_dir + 'cnn', global_step=global_step)
            last_save_epoch = 0

            while True:
                # Run training steps or whatever
                start_time = time.time()
                _, loss_value, num_chk, _, gstp = sess.run(
                    [train_op, loss, check_numerics_op, print_op, global_step])
                duration = time.time() - start_time

                # Write the summaries and print an overview fairly often.
                if ((gstp + 1) % 100 == 0 or gstp < 10):
                    # Print status to stdout.
                    print(
                        'Global step %d epoch %d: numerics = %s, batch mean loss = %.2f (%.3f sec)'
                        % (gstp, epoch, num_chk, loss_value.mean(), duration))
                    # Update the events file.
                    summary_str = sess.run(summary_op)
                    summary_writer.add_summary(summary_str, gstp)
                    summary_writer.flush()

                # Save a checkpoint periodically.
                if (epoch + 1) % 100 == 0 and epoch != last_save_epoch:
                    # If log_dir is /tmp/cnn/ then checkpoints are saved in that
                    # directory, prefixed with 'cnn'.
                    print('saving checkpoint at global step %d, epoch %s' %
                          (gstp, epoch))
                    saver.save(sess,
                               FLAGS.log_dir + 'cnn',
                               global_step=global_step)
                    last_save_epoch = epoch

        except tf.errors.OutOfRangeError as e:
            print('Finished epoch {}'.format(epoch))


#         finally:
#             # When done, ask the threads to stop.
#             coord.request_stop()
#             print('Final Step %d: numerics = %s, loss = %.2f (%.3f sec)'
#                   % (step, num_chk, loss_value, duration))
#             summary_str = sess.run(summary_op, num_chk)
#             summary_writer.add_summary(summary_str, step)
#             summary_writer.flush()

# Wait for threads to finish.
#        coord.join(threads, stop_grace_period=10)

    print('Final Step %d: numerics = %s, batch mean loss = %.2f (%.3f sec)' %
          (gstp, num_chk, loss_value.mean(), duration))
    try:
        summary_str = sess.run(summary_op)
        summary_writer.add_summary(summary_str, step)
        summary_writer.flush()
    except tf.errors.OutOfRangeError as e:
        print('No final summary to write')

    sess.close()
Esempio n. 6
0
def train():
    """Train fish_cubes for a number of steps."""
    # Get the sets of images and labels for training, validation, and
    # test
    if FLAGS.num_epochs:
        num_epochs = FLAGS.num_epochs
    else:
        num_epochs = None

    # Tell TensorFlow that the model will be built into the default Graph.
    # I don't think this is necessary any more
    #with tf.Graph().as_default():

    # seed provides the mechanism to control the shuffling which takes place reading input
    seed = tf.placeholder(tf.int64, shape=())

    # Generate placeholders for the images and labels.
    iterator = input_data.input_pipeline(
        FLAGS.data_dir,
        FLAGS.batch_size,
        fake_data=FLAGS.fake_data,
        num_epochs=num_epochs,
        read_threads=FLAGS.read_threads,
        shuffle_size=FLAGS.shuffle_size,
        num_expected_examples=FLAGS.num_examples,
        seed=seed)
    image_path, label_path, images, labels = iterator.get_next()

    if FLAGS.verbose:
        print_op = tf.print("images and labels this batch: ", image_path,
                            label_path)
    else:
        print_op = tf.constant('No printing')

    # Build a Graph that computes predictions from the inference model.
    logits = topology.inference(images, FLAGS.network_pattern)

    # Add to the Graph the Ops for loss calculation.
    loss = topology.loss(logits, labels)

    # Add to the Graph the Ops that calculate and apply gradients.
    train_op = topology.training(tf.reduce_mean(loss), FLAGS.learning_rate)
    if FLAGS.check_numerics:
        check_numerics_op = tf.add_check_numerics_ops()
    else:
        check_numerics_op = tf.constant('not checked')

    # Build the summary operation based on the TF collection of Summaries.
    summary_op = tf.summary.merge_all()

    # Create the graph, etc.
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    # Create a saver for writing training checkpoints.
    saver = tf.train.Saver()

    # Create a session for running operations in the Graph.
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))

    # Instantiate a SummaryWriter to output summaries and the Graph.
    summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)

    # Initialize the variables (like the epoch counter).
    if len(FLAGS.starting_snapshot) == 0:
        sess.run(init_op)
    else:
        saver.restore(sess, FLAGS.starting_snapshot)

    # Start input enqueue threads.
    #coord = tf.train.Coordinator()

    # This isn't needed now that we no longer use input queues
    #threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    step = 0
    loss_value = -1.0  # avoid a corner case where it is unset on error
    duration = 0.0  # ditto
    num_chk = None  # ditto
    #sess = tf_debug.LocalCLIDebugWrapperSession(sess)
    # Loop through training epochs
    for epoch in range(num_epochs):
        try:
            #while not coord.should_stop():
            sess.run(iterator.initializer, feed_dict={seed: epoch})

            while True:
                # Run training steps or whatever
                start_time = time.time()
                _, loss_value, num_chk = sess.run(
                    [train_op, loss, check_numerics_op])
                duration = time.time() - start_time

                # Write the summaries and print an overview fairly often.
                if ((step + 1) % 100 == 0 or step < 10):
                    # Print status to stdout.
                    print(
                        'Step %d: numerics = %s, batch mean loss = %.2f (%.3f sec)'
                        % (step, num_chk, loss_value.mean(), duration))
                    # Update the events file.
                    summary_str = sess.run(summary_op)
                    summary_writer.add_summary(summary_str, step)
                    summary_writer.flush()

                # Save a checkpoint periodically.
                if (step + 1) % 4 == 0:
                    # If log_dir is /tmp/cnn/ then checkpoints are saved in that
                    # directory, prefixed with 'cnn'.
                    saver.save(sess, FLAGS.log_dir + 'cnn', global_step=step)

                step += 1

        except tf.errors.OutOfRangeError as e:
            print('Finished epoch {}'.format(epoch))


#         finally:
#             # When done, ask the threads to stop.
#             coord.request_stop()
#             print('Final Step %d: numerics = %s, loss = %.2f (%.3f sec)'
#                   % (step, num_chk, loss_value, duration))
#             summary_str = sess.run(summary_op, num_chk)
#             summary_writer.add_summary(summary_str, step)
#             summary_writer.flush()

# Wait for threads to finish.
#        coord.join(threads, stop_grace_period=10)

    print('Final Step %d: numerics = %s, batch mean loss = %.2f (%.3f sec)' %
          (step, num_chk, loss_value.mean(), duration))
    try:
        summary_str = sess.run(summary_op)
        summary_writer.add_summary(summary_str, step)
        summary_writer.flush()
    except tf.errors.OutOfRangeError as e:
        print('No final summary to write')

    sess.close()