def write_labels(flag, im_lists, config):
    # Write labels list
    make_dir(config.processed_image_patch_dir)
    label_list = os.path.join(
        config.processed_image_patch_dir,
        'list_of_' + '_'.join(x
                              for x in config.image_prefixes) + '_labels.txt')
    write_label_list(im_lists[flag], label_list)

    # Finally, write the labels file:
    labels_to_class_names = dict(
        zip(range(len(config.label_directories)), config.label_directories))
    write_label_file(labels_to_class_names, config.tfrecord_dir)
    return label_list
示例#2
0
def run():
    # Run build_image_data script
    print('Organizing files into tf records')

    # Make dirs if they do not exist
    config = GEDIconfig()
    dir_list = [
        config.train_directory, config.validation_directory,
        config.tfrecord_dir, config.train_checkpoint
    ]
    [make_dir(d) for d in dir_list]

    # Prepare lists with file pointers
    files = get_file_list(config.processed_image_patch_dir,
                          config.label_directories, config.im_ext)
    label_list = os.path.join(config.processed_image_patch_dir, 'list_of_' +
                              '_'.join(x for x in config.image_prefixes) +
                              '_labels.txt')  # to be created by prepare
    write_label_list(files, label_list)

    # Copy data into the appropriate training/testing directories
    hw = misc.imread(files[0]).shape
    new_files = split_files(files, config.train_proportion, config.tvt_flags)
    move_files(new_files['train'], config.train_directory)
    #process_image_data('train',new_files,config.tfrecord_dir,config.im_ext,config.train_shards,hw,config.normalize)
    simple_tf_records('train', new_files, config.tfrecord_dir, config.im_ext,
                      config.train_shards, hw, config.normalize)
    if 'val' in config.tvt_flags:
        move_files(new_files['val'], config.validation_directory)
        #process_image_data('val',new_files,config.tfrecord_dir,config.im_ext,config.train_shards,hw,config.normalize)
        simple_tf_records('val', new_files, config.tfrecord_dir, config.im_ext,
                          config.train_shards, hw, config.normalize)
    if 'test' in config.tvt_flags:
        move_files(new_files['test'], config.test_directory)
        #process_image_data('test',new_files,config.tfrecord_dir,config.im_ext,config.train_shards,hw,config.normalize)
        simple_tf_records('test', new_files, config.tfrecord_dir,
                          config.im_ext, config.train_shards, hw,
                          config.normalize)

    # Finally, write the labels file:
    labels_to_class_names = dict(
        zip(range(len(config.label_directories)), config.label_directories))
    write_label_file(labels_to_class_names, config.tfrecord_dir)
示例#3
0
def save_images(
        y,
        yhat,
        viz,
        files,
        output_folder,
        target,
        label_dict,
        ext='.png'):
    """Save TP/FP/TN/FN images in separate folders."""
    quality = ['true', 'false']
    folders = [[os.path.join(
        output_folder, '%s_%s' % (
            k, quality[0])), os.path.join(
        output_folder, '%s_%s' % (
            k, quality[1]))] for k in label_dict.keys()]
    flat_folders = flatten_list(folders)
    [make_dir(f) for f in flat_folders]
    for iy, iyhat, iviz, ifiles in zip(y, yhat, viz, files):
        correct = iy == iyhat
        target_label = iy == target
        f = plt.figure()
        plt.imshow(iviz.squeeze())
        it_f = ifiles.split('/')[-1].split('\.')[0]
        if correct and target_label:
            # TP
            it_folder = folders[0][0]
        elif correct and not target_label:
            # TN
            it_folder = folders[0][1]
        elif not correct and target_label:
            # FP
            it_folder = folders[1][0]
        elif not correct and not target_label:
            # FN
            it_folder = folders[1][1]
        plt.title('Predicted label=%s, true label=%s' % (iyhat, iy))
        plt.savefig(
            os.path.join(
                it_folder,
                '%s%s' % (it_f, ext)))
        plt.close(f)
def train_vgg16(train_dir=None, validation_dir=None):
    config = GEDIconfig()
    if train_dir is None:  # Use globals
        train_data = os.path.join(
            config.tfrecord_dir,
            config.tf_record_names['train'])
        meta_data = np.load(
            os.path.join(
                config.tfrecord_dir,
                '%s_%s' % (config.tvt_flags[0], config.max_file)))
    else:
        meta_data = np.load(
            os.path.join(
                train_dir,
                '%s_%s' % (config.tvt_flags[0], config.max_file)))

    # Prepare image normalization values
    if config.max_gedi is None:
        max_value = np.nanmax(meta_data['max_array']).astype(np.float32)
        if max_value == 0:
            max_value = None
            print 'Derived max value is 0'
        else:
            print 'Normalizing with empirical max.'
        if 'min_array' in meta_data.keys():
            min_value = np.min(meta_data['min_array']).astype(np.float32)
            print 'Normalizing with empirical min.'
        else:
            min_value = None
            print 'Not normalizing with a min.'
    else:
        max_value = config.max_gedi
        min_value = config.min_gedi
    ratio = meta_data['ratio']
    if config.encode_time_of_death:
        tod = pd.read_csv(config.encode_time_of_death)
        tod_data = tod['dead_tp'].as_matrix()
        mask = np.isnan(tod_data).astype(int) + (
            tod['plate_well_neuron'] == 'empty').as_matrix().astype(int)
        tod_data = tod_data[mask == 0]
        tod_data = tod_data[tod_data > config.mask_timepoint_value]
        config.output_shape = len(np.unique(tod_data))
        ratio = class_weight.compute_class_weight(
            'balanced',
            np.sort(np.unique(tod_data)),
            tod_data)
        flip_ratio = False
    else:
        flip_ratio = True
    print 'Ratio is: %s' % ratio

    if validation_dir is None:  # Use globals
        validation_data = os.path.join(
            config.tfrecord_dir,
            config.tf_record_names['val'])
    elif validation_dir is False:
        pass  # Do not use validation data during training

    # Make output directories if they do not exist
    dt_stamp = re.split(
        '\.', str(datetime.now()))[0].\
        replace(' ', '_').replace(':', '_').replace('-', '_')
    dt_dataset = config.which_dataset + '_' + dt_stamp + '/'
    config.train_checkpoint = os.path.join(
        config.train_checkpoint, dt_dataset)  # timestamp this run
    out_dir = os.path.join(config.results, dt_dataset)
    dir_list = [
        config.train_checkpoint, config.train_summaries,
        config.results, out_dir]
    [make_dir(d) for d in dir_list]
    # im_shape = get_image_size(config)
    im_shape = config.gedi_image_size

    print '-'*60
    print('Training model:' + dt_dataset)
    print '-'*60

    # Prepare data on CPU
    assert os.path.exists(train_data)
    assert os.path.exists(validation_data)
    assert os.path.exists(config.vgg16_weight_path)
    with tf.device('/cpu:0'):
        train_images, train_labels, train_gedi_images = inputs(
            train_data,
            config.train_batch,
            im_shape,
            config.model_image_size[:2],
            max_value=max_value,
            min_value=min_value,
            train=config.data_augmentations,
            num_epochs=config.epochs,
            normalize=config.normalize,
            return_gedi=config.include_GEDI_in_tfrecords,
            return_extra_gfp=config.extra_image,
            return_GEDI_derivative=True)
        val_images, val_labels, val_gedi_images = inputs(
            validation_data,
            config.validation_batch,
            im_shape,
            config.model_image_size[:2],
            max_value=max_value,
            min_value=min_value,
            num_epochs=config.epochs,
            normalize=config.normalize,
            return_gedi=config.include_GEDI_in_tfrecords,
            return_extra_gfp=config.extra_image,
            return_GEDI_derivative=True)
        if config.include_GEDI_in_tfrecords:
            extra_im_name = 'GEDI at current timepoint'
        else:
            extra_im_name = 'next gfp timepoint'
        tf.summary.image('train images', train_images)
        tf.summary.image('validation images', val_images)
        tf.summary.image('train %s' % extra_im_name, train_gedi_images)
        tf.summary.image('validation %s' % extra_im_name, val_gedi_images)

    # Prepare model on GPU
    with tf.device('/gpu:0'):
        with tf.variable_scope('cnn') as scope:
            if config.ordinal_classification is None:
                vgg_output = 2  # Sign of derivative (inf norm)
                train_labels = tf.cast(tf.sign(train_labels), tf.int32)
                val_labels = tf.cast(tf.sign(val_labels), tf.int32)
            elif config.ordinal_classification == 'regression':
                vgg_output = 1
            else:
                raise RuntimeError(
                    'config.ordinal_classification must be sign or regression.'
                    )
            vgg = vgg16.model_struct()
            train_mode = tf.get_variable(name='training', initializer=True)

            # Mask NAN images from loss
            image_nan = tf.reduce_sum(
                tf.cast(tf.is_nan(train_images), tf.float32),
                reduction_indices=[1, 2, 3])
            gedi_nan = tf.reduce_sum(
                tf.cast(tf.is_nan(train_gedi_images), tf.float32),
                reduction_indices=[1, 2, 3],
                keep_dims=True)
            image_mask = tf.cast(tf.equal(image_nan, 0.), tf.float32)
            gedi_nan = tf.cast(tf.equal(gedi_nan, 0.), tf.float32)
            train_images = tf.where(
                tf.is_nan(train_images),
                tf.zeros_like(train_images),
                train_images)
            train_gedi_images = tf.where(
                tf.is_nan(train_gedi_images),
                tf.zeros_like(train_gedi_images),
                train_gedi_images)
            train_images = tf.concat([train_images, train_images, train_images], axis=3)
            val_images = tf.concat([val_images, val_images, val_images], axis=3)
            vgg.build(
                train_images, output_shape=vgg_output,
                train_mode=train_mode, batchnorm=config.batchnorm_layers)
            # Prepare the cost function
            if config.ordinal_classification is None:
                # Encode y w/ k-hot and yhat w/ sigmoid ce. units capture dist.
                cost = softmax_cost(
                    vgg.fc8,
                    train_labels,
                    mask=image_mask)
            elif config.ordinal_classification == 'regression':
                cost = tf.nn.l2_loss(tf.squeeze(vgg.fc8) - train_labels)

            class_loss = cost
            tf.summary.scalar("cce cost", cost)

            # Weight decay
            if config.wd_layers is not None:
                _, l2_wd_layers = fine_tune_prepare_layers(
                    tf.trainable_variables(), config.wd_layers)
                l2_wd_layers = [
                    x for x in l2_wd_layers if 'biases' not in x.name]
                if len(l2_wd_layers) > 0:
                    cost += (config.wd_penalty * tf.add_n(
                        [tf.nn.l2_loss(x) for x in l2_wd_layers]))

            # Optimize
            train_op = tf.train.AdamOptimizer(config.new_lr).minimize(cost)
            if config.ordinal_classification is None:
                train_accuracy = class_accuracy(
                    vgg.prob, train_labels)  # training accuracy
            elif config.ordinal_classification == 'regression':
                train_accuracy = tf.nn.l2_loss(
                    tf.squeeze(vgg.fc8) - train_labels)
            tf.summary.scalar("training accuracy", train_accuracy)

            # Setup validation op
            if validation_data is not False:
                scope.reuse_variables()
                # Validation graph is the same as training except no batchnorm
                val_vgg = vgg16.model_struct(
                    fine_tune_layers=config.fine_tune_layers)
                val_vgg.build(val_images, output_shape=vgg_output)
                # Calculate validation accuracy
                if config.ordinal_classification is None:
                    val_accuracy = class_accuracy(val_vgg.prob, val_labels)
                elif config.ordinal_classification == 'regression':
                    val_accuracy = tf.nn.l2_loss(tf.squeeze(val_vgg.fc8) - val_labels)
                tf.summary.scalar("validation accuracy", val_accuracy)

    # Set up summaries and saver
    saver = tf.train.Saver(
        tf.global_variables(), max_to_keep=config.keep_checkpoints)
    summary_op = tf.summary.merge_all()

    # Initialize the graph
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    # Need to initialize both of these if supplying num_epochs to inputs
    sess.run(tf.group(tf.global_variables_initializer(),
             tf.local_variables_initializer()))
    summary_dir = os.path.join(
        config.train_summaries, config.which_dataset + '_' + dt_stamp)
    summary_writer = tf.summary.FileWriter(summary_dir, sess.graph)

    # Set up exemplar threading
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    # Restore model if requested
    if config.restore_path is not None:
        print '-' * 60
        print 'Restoring from a previous model: %s' % config.restore_path
        print '-' * 60
        saver.restore(sess, config.restore_path)

    # Start training loop
    np.save(out_dir + 'meta_info', config)
    step, losses = 0, []  # val_max = 0

    try:
        # print response
        while not coord.should_stop():
            start_time = time.time()
            _, loss_value, train_acc = sess.run(
                [train_op, cost, train_accuracy])
            losses.append(loss_value)
            duration = time.time() - start_time
            if np.isnan(loss_value).sum():
                import ipdb;ipdb.set_trace()
            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
            if step % config.validation_steps == 0:
                if validation_data is not False:
                    _, val_acc = sess.run([train_op, val_accuracy])
                else:
                    val_acc -= 1  # Store every checkpoint

                # Summaries
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

                # Training status and validation accuracy
                format_str = (
                    '%s: step %d, loss = %.2f (%.1f examples/sec; '
                    '%.3f sec/batch) | Training accuracy = %s | '
                    'Training %s = %s | Training class loss = %s | '
                    'Validation accuracy = %s | Validation %s = %s | '
                    'logdir = %s')
                print (format_str % (
                    datetime.now(), step, loss_value,
                    config.train_batch / duration, float(duration),
                    train_acc, extra_im_name, 0.,
                    0., val_acc, extra_im_name,
                    0., summary_dir))

                # Save the model checkpoint if it's the best yet
                if 1:  # val_acc >= val_max:
                    saver.save(
                        sess, os.path.join(
                            config.train_checkpoint,
                            'model_' + str(step) + '.ckpt'), global_step=step)
                    # Store the new max validation accuracy
                    # val_max = val_acc

            else:
                # Training status
                format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; '
                              '%.3f sec/batch) | Training accuracy = %s | '
                              'Training %s = %s | Training class loss = %s')
                print (format_str % (datetime.now(), step, loss_value,
                                     config.train_batch / duration,
                                     float(duration), train_acc,
                                     extra_im_name, 0.,
                                     0.))
            # End iteration
            step += 1

    except tf.errors.OutOfRangeError:
        print('Done training for %d epochs, %d steps.' % (config.epochs, step))
    finally:
        coord.request_stop()
        np.save(os.path.join(config.tfrecord_dir, 'training_loss'), losses)
    coord.join(threads)
    sess.close()
示例#5
0
def extract_tf_records_from_GEDI_tiffs():
    """Extracts data directly from GEDI tiffs and
    inserts them into tf records. This allows us to
    offload the normalization procedure to either
    right before training (via sampling) or during
    training (via normalization with a batch's max)"""

    # Grab the global config
    config = GEDIconfig()

    # Make dirs if they do not exist
    dir_list = [
        config.train_directory, config.validation_directory,
        config.tfrecord_dir, config.train_checkpoint
    ]
    if 'test' in config.tvt_flags:
        dir_list += [config.test_directory]
        config.raw_im_dirs = [x + '_train' for x in config.raw_im_dirs]
        config.raw_im_dirs += [
            x.split('_train')[0] + '_test' for x in config.raw_im_dirs
        ]
    [make_dir(d) for d in dir_list]

    print('raw_im_dirs', config.raw_im_dirs)
    # gather file names of images to process
    im_lists = flatten_list([
        glob(os.path.join(config.home_dir, r, '*' + config.raw_im_ext))
        for r in config.raw_im_dirs
    ])
    print('im_lists', im_lists[:3])

    # Write labels list
    label_list = os.path.join(
        config.processed_image_patch_dir,
        'list_of_' + '_'.join(x
                              for x in config.image_prefixes) + '_labels.txt')
    print('label_list', label_list)
    write_label_list(im_lists, label_list)

    # Finally, write the labels file:
    labels_to_class_names = dict(
        zip(range(len(config.label_directories)), config.label_directories))
    write_label_file(labels_to_class_names, config.tfrecord_dir)

    # Copy data into the appropriate training/testing directories
    if 'test' in config.tvt_flags:
        new_files = split_files(im_lists, config.train_proportion,
                                config.tvt_flags)
    else:
        new_files = split_files(im_lists, config.train_proportion,
                                config.tvt_flags)

    if type(config.tvt_flags) is str:
        files = new_files[config.tvt_flags]
        label_list = new_files[config.tvt_flags + '_labels']
        output_pointer = os.path.join(config.tfrecord_dir,
                                      config.tvt_flags + '.tfrecords')
        extract_to_tf_records(files=files,
                              label_list=label_list,
                              output_pointer=output_pointer,
                              ratio_list=None,
                              config=config,
                              k=config.tvt_flags)
    else:
        for k in config.tvt_flags:
            files = new_files[k]
            label_list = new_files[k + '_labels']
            print('GEDI label list', label_list)
            output_pointer = os.path.join(config.tfrecord_dir,
                                          k + '.tfrecords')
            extract_to_tf_records(files=files,
                                  label_list=label_list,
                                  output_pointer=output_pointer,
                                  ratio_list=None,
                                  config=config,
                                  k=k)
示例#6
0
def train_vgg16(train_dir=None, validation_dir=None):
    config = GEDIconfig()
    if train_dir is None:  # Use globals
        train_data = os.path.join(config.tfrecord_dir, 'train.tfrecords')
        meta_data = np.load(
            os.path.join(
                config.tfrecord_dir, config.tvt_flags[0] + '_' +
                config.max_file))
    else:
        meta_data = np.load(
            os.path.join(
                train_dir, config.tvt_flags[0] + '_' + config.max_file))

    # Prepare image normalization values
    max_value = np.nanmax(meta_data['max_array']).astype(np.float32)
    if max_value == 0:
        max_value = None
        print 'Derived max value is 0'
    else:
        print 'Normalizing with empirical max.'
    if 'min_array' in meta_data.keys():
        min_value = np.min(meta_data['min_array']).astype(np.float32)
        print 'Normalizing with empirical min.'
    else:
        min_value = None
        print 'Not normalizing with a min.'
    ratio = meta_data['ratio']
    print 'Ratio is: %s' % ratio

    if validation_dir is None:  # Use globals
        validation_data = os.path.join(config.tfrecord_dir, 'val.tfrecords')
    elif validation_dir is False:
        pass  # Do not use validation data during training

    # Make output directories if they do not exist
    dt_stamp = re.split(
        '\.', str(datetime.now()))[0].\
        replace(' ', '_').replace(':', '_').replace('-', '_')
    dt_dataset = config.which_dataset + '_' + dt_stamp + '/'
    config.train_checkpoint = os.path.join(
        config.train_checkpoint, dt_dataset)  # timestamp this run
    out_dir = os.path.join(config.results, dt_dataset)
    dir_list = [
        config.train_checkpoint, config.train_summaries,
        config.results, out_dir]
    [make_dir(d) for d in dir_list]
    # im_shape = get_image_size(config)
    im_shape = config.gedi_image_size

    print '-'*60
    print('Training model:' + dt_dataset)
    print '-'*60

    # Prepare data on CPU
    with tf.device('/cpu:0'):
        train_images, train_labels = inputs(
            train_data,
            config.train_batch,
            im_shape,
            config.model_image_size[:2],
            max_value=max_value,
            min_value=min_value,
            train=config.data_augmentations,
            num_epochs=config.epochs,
            normalize=config.normalize)
        val_images, val_labels = inputs(
            validation_data,
            config.validation_batch,
            im_shape,
            config.model_image_size[:2],
            max_value=max_value,
            min_value=min_value,
            num_epochs=config.epochs,
            normalize=config.normalize)
        tf.summary.image('train images', train_images)
        tf.summary.image('validation images', val_images)

    # Prepare model on GPU
    num_timepoints = len(config.channel)
    with tf.device('/gpu:0'):
        with tf.variable_scope('cnn') as scope:

            # Prepare the loss function
            if config.balance_cost:
                cost_fun = lambda yhat, y: softmax_cost(yhat, y, ratio=ratio)
            else:
                cost_fun = lambda yhat, y: softmax_cost(yhat, y)

            def cond(i, state, images, output, loss, num_timepoints):  # NOT CORRECT
                return tf.less(i, num_timepoints)

            def body(
                    i,
                    images,
                    label,
                    cell,
                    state,
                    output,
                    loss,
                    vgg,
                    train_mode,
                    output_shape,
                    batchnorm_layers,
                    cost_fun,
                    score_layer='fc7'):
                vgg.build(
                    images[i],
                    output_shape=config.output_shape,
                    train_mode=train_mode,
                    batchnorm=config.batchnorm_layers)
                    it_output, state = cell(vgg[score_layer], state)
                    output= output.write(i, it_output)
                    it_loss = cost_fun(output, label)
                    loss = loss.write(i, it_loss)
                return (i+1, images, label, cell, state, output, loss, vgg, train_mode, output_shape, batchnorm_layers) 

            # Prepare LSTM loop
            train_mode = tf.get_variable(name='training', initializer=True)
            cell = lstm_layers(layer_units=config.lstm_units)
            output = tf.TensorArray(tf.float32, num_timepoints) # output array for LSTM loop -- one slot per timepoint
            loss = tf.TensorArray(tf.float32, num_timepoints)  # loss for LSTM loop -- one slot per timepoint
            i = tf.constant(0)  # iterator for LSTM loop
            loop_vars = [i, images, label, cell, state, output, loss, vgg, train_mode, output_shape, batchnorm_layers, cost_fun]

            # Run the lstm
            processed_list = tf.while_loop(
                cond=cond,
                body=body,
                loop_vars=loop_vars,
                back_prop=True,
                swap_memory=False)
            output_cell = processed_list[3]
            output_state = processed_list[4]
            output_activity = processed_list[5]
            cost = processed_list[6]

            # Optimize
            combined_cost = tf.reduce_sum(cost)
            tf.summary.scalar('cost', combined_cost)
            import ipdb;ipdb.set_trace()  # Need to make sure lstm weights are being trained
            other_opt_vars, ft_opt_vars = fine_tune_prepare_layers(
                tf.trainable_variables(), config.fine_tune_layers)
            if config.optimizer == 'adam':
                train_op = ft_non_optimized(
                    cost, other_opt_vars, ft_opt_vars,
                    tf.train.AdamOptimizer, config.hold_lr, config.new_lr)
            elif config.optimizer == 'sgd':
                train_op = ft_non_optimized(
                    cost, other_opt_vars, ft_opt_vars,
                    tf.train.GradientDescentOptimizer,
                    config.hold_lr, config.new_lr)

            train_accuracy = class_accuracy(
                vgg.prob, train_labels)  # training accuracy
            tf.summary.scalar("training accuracy", train_accuracy)

            # Setup validation op
            if validation_data is not False:
                scope.reuse_variables()
                # Validation graph is the same as training except no batchnorm
                val_vgg = vgg16.Vgg16(
                    vgg16_npy_path=config.vgg16_weight_path,
                    fine_tune_layers=config.fine_tune_layers)
                val_vgg.build(val_images, output_shape=config.output_shape)
                # Calculate validation accuracy
                val_accuracy = class_accuracy(val_vgg.prob, val_labels)
                tf.summary.scalar("validation accuracy", val_accuracy)
示例#7
0
def test_vgg16(validation_data, model_dir, label_file, selected_ckpts=-1):
    config = GEDIconfig()

    # Load metas
    meta_data = np.load(os.path.join(tf_dir, 'val_maximum_value.npz'))
    max_value = np.max(meta_data['max_array']).astype(np.float32)

    # Find model checkpoints
    ckpts, ckpt_names = find_ckpts(config, model_dir)
    # ds_dt_stamp = re.split('/', ckpts[0])[-2]
    out_dir = os.path.join(config.results, 'gfp_2017_02_19_17_41_19' + '/')
    try:
        config = np.load(os.path.join(out_dir, 'meta_info.npy')).item()
        # Make sure this is always at 1
        config.validation_batch = 64
        print '-' * 60
        print 'Loading config meta data for:%s' % out_dir
        print '-' * 60
    except:
        print '-' * 60
        print 'Using config from gedi_config.py for model:%s' % out_dir
        print '-' * 60

    sorted_index = np.argsort(np.asarray([int(x) for x in ckpt_names]))
    ckpts = ckpts[sorted_index]
    ckpt_names = ckpt_names[sorted_index]

    # CSV file
    svm_image_file = os.path.join(out_dir, 'svm_models.npz')
    if svm_image_file == 2:
        svm_image_data = np.load(svm_image_file)
        image_array = svm_image_data['image_array']
        label_vec = svm_image_data['label_vec']
        tr_label_vec = svm_image_data['tr_label_vec']
    else:
        labels = pd.read_csv(
            os.path.join(config.processed_image_patch_dir,
                         'LINCSproject_platelayout_trans.csv'))
        label_vec = []
        image_array = []
        for idx, row in tqdm(labels.iterrows(), total=len(labels)):
            path_wd = '*%s_%s*' % (row['Plate'], row['Sci_WellID'])
            path_pointer = glob(os.path.join(image_dir, path_wd))
            if len(path_pointer) > 0:
                for p in path_pointer:
                    import ipdb
                    ipdb.set_trace()
                    label_vec.append(row['Sci_SampleID'])
        label_vec = np.asarray(label_vec)
        le = preprocessing.LabelEncoder()
        tr_label_vec = le.fit_transform(label_vec)
        np.savez(svm_image_file,
                 image_array=image_array,
                 label_vec=label_vec,
                 tr_label_vec=tr_label_vec)

    # Make output directories if they do not exist
    dir_list = [config.results, out_dir]
    [make_dir(d) for d in dir_list]

    # Make placeholder
    val_images = tf.placeholder(tf.float32,
                                shape=[None] + config.model_image_size)

    # Prepare model on GPU
    with tf.device('/gpu:0'):
        with tf.variable_scope('cnn'):
            vgg = vgg16.Vgg16(vgg16_npy_path=config.vgg16_weight_path,
                              fine_tune_layers=config.fine_tune_layers)
            validation_mode = tf.Variable(False, name='training')
            # No batchnorms durign testing
            vgg.build(val_images,
                      output_shape=config.output_shape,
                      train_mode=validation_mode)

    # Set up saver
    svm_feature_file = os.path.join(out_dir, 'svm_scores.npz')
    if os.path.exists(svm_feature_file):
        svm_features = np.load(svm_feature_file)
        dec_scores = svm_features['dec_scores']
        label_vec = svm_features['label_vec']
    else:
        saver = tf.train.Saver(tf.global_variables())
        ckpts = [ckpts[selected_ckpts]]
        image_array = np.asarray(image_array)
        for idx, c in enumerate(ckpts):
            dec_scores = []
            # Initialize the graph
            sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
            sess.run(
                tf.group(tf.global_variables_initializer(),
                         tf.local_variables_initializer()))

            # Set up exemplar threading
            saver.restore(sess, c)
            num_batches = np.ceil(len(image_array) /
                                  config.validation_batch).astype(int)
            batch_idx = np.arange(num_batches).repeat(
                num_batches)[:len(image_array)]
            for bi in np.unique(batch_idx):
                # move this above to image processing
                batch_images = image_array[batch_idx == bi] / 255.
                start_time = time.time()
                sc = sess.run(vgg.fc7, feed_dict={val_images: batch_images})
                dec_scores.append(sc)
                print 'Batch %d took %.1f seconds' % (idx,
                                                      time.time() - start_time)

    # Save everything
    np.savez(svm_feature_file, dec_scores=dec_scores, label_vec=label_vec)

    # Build SVM
    dec_scores = np.concatenate(dec_scores[:], axis=0)
    model_array, score_array, combo_array, masked_label_array = [], [], [], []
    for combo in itertools.combinations(np.unique(label_vec), 2):
        combo_array.append(combo)
        mask = np.logical_or(label_vec == combo[0], label_vec == combo[1])
        import ipdb
        ipdb.set_trace()
        masked_labels = label_vec[mask]
        masked_scores = dec_scores[mask, :]
        clf = SVC(kernel='linear', C=1)
        scores = cross_val_score(clf, masked_scores, masked_labels, cv=5)
        model_array.append(clf)
        score_array.append(scores)
        masked_label_array.append(masked_labels)
    print("Accuracy: %0.2f (+/- %0.2f)" % (scores.mean(), scores.std() * 2))

    # Save everything
    np.savez(os.path.join(out_dir, 'svm_models'),
             combo_array=combo_array,
             model_array=model_array,
             score_array=score_array,
             masked_label_array=masked_label_array)
sys.path.insert(
    0,
    re.split(__file__,
             os.path.realpath(__file__))[0])  # puts this experiment into path
from gedi_config import GEDIconfig
from glob import glob
from exp_ops.tf_fun import make_dir
from exp_ops.preprocessing_GEDI_images import produce_patches_parallel, produce_patches

# Grab the global config
config = GEDIconfig()
USE_PARALLEL = False

# Process images
[make_dir(os.path.join(config.home_dir, d))
 for d in config.output_dirs]  # Make directories for processed images
im_lists = [
    glob(os.path.join(
        config.home_dir, r,
        '*' + config.raw_im_ext))  # gather file names of images to process
    for r in config.raw_im_dirs
]

if USE_PARALLEL:
    produce_patches_parallel(config, im_lists)  # multithread processing
else:
    [
        produce_patches(p,
                        config.channel,
                        config.panel,
def test_vgg16(live_ims,
               dead_ims,
               model_file,
               svm_model='svm_model',
               output_csv='prediction_file',
               training_max=None,
               C=1e-3,
               k_folds=10):
    """Train an SVM for your dataset on GEDI-model encodings."""
    config = GEDIconfig()
    if live_ims is None:
        raise RuntimeError(
            'You need to supply a directory path to the live images.')
    if dead_ims is None:
        raise RuntimeError(
            'You need to supply a directory path to the dead images.')

    live_files = glob(os.path.join(live_ims, '*%s' % config.raw_im_ext))
    dead_files = glob(os.path.join(dead_ims, '*%s' % config.raw_im_ext))
    combined_labels = np.concatenate(
        (np.zeros(len(live_files)), np.ones(len(dead_files))))
    combined_files = np.concatenate((live_files, dead_files))
    if len(combined_files) == 0:
        raise RuntimeError('Could not find any files. Check your image path.')

    config = GEDIconfig()
    model_file_path = os.path.sep.join(model_file.split(os.path.sep)[:-1])
    meta_file_pointer = os.path.join(model_file_path,
                                     'train_maximum_value.npz')
    if not os.path.exists(meta_file_pointer):
        raise RuntimeError(
            'Cannot find the training data meta file: train_maximum_value.npz'
            'Closest I could find from directory %s was %s.'
            'Download this from the link described in the README.md.' %
            (model_file_path, glob(os.path.join(model_file_path, '*.npz'))))
    meta_data = np.load(meta_file_pointer)

    # Prepare image normalization values
    if training_max is None:
        training_max = np.max(meta_data['max_array']).astype(np.float32)
    training_min = np.min(meta_data['min_array']).astype(np.float32)

    # Find model checkpoints
    ds_dt_stamp = re.split('/', model_file)[-2]
    out_dir = os.path.join(config.results, ds_dt_stamp)

    # Make output directories if they do not exist
    dir_list = [config.results, out_dir]
    [make_dir(d) for d in dir_list]

    # Prepare data on CPU
    images = tf.placeholder(tf.float32,
                            shape=[None] + config.model_image_size,
                            name='images')

    # Prepare model on GPU
    with tf.device('/gpu:0'):
        with tf.variable_scope('cnn'):
            vgg = vgg16.model_struct(vgg16_npy_path=config.vgg16_weight_path,
                                     fine_tune_layers=config.fine_tune_layers)
            vgg.build(images, output_shape=config.output_shape)

        # Setup validation op
        scores = vgg.fc7
        preds = tf.argmax(vgg.prob, 1)

    # Set up saver
    saver = tf.train.Saver(tf.global_variables())

    # Loop through each checkpoint then test the entire validation set
    ckpts = [model_file]
    ckpt_yhat, ckpt_y, ckpt_scores, ckpt_file_array = [], [], [], []
    print '-' * 60
    print 'Beginning evaluation'
    print '-' * 60

    if config.validation_batch > len(combined_files):
        print 'Trimming validation_batch size to %s (same as # of files).' % len(
            combined_files)
        config.validation_batch = len(combined_files)

    for idx, c in tqdm(enumerate(ckpts), desc='Running checkpoints'):
        dec_scores, yhat, y, file_array = [], [], [], []
        # Initialize the graph
        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
        sess.run(
            tf.group(tf.global_variables_initializer(),
                     tf.local_variables_initializer()))

        # Set up exemplar threading
        saver.restore(sess, c)
        start_time = time.time()
        num_batches = np.floor(
            len(combined_files) / float(config.validation_batch)).astype(int)
        for image_batch, label_batch, file_batch in tqdm(image_batcher(
                start=0,
                num_batches=num_batches,
                images=combined_files,
                labels=combined_labels,
                config=config,
                training_max=training_max,
                training_min=training_min),
                                                         total=num_batches):
            feed_dict = {images: image_batch}
            sc, tyh = sess.run([scores, preds], feed_dict=feed_dict)
            dec_scores += [sc]
            yhat = np.append(yhat, tyh)
            y = np.append(y, label_batch)
            file_array = np.append(file_array, file_batch)
        ckpt_yhat.append(yhat)
        ckpt_y.append(y)
        ckpt_scores.append(dec_scores)
        ckpt_file_array.append(file_array)
        print 'Batch %d took %.1f seconds' % (idx, time.time() - start_time)
    sess.close()

    # Save everything
    np.savez(os.path.join(out_dir, 'validation_accuracies'),
             ckpt_yhat=ckpt_yhat,
             ckpt_y=ckpt_y,
             ckpt_scores=ckpt_scores,
             ckpt_names=ckpts,
             combined_files=ckpt_file_array)

    # Run SVM
    svm = LinearSVC(C=C, dual=False, class_weight='balanced')
    clf = make_pipeline(preprocessing.StandardScaler(), svm)
    predictions = cross_val_predict(clf,
                                    np.concatenate(dec_scores),
                                    y,
                                    cv=k_folds)
    cv_performance = metrics.accuracy_score(predictions, y)
    p_value = randomization_test(y=y, yhat=predictions)
    clf.fit(np.concatenate(dec_scores), y)
    # mu = dec_scores.mean(0)
    # sd = dec_scores.std(0)
    print '%s-fold SVM performance: accuracy = %s%% , p = %.5f' % (
        k_folds, np.mean(cv_performance * 100), p_value)
    np.savez(
        os.path.join(out_dir, 'svm_data'),
        yhat=predictions,
        y=y,
        scores=dec_scores,
        ckpts=ckpts,
        cv_performance=cv_performance,
        p_value=p_value,
        k_folds=k_folds,
        # mu=mu,
        # sd=sd,
        C=C)

    # Also save a csv with item/guess pairs
    try:
        trimmed_files = [re.split('/', x)[-1] for x in combined_files]
        trimmed_files = np.asarray(trimmed_files)
        dec_scores = np.asarray(dec_scores)
        yhat = np.asarray(yhat)
        df = pd.DataFrame(np.hstack(
            (trimmed_files.reshape(-1, 1), yhat.reshape(-1,
                                                        1), y.reshape(-1, 1))),
                          columns=['files', 'guesses', 'true label'])
        df.to_csv(os.path.join(out_dir, 'prediction_file.csv'))
        print 'Saved csv to: %s' % out_dir
    except:
        print 'X' * 60
        print 'Could not save a spreadsheet of file info'
        print 'X' * 60

    # save the classifier
    with open('%s.pkl' % svm_model, 'wb') as fid:
        # model_dict = {
        #     'model': clf,
        #     'mu': mu,
        #     'sd': sd
        # }
        cPickle.dump(clf, fid)
    print 'Saved svm model to: %s.pkl' % svm_model
示例#10
0
def test_vgg16(image_dir,
               model_file,
               output_csv='prediction_file',
               training_max=None):
    print(image_dir)
    #    tf.set_random_seed(0)
    config = GEDIconfig()
    if image_dir is None:
        raise RuntimeError(
            'You need to supply a directory path to the images.')

    combined_files = np.asarray(
        glob(os.path.join(image_dir, '*%s' % config.raw_im_ext)))
    if len(combined_files) == 0:
        raise RuntimeError('Could not find any files. Check your image path.')

    config = GEDIconfig()
    model_file_path = os.path.sep.join(model_file.split(os.path.sep)[:-1])
    print('model file path', model_file_path)
    meta_file_pointer = os.path.join(model_file_path,
                                     'train_maximum_value.npz')
    if not os.path.exists(meta_file_pointer):
        raise RuntimeError(
            'Cannot find the training data meta file: train_maximum_value.npz'
            'Closest I could find from directory %s was %s.'
            'Download this from the link described in the README.md.' %
            (model_file_path, glob(os.path.join(model_file_path, '*.npz'))))
    meta_data = np.load(meta_file_pointer)

    # Prepare image normalization values
    if training_max is None:
        training_max = np.max(meta_data['max_array']).astype(np.float32)
    training_min = np.min(meta_data['min_array']).astype(np.float32)

    # Find model checkpoints
    ds_dt_stamp = re.split('/', model_file)[-2]
    out_dir = os.path.join(config.results, ds_dt_stamp)
    print('out_dir', out_dir)

    # Make output directories if they do not exist
    dir_list = [config.results, out_dir]
    [make_dir(d) for d in dir_list]

    # Prepare data on CPU
    if config.model_image_size[-1] < 3:
        print('*' * 60)
        print('Warning: model is expecting a H/W/1 image. '
              'Do you mean to set the last dimension of '
              'config.model_image_size to 3?')
        print('*' * 60)

    images = tf.placeholder(tf.float32,
                            shape=[None] + config.model_image_size,
                            name='images')

    # Prepare model on GPU
    with tf.device('/gpu:0'):
        with tf.variable_scope('cnn'):
            vgg = vgg16.model_struct(vgg16_npy_path=config.vgg16_weight_path,
                                     fine_tune_layers=config.fine_tune_layers)
            vgg.build(images, output_shape=config.output_shape)

        # Setup validation op
        scores = vgg.prob
        preds = tf.argmax(vgg.prob, 1)

    # Set up saver
    saver = tf.train.Saver(tf.global_variables())

    # Loop through each checkpoint then test the entire validation set
    ckpts = [model_file]
    ckpt_yhat, ckpt_y, ckpt_scores, ckpt_file_array = [], [], [], []
    print('-' * 60)
    print('Beginning evaluation')
    print('-' * 60)

    if config.validation_batch > len(combined_files):
        print('Trimming validation_batch size to %s (same as # of files).' %
              len(combined_files))
        config.validation_batch = len(combined_files)

    for idx, c in tqdm(enumerate(ckpts), desc='Running checkpoints'):
        dec_scores, yhat, file_array = [], [], []
        # Initialize the graph

        #        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))

        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            sess.run(
                tf.group(tf.global_variables_initializer(),
                         tf.local_variables_initializer()))

            # Set up exemplar threading
            saver.restore(sess, c)
            start_time = time.time()
            num_batches = np.floor(
                len(combined_files) /
                float(config.validation_batch)).astype(int)
            for image_batch, file_batch in tqdm(image_batcher(
                    start=0,
                    num_batches=num_batches,
                    images=combined_files,
                    config=config,
                    training_max=training_max,
                    training_min=training_min),
                                                total=num_batches):
                feed_dict = {images: image_batch}
                sc, tyh = sess.run([scores, preds], feed_dict=feed_dict)
                dec_scores = np.append(dec_scores, sc)
                yhat = np.append(yhat, tyh)
                file_array = np.append(file_array, file_batch)
            ckpt_yhat.append(yhat)
            ckpt_scores.append(dec_scores)
            ckpt_file_array.append(file_array)
            print('Batch %d took %.1f seconds' %
                  (idx, time.time() - start_time))
#    sess.close()

# Save everything
    print('Save npz.')
    print(os.path.join(out_dir, 'validation_accuracies'))
    np.savez(os.path.join(out_dir, 'validation_accuracies'),
             ckpt_yhat=ckpt_yhat,
             ckpt_scores=ckpt_scores,
             ckpt_names=ckpts,
             combined_files=ckpt_file_array)

    # Also save a csv with item/guess pairs
    try:
        dec_scores = np.asarray(dec_scores)
        yhat = np.asarray(yhat)
        df = pd.DataFrame(np.hstack(
            (np.asarray(ckpt_file_array).reshape(-1, 1), yhat.reshape(-1, 1),
             dec_scores.reshape(dec_scores.shape[0] // 2, 2))),
                          columns=[
                              'files', 'live_guesses', 'classifier score dead',
                              'classifier score live'
                          ])
        output_name = image_dir.split('/')[-1]
        if output_name is None or len(output_name) == 0:
            output_name = 'output'
        df.to_csv(os.path.join(out_dir, '%s.csv' % output_name))
        print('Saved csv to: %s' %
              os.path.join(out_dir, '%s.csv' % output_name))
    except:
        print('X' * 60)
        print('Could not save a spreadsheet of file info')
        print('X' * 60)

    # Plot everything
    try:
        plot_accuracies(ckpt_y, ckpt_yhat, config, ckpts,
                        os.path.join(out_dir, 'validation_accuracies.png'))
        plot_std(ckpt_y, ckpt_yhat, ckpts,
                 os.path.join(out_dir, 'validation_stds.png'))
        plot_cms(ckpt_y, ckpt_yhat, config,
                 os.path.join(out_dir, 'confusion_matrix.png'))
        plot_pr(ckpt_y, ckpt_yhat, ckpt_scores,
                os.path.join(out_dir, 'precision_recall.png'))


#        plot_cost(
#            os.path.join(out_dir, 'training_loss.npy'), ckpts,
#            os.path.join(out_dir, 'training_costs.png'))
    except:
        print('X' * 60)
        print('Could not locate the loss numpy')
        print('X' * 60)
def extract_tf_records_from_GEDI_tiffs():
    """Extracts data directly from GEDI tiffs and
    inserts them into tf records. This allows us to
    offload the normalization procedure to either
    right before training (via sampling) or during
    training (via normalization with a batch's max)"""

    # Grab the global config
    config = GEDIconfig()

    # If requested load in the ratio file
    ratio_file = os.path.join(
        config.home_dir, config.original_image_dir, config.ratio_stem,
        '%s%s.csv' % (config.ratio_prefix, config.experiment_image_set))
    if config.ratio_prefix is not None and os.path.exists(ratio_file):
        ratio_list = read_csv(os.path.join(ratio_file))
    else:
        ratio_list = None

    # Allow option for per-timestep image extraction
    if config.timestep_delta_frames:
        extraction = extract_to_tf_records
    else:
        extraction = vtf

    # Make dirs if they do not exist
    dir_list = [
        config.train_directory, config.validation_directory,
        config.tfrecord_dir, config.train_checkpoint
    ]
    if 'test' in config.tvt_flags:
        dir_list += [config.test_directory]
        config.raw_im_dirs = [x + '_train' for x in config.raw_im_dirs]
        config.raw_im_dirs += [
            x.split('_train')[0] + '_test' for x in config.raw_im_dirs
        ]
    [make_dir(d) for d in dir_list]
    im_lists = get_image_dict(config)

    # Sample from training for validation images
    if 'val' in config.tvt_flags:
        im_lists['val'], im_lists['train'] = sample_files(
            im_lists['train'], config.train_proportion, config.tvt_flags)
    if config.encode_time_of_death is not None:
        death_timepoints = pd.read_csv(
            config.encode_time_of_death)[['plate_well_neuron', 'dead_tp']]
        keep_experiments = pd.read_csv(config.time_of_death_experiments)
        im_labels = {}
        for k, v in im_lists.items():
            if k is not 'test':
                proc_ims, proc_labels = find_timepoint(
                    images=v,
                    data=death_timepoints,
                    keep_experiments=keep_experiments,
                    remove_thresh=config.mask_timepoint_value)
                im_labels[k] = proc_labels
                im_lists[k] = proc_ims
                df = pd.DataFrame(np.vstack(
                    (proc_ims, proc_labels)).transpose(),
                                  columns=['image', 'timepoint'])
                df.to_csv('%s.csv' % k)
            else:
                im_labels[k] = find_label(v)
    else:
        im_labels = {k: find_label(v) for k, v in im_lists.items()}

    if type(config.tvt_flags) is str:
        tvt_flags = [config.tvt_flags]
    else:
        tvt_flags = config.tvt_flags
    assert len(np.concatenate(im_lists.values())), 'Could not find any files.'
    label_list = [
        write_labels(flag=x, im_lists=im_lists, config=config)
        for x in tvt_flags
    ]

    if config.include_GEDI_in_tfrecords > 0:
        tf_flag = '_%sgedi' % config.include_GEDI_in_tfrecords
    else:
        tf_flag = ''
    if config.extra_image:
        tf_flag = '_1image'
    else:
        tf_flag = ''

    if type(config.tvt_flags) is str:
        files = im_lists[config.tvt_flags]
        label_list = im_labels[config.tvt_flags]
        output_pointer = os.path.join(
            config.tfrecord_dir,
            '%s%s.tfrecords' % (config.tvt_flags, tf_flag))
        extraction(files=files,
                   label_list=label_list,
                   output_pointer=output_pointer,
                   ratio_list=ratio_list,
                   config=config,
                   k=config.tvt_flags)
    else:
        for k in config.tvt_flags:
            files = im_lists[k]
            label_list = im_labels[k]
            output_pointer = os.path.join(
                config.tfrecord_dir,
                '%s%s.tfrecords' % (tf_flag, config.tf_record_names[k]))
            extraction(files=files,
                       label_list=label_list,
                       output_pointer=output_pointer,
                       ratio_list=ratio_list,
                       config=config,
                       k=k)
示例#12
0
def train_model(train_dir=None, validation_dir=None):
    config = GEDIconfig()
    if train_dir is None:  # Use globals
        train_data = os.path.join(config.tfrecord_dir,
                                  config.tf_record_names['train'])
        meta_data = np.load(
            os.path.join(config.tfrecord_dir,
                         '%s_%s' % (config.tvt_flags[0], config.max_file)))
    else:
        meta_data = np.load(
            os.path.join(train_dir,
                         '%s_%s' % (config.tvt_flags[0], config.max_file)))

    # Prepare image normalization values
    if config.max_gedi is None:
        max_value = np.nanmax(meta_data['max_array']).astype(np.float32)
        if max_value == 0:
            max_value = None
            print 'Derived max value is 0'
        else:
            print 'Normalizing with empirical max.'
        if 'min_array' in meta_data.keys():
            min_value = np.min(meta_data['min_array']).astype(np.float32)
            print 'Normalizing with empirical min.'
        else:
            min_value = None
            print 'Not normalizing with a min.'
    else:
        max_value = config.max_gedi
        min_value = config.min_gedi
    ratio = meta_data['ratio']
    print 'Ratio is: %s' % ratio

    if validation_dir is None:  # Use globals
        validation_data = os.path.join(config.tfrecord_dir,
                                       config.tf_record_names['val'])
    elif validation_dir is False:
        pass  # Do not use validation data during training

    # Make output directories if they do not exist
    dt_stamp = re.split(
        '\.', str(datetime.now()))[0].\
        replace(' ', '_').replace(':', '_').replace('-', '_')
    dt_dataset = config.which_dataset + '_' + dt_stamp + '/'
    config.train_checkpoint = os.path.join(config.train_checkpoint,
                                           dt_dataset)  # timestamp this run
    out_dir = os.path.join(config.results, dt_dataset)
    dir_list = [
        config.train_checkpoint, config.train_summaries, config.results,
        out_dir
    ]
    [make_dir(d) for d in dir_list]
    # im_shape = get_image_size(config)
    im_shape = config.gedi_image_size

    print '-' * 60
    print('Training model:' + dt_dataset)
    print '-' * 60

    # Prepare data on CPU
    assert os.path.exists(train_data)
    assert os.path.exists(validation_data)
    assert os.path.exists(config.vgg16_weight_path)
    with tf.device('/cpu:0'):
        train_images_0, train_images_1, train_labels, train_times = inputs(
            train_data,
            config.train_batch,
            im_shape,
            config.model_image_size[:2],
            max_value=max_value,
            min_value=min_value,
            train=config.data_augmentations,
            num_epochs=config.epochs,
            normalize=config.normalize,
            return_filename=True)
        val_images_0, val_images_1, val_labels, val_times = inputs(
            validation_data,
            config.validation_batch,
            im_shape,
            config.model_image_size[:2],
            max_value=max_value,
            min_value=min_value,
            num_epochs=config.epochs,
            normalize=config.normalize,
            return_filename=True)
        tf.summary.image('train image frame 0', train_images_0)
        tf.summary.image('train image frame 1', train_images_1)
        tf.summary.image('validation image frame 0', val_images_0)
        tf.summary.image('validation image frame 1', val_images_1)

    # Prepare model on GPU
    with tf.device('/gpu:0'):
        with tf.variable_scope('gedi'):
            # Build training GEDI model for frame 0
            vgg_train_mode = tf.get_variable(name='vgg_training',
                                             initializer=False)
            gedi_model_0 = vgg16.model_struct(
                vgg16_npy_path=config.gedi_weight_path, trainable=False)
            gedi_model_0.build(prep_images_for_gedi(train_images_0),
                               output_shape=2,
                               train_mode=vgg_train_mode)
            gedi_scores_0 = gedi_model_0.fc7

        with tf.variable_scope('match'):
            # Build matching model for frame 0
            model_0 = matching_gedi.model_struct()
            model_0.build(train_images_0)

            # Build frame 0 vector
            frame_0 = tf.concat([gedi_scores_0, model_0.output], axis=-1)

            # Build output layer
            if config.matching_combine == 'concatenate':
                output_shape = [int(frame_0.get_shape()[-1]) * 2, 2]
            elif config.matching_combine == 'subtract':
                output_shape = [int(frame_0.get_shape()[-1]), 2]
            else:
                raise RuntimeError

        # Build GEDI model for frame 1
        with tf.variable_scope('gedi', reuse=True):
            gedi_model_1 = vgg16.model_struct(
                vgg16_npy_path=config.gedi_weight_path, trainable=False)
            gedi_model_1.build(prep_images_for_gedi(train_images_1),
                               output_shape=2,
                               train_mode=vgg_train_mode)
            gedi_scores_1 = gedi_model_1.fc7

        with tf.variable_scope('match', reuse=True):
            # Build matching model for frame 1
            model_1 = matching_gedi.model_struct()
            model_1.build(train_images_1)

        # Build frame 0 and frame 1 vectors
        frame_1 = tf.concat([gedi_scores_1, model_1.output], axis=-1)

        with tf.variable_scope('output'):
            # Concatenate or subtract
            if config.matching_combine == 'concatenate':
                output_scores = tf.concat([frame_0, frame_1], axis=-1)
            elif config.matching_combine == 'subtract':
                output_scores = frame_0 - frame_1
            else:
                raise NotImplementedError

            # Build output layer
            output_shape = [int(output_scores.get_shape()[-1]), 2]
            output_weights = tf.get_variable(
                name='output_weights',
                shape=output_shape,
                initializer=tf.contrib.layers.xavier_initializer(
                    uniform=False))
            output_bias = tf.get_variable(name='output_bias',
                                          initializer=tf.truncated_normal(
                                              [output_shape[-1]], .0, .001))
            decision_logits = tf.nn.bias_add(
                tf.matmul(output_scores, output_weights), output_bias)
            train_soft_decisions = tf.nn.softmax(decision_logits)
            cost = softmax_cost(decision_logits, train_labels)
            tf.summary.scalar("cce loss", cost)
            cost += tf.nn.l2_loss(output_weights)

            # Weight decay
            if config.wd_layers is not None:
                _, l2_wd_layers = fine_tune_prepare_layers(
                    tf.trainable_variables(), config.wd_layers)
                l2_wd_layers = [
                    x for x in l2_wd_layers if 'biases' not in x.name
                ]
                if len(l2_wd_layers) > 0:
                    cost += (config.wd_penalty *
                             tf.add_n([tf.nn.l2_loss(x)
                                       for x in l2_wd_layers]))

        # Optimize
        train_op = tf.train.AdamOptimizer(config.new_lr).minimize(cost)
        train_accuracy = class_accuracy(train_soft_decisions,
                                        train_labels)  # training accuracy
        tf.summary.scalar("training accuracy", train_accuracy)

        # Setup validation op
        if validation_data is not False:
            with tf.variable_scope('gedi', reuse=tf.AUTO_REUSE):  # FIX THIS
                # Validation graph is the same as training except no batchnorm
                val_gedi_model_0 = vgg16.model_struct(
                    vgg16_npy_path=config.gedi_weight_path)
                val_gedi_model_0.build(prep_images_for_gedi(val_images_0),
                                       output_shape=2,
                                       train_mode=vgg_train_mode)
                val_gedi_scores_0 = val_gedi_model_0.fc7

                # Build GEDI model for frame 1
                val_gedi_model_1 = vgg16.model_struct(
                    vgg16_npy_path=config.gedi_weight_path)
                val_gedi_model_1.build(prep_images_for_gedi(val_images_1),
                                       output_shape=2,
                                       train_mode=vgg_train_mode)
                val_gedi_scores_1 = val_gedi_model_1.fc7

            with tf.variable_scope('match', reuse=tf.AUTO_REUSE):
                # Build matching model for frame 0
                val_model_0 = matching_gedi.model_struct()
                val_model_0.build(val_images_0)

                # Build matching model for frame 1
                val_model_1 = matching_gedi.model_struct()
                val_model_1.build(val_images_1)

            # Build frame 0 and frame 1 vectors
            val_frame_0 = tf.concat([val_gedi_scores_0, val_model_0.output],
                                    axis=-1)
            val_frame_1 = tf.concat([val_gedi_scores_1, val_model_1.output],
                                    axis=-1)

            # Concatenate or subtract
            if config.matching_combine == 'concatenate':
                val_output_scores = tf.concat([val_frame_0, val_frame_1],
                                              axis=-1)
            elif config.matching_combine == 'subtract':
                val_output_scores = val_frame_0 - val_frame_1
            else:
                raise NotImplementedError

            with tf.variable_scope('output', reuse=tf.AUTO_REUSE):
                # Build output layer
                val_output_weights = tf.get_variable(
                    name='val_output_weights',
                    shape=output_shape,
                    trainable=False,
                    initializer=tf.contrib.layers.xavier_initializer(
                        uniform=False))
                val_output_bias = tf.get_variable(
                    name='output_bias',
                    trainable=False,
                    initializer=tf.truncated_normal([output_shape[-1]], .0,
                                                    .001))
                val_decision_logits = tf.nn.bias_add(
                    tf.matmul(val_output_scores, val_output_weights),
                    val_output_bias)
                val_soft_decisions = tf.nn.softmax(val_decision_logits)

            # Calculate validation accuracy
            val_accuracy = class_accuracy(val_soft_decisions, val_labels)
            tf.summary.scalar("validation accuracy", val_accuracy)

    # Set up summaries and saver
    saver = tf.train.Saver(tf.global_variables(),
                           max_to_keep=config.keep_checkpoints)
    summary_op = tf.summary.merge_all()

    # Initialize the graph
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    # Need to initialize both of these if supplying num_epochs to inputs
    sess.run(
        tf.group(tf.global_variables_initializer(),
                 tf.local_variables_initializer()))
    summary_dir = os.path.join(config.train_summaries,
                               config.which_dataset + '_' + dt_stamp)
    summary_writer = tf.summary.FileWriter(summary_dir, sess.graph)

    # Set up exemplar threading
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    # Start training loop
    np.save(out_dir + 'meta_info', config)
    step, losses = 0, []  # val_max = 0
    try:
        # print response
        while not coord.should_stop():
            start_time = time.time()
            _, loss_value, train_acc, val_acc = sess.run(
                [train_op, cost, train_accuracy, val_accuracy])
            losses += [loss_value]
            duration = time.time() - start_time
            if np.isnan(loss_value).sum():
                assert not np.isnan(loss_value), 'Model loss = NaN'

            if step % config.validation_steps == 0:
                if validation_data is not False:
                    val_acc = sess.run(val_accuracy)
                else:
                    val_acc -= 1  # Store every checkpoint

                # Summaries
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

                # Training status and validation accuracy
                format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; '
                              '%.3f sec/batch) | Training accuracy = %s | '
                              'Validation accuracy = %s | '
                              'logdir = %s')
                print(format_str %
                      (datetime.now(), step,
                       loss_value, config.train_batch / duration,
                       float(duration), train_acc, val_acc, summary_dir))

                # Save the model checkpoint if it's the best yet
                if 1:  # val_acc >= val_max:
                    saver.save(sess,
                               os.path.join(config.train_checkpoint,
                                            'model_' + str(step) + '.ckpt'),
                               global_step=step)
                    # Store the new max validation accuracy
                    # val_max = val_acc

            else:
                # Training status
                format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; '
                              '%.3f sec/batch) | Training accuracy = %s | '
                              'Training loss = %s')
                print(format_str %
                      (datetime.now(), step, loss_value, config.train_batch /
                       duration, float(duration), loss_value))
            # End iteration
            step += 1

    except tf.errors.OutOfRangeError:
        print('Done training for %d epochs, %d steps.' % (config.epochs, step))
    finally:
        coord.request_stop()
        np.save(os.path.join(config.tfrecord_dir, 'training_loss'), losses)
    coord.join(threads)
    sess.close()
示例#13
0
def test_vgg16(validation_data, model_dir, which_set, selected_ckpts=-1):
    config = GEDIconfig()
    if validation_data is None:  # Use globals
        validation_data = os.path.join(config.tfrecord_dir,
                                       config.tf_record_names[which_set])
        meta_data = np.load(
            os.path.join(config.tfrecord_dir, 'val_%s' % config.max_file))
    else:
        meta_data = np.load('%s_maximum_value.npz' %
                            validation_data.split('.tfrecords')[0])
    label_list = os.path.join(
        config.processed_image_patch_dir,
        'list_of_' + '_'.join(x
                              for x in config.image_prefixes) + '_labels.txt')
    with open(label_list) as f:
        file_pointers = [l.rstrip('\n') for l in f.readlines()]

    # Prepare image normalization values
    try:
        max_value = np.max(meta_data['max_array']).astype(np.float32)
    except:
        max_value = np.asarray([config.max_gedi])
    try:
        min_value = np.max(meta_data['min_array']).astype(np.float32)
    except:
        min_value = np.asarray([config.min_gedi])

    # Find model checkpoints
    ckpts, ckpt_names = find_ckpts(config, model_dir)
    ds_dt_stamp = re.split('/', ckpts[0])[-2]
    out_dir = os.path.join(config.results, ds_dt_stamp)
    try:
        config = np.load(os.path.join(out_dir, 'meta_info.npy')).item()
        # Make sure this is always at 1
        config.validation_batch = 1
        print('-' * 60)
        print('Loading config meta data for:%s' % out_dir)
        print('-' * 60)
    except:
        print('-' * 60)
        print('Using config from gedi_config.py for model:%s' % out_dir)
        print('-' * 60)

    # Make output directories if they do not exist
    dir_list = [config.results, out_dir]
    [make_dir(d) for d in dir_list]
    # im_shape = get_image_size(config)
    im_shape = config.gedi_image_size

    # Prepare data on CPU
    with tf.device('/cpu:0'):
        val_images, val_labels = inputs(validation_data,
                                        1,
                                        im_shape,
                                        config.model_image_size[:2],
                                        max_value=max_value,
                                        min_value=min_value,
                                        num_epochs=1,
                                        normalize=config.normalize)

    # Prepare model on GPU
    with tf.device('/gpu:0'):
        with tf.variable_scope('cnn'):
            vgg = vgg16.Vgg16(vgg16_npy_path=config.vgg16_weight_path,
                              fine_tune_layers=config.fine_tune_layers)
            vgg.build(val_images, output_shape=config.output_shape)

        # Setup validation op
        scores = vgg.prob
        preds = tf.argmax(vgg.prob, 1)
        targets = tf.cast(val_labels, dtype=tf.int64)

    # Set up saver
    saver = tf.train.Saver(tf.global_variables())

    # Loop through each checkpoint then test the entire validation set
    ckpt_yhat, ckpt_y, ckpt_scores = [], [], []
    print('-' * 60)
    print('Beginning evaluation')
    print('-' * 60)

    if selected_ckpts is not None:
        # Select a specific ckpt
        if selected_ckpts < 0:
            ckpts = ckpts[selected_ckpts:]
        else:
            ckpts = ckpts[:selected_ckpts]

    for idx, c in tqdm(enumerate(ckpts), desc='Running checkpoints'):
        dec_scores, yhat, y = [], [], []
        try:
            # Initialize the graph
            sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
            sess.run(
                tf.group(tf.global_variables_initializer(),
                         tf.local_variables_initializer()))

            # Set up exemplar threading
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            saver.restore(sess, c)
            start_time = time.time()
            while not coord.should_stop():
                sc, tyh, ty = sess.run([scores, preds, targets])
                dec_scores = np.append(dec_scores, sc)
                yhat = np.append(yhat, tyh)
                y = np.append(y, ty)
        except tf.errors.OutOfRangeError:
            ckpt_yhat.append(yhat)
            ckpt_y.append(y)
            ckpt_scores.append(dec_scores)
            print('Iteration accuracy: %s' % np.mean(yhat == y))
            print('Iteration pvalue: %.5f' %
                  randomization_test(y=y, yhat=yhat))
            print('Batch %d took %.1f seconds' %
                  (idx, time.time() - start_time))
        finally:
            coord.request_stop()
        coord.join(threads)
        sess.close()

    # Save everything
    np.savez(os.path.join(out_dir, 'validation_accuracies'),
             ckpt_yhat=ckpt_yhat,
             ckpt_y=ckpt_y,
             ckpt_scores=ckpt_scores,
             ckpt_names=ckpt_names,
             file_pointers=file_pointers)

    # Also save a csv with item/guess pairs
    try:
        trimmed_files = [re.split('/', x)[-1] for x in file_pointers]
        trimmed_files = np.asarray(trimmed_files)
        dec_scores = np.asarray(dec_scores)
        yhat = np.asarray(yhat)
        df = pd.DataFrame(
            np.hstack((trimmed_files.reshape(-1, 1), yhat.reshape(-1, 1),
                       dec_scores.reshape(dec_scores.shape[0] // 2, 2))),
            columns=['files', 'guesses', 'score dead', 'score live'])
        df.to_csv(os.path.join(out_dir, 'prediction_file.csv'))
        print('Saved csv to: %s' % out_dir)
    except:
        print('X' * 60)
        print('Could not save a spreadsheet of file info')
        print('X' * 60)

    # Plot everything
    try:
        plot_accuracies(ckpt_y, ckpt_yhat, config, ckpt_names,
                        os.path.join(out_dir, 'validation_accuracies.png'))
        plot_std(ckpt_y, ckpt_yhat, ckpt_names,
                 os.path.join(out_dir, 'validation_stds.png'))
        plot_cms(ckpt_y, ckpt_yhat, config,
                 os.path.join(out_dir, 'confusion_matrix.png'))
        plot_pr(ckpt_y, ckpt_yhat, ckpt_scores,
                os.path.join(out_dir, 'precision_recall.png'))
        plot_cost(os.path.join(out_dir, 'training_loss.npy'), ckpt_names,
                  os.path.join(out_dir, 'training_costs.png'))
    except:
        print('X' * 60)
        print('Could not locate the loss numpy')
        print('X' * 60)
示例#14
0
def train_model(train_dir=None,
                validation_dir=None,
                debug=True,
                resume_ckpt=None,
                resume_meta=None,
                margin=.4):
    config = GEDIconfig()
    if resume_meta is not None:
        config = np.load(resume_meta).item()
    assert margin is not None, 'Need a margin for the loss.'
    if train_dir is None:  # Use globals
        train_data = os.path.join(config.tfrecord_dir,
                                  config.tf_record_names['train'])
        meta_data = np.load(
            os.path.join(config.tfrecord_dir,
                         '%s_%s' % (config.tvt_flags[0], config.max_file)))
    else:
        meta_data = np.load(
            os.path.join(train_dir,
                         '%s_%s' % (config.tvt_flags[0], config.max_file)))

    # Prepare image normalization values
    if config.max_gedi is None:
        max_value = np.nanmax(meta_data['max_array']).astype(np.float32)
        if max_value == 0:
            max_value = None
            print('Derived max value is 0')
        else:
            print('Normalizing with empirical max.')
        if 'min_array' in meta_data.keys():
            # min_value = np.min(meta_data['min_array']).astype(np.float32)
            print('Normalizing with empirical min.')
        else:
            # min_value = None
            print('Not normalizing with a min.')
    else:
        max_value = config.max_gedi
        # min_value = config.min_gedi
    ratio = meta_data['ratio']
    print('Ratio is: %s' % ratio)

    if validation_dir is None:  # Use globals
        validation_data = os.path.join(config.tfrecord_dir,
                                       config.tf_record_names['val'])
    elif validation_dir is False:
        pass  # Do not use validation data during training

    # Make output directories if they do not exist
    dt_stamp = re.split(
        '\.', str(datetime.now()))[0].\
        replace(' ', '_').replace(':', '_').replace('-', '_')
    dt_dataset = config.which_dataset + '_' + dt_stamp + '/'
    config.train_checkpoint = os.path.join(config.train_checkpoint,
                                           dt_dataset)  # timestamp this run
    out_dir = os.path.join(config.results, dt_dataset)
    dir_list = [
        config.train_checkpoint, config.train_summaries, config.results,
        out_dir
    ]
    [tf_fun.make_dir(d) for d in dir_list]
    # im_shape = get_image_size(config)
    im_shape = config.gedi_image_size

    print('-' * 60)
    print('Training model:' + dt_dataset)
    print('-' * 60)

    # Prepare data on CPU
    assert os.path.exists(train_data)
    assert os.path.exists(validation_data)
    assert os.path.exists(config.vgg16_weight_path)
    with tf.device('/cpu:0'):
        train_images, train_labels, train_times = inputs(
            train_data,
            config.train_batch,
            im_shape,
            config.model_image_size,
            # max_value=max_value,
            # min_value=min_value,
            train=config.data_augmentations,
            num_epochs=config.epochs,
            normalize=config.normalize,
            return_filename=True)
        val_images, val_labels, val_times = inputs(
            validation_data,
            config.validation_batch,
            im_shape,
            config.model_image_size,
            # max_value=max_value,
            # min_value=min_value,
            num_epochs=config.epochs,
            normalize=config.normalize,
            return_filename=True)
        train_image_list, val_image_list = [], []
        for idx in range(int(train_images.get_shape()[1])):
            train_image_list += [tf.gather(train_images, idx, axis=1)]
            val_image_list += [tf.gather(val_images, idx, axis=1)]
            tf.summary.image('train_image_frame_%s' % idx,
                             train_image_list[idx])
            tf.summary.image('validation_image_frame_%s' % idx,
                             val_image_list[idx])

    # Prepare model on GPU
    config.l2_norm = False
    config.norm_axis = 0
    config.dist_fun = 'pearson'
    config.per_batch = False
    config.include_GEDI = True  # False
    config.output_shape = 32
    config.margin = margin
    with tf.device('/gpu:0'):
        with tf.variable_scope('match'):
            # Build matching model for frame 0
            model_0 = vgg16.model_struct(
                vgg16_npy_path=config.gedi_weight_path)  # ,
            frame_activity = []
            model_activity = model_0.build(train_image_list[0],
                                           output_shape=config.output_shape,
                                           include_GEDI=config.include_GEDI)
            if config.l2_norm:
                model_activity = tf_fun.l2_normalize(model_activity,
                                                     axis=config.norm_axis)
            frame_activity += [model_activity]

        with tf.variable_scope('match', reuse=tf.AUTO_REUSE):
            # Build matching model for other frames
            for idx in range(1, len(train_image_list)):
                model_activity = model_0.build(
                    train_image_list[idx],
                    output_shape=config.output_shape,
                    include_GEDI=config.include_GEDI)
                if config.l2_norm:
                    model_activity = tf_fun.l2_normalize(model_activity,
                                                         axis=config.norm_axis)
                frame_activity += [model_activity]

        if config.dist_fun == 'l2':
            pos = tf_fun.l2_dist(frame_activity[0], frame_activity[1], axis=1)
            neg = tf_fun.l2_dist(frame_activity[0], frame_activity[2], axis=1)
        elif config.dist_fun == 'pearson':
            pos = tf_fun.pearson_dist(frame_activity[0],
                                      frame_activity[1],
                                      axis=1)
            neg = tf_fun.pearson_dist(frame_activity[0],
                                      frame_activity[2],
                                      axis=1)
        else:
            raise NotImplementedError(config.dist_fun)
        if config.per_batch:
            loss = tf.maximum(tf.reduce_mean(pos - neg) + margin, 0.)
        else:
            loss = tf.reduce_mean(tf.maximum(pos - neg + margin, 0.))
        tf.summary.scalar('Triplet_loss', loss)

        # Weight decay
        if config.wd_layers is not None:
            _, l2_wd_layers = tf_fun.fine_tune_prepare_layers(
                tf.trainable_variables(), config.wd_layers)
            l2_wd_layers = [x for x in l2_wd_layers if 'biases' not in x.name]
            if len(l2_wd_layers) > 0:
                loss += (config.wd_penalty *
                         tf.add_n([tf.nn.l2_loss(x) for x in l2_wd_layers]))

        # Optimize
        train_op = tf.train.AdamOptimizer(config.new_lr).minimize(loss)
        train_accuracy = tf.reduce_mean(
            tf.cast(
                tf.equal(
                    tf.nn.relu(tf.sign(neg - pos)),  # 1 if pos < neg
                    tf.cast(tf.ones_like(train_labels), tf.float32)),
                tf.float32))
        tf.summary.scalar('training_accuracy', train_accuracy)

        # Setup validation op
        if validation_data is not False:
            with tf.variable_scope('match', tf.AUTO_REUSE) as match:
                # Build matching model for frame 0
                match.reuse_variables()
                val_model_0 = vgg16.model_struct(
                    vgg16_npy_path=config.gedi_weight_path)
                val_frame_activity = []
                model_activity = val_model_0.build(
                    val_image_list[0],
                    output_shape=config.output_shape,
                    include_GEDI=config.include_GEDI)
                if config.l2_norm:
                    model_activity = tf_fun.l2_normalize(model_activity,
                                                         axis=config.norm_axis)
                val_frame_activity += [model_activity]

                # Build matching model for other frames
                for idx in range(1, len(train_image_list)):
                    model_activity = val_model_0.build(
                        val_image_list[idx],
                        output_shape=config.output_shape,
                        include_GEDI=config.include_GEDI)
                    if config.l2_norm:
                        model_activity = tf_fun.l2_normalize(
                            model_activity, axis=config.norm_axis)
                    val_frame_activity += [model_activity]
            if config.dist_fun == 'l2':
                val_pos = tf_fun.l2_dist(val_frame_activity[0],
                                         val_frame_activity[1],
                                         axis=1)
                val_neg = tf_fun.l2_dist(val_frame_activity[0],
                                         val_frame_activity[2],
                                         axis=1)
            elif config.dist_fun == 'pearson':
                val_pos = tf_fun.pearson_dist(val_frame_activity[0],
                                              val_frame_activity[1],
                                              axis=1)
                val_neg = tf_fun.pearson_dist(val_frame_activity[0],
                                              val_frame_activity[2],
                                              axis=1)
            if config.per_batch:
                val_loss = tf.maximum(
                    tf.reduce_mean(val_pos - val_neg) + margin, 0.)
            else:
                val_loss = tf.reduce_mean(
                    tf.maximum(val_pos - val_neg + margin, 0.))
            tf.summary.scalar('Validation_triplet_loss', val_loss)

        # Calculate validation accuracy
        val_accuracy = tf.reduce_mean(
            tf.cast(
                tf.equal(tf.nn.relu(tf.sign(val_neg - val_pos)),
                         tf.cast(tf.ones_like(val_labels), tf.float32)),
                tf.float32))
        tf.summary.scalar('val_accuracy', val_accuracy)

    # Set up summaries and saver
    saver = tf.train.Saver(tf.global_variables(),
                           max_to_keep=config.keep_checkpoints)
    summary_op = tf.summary.merge_all()

    # Initialize the graph
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    # Need to initialize both of these if supplying num_epochs to inputs
    sess.run(
        tf.group(tf.global_variables_initializer(),
                 tf.local_variables_initializer()))
    summary_dir = os.path.join(config.train_summaries,
                               config.which_dataset + '_' + dt_stamp)
    summary_writer = tf.summary.FileWriter(summary_dir, sess.graph)

    # Set up exemplar threading
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    # Train operations
    train_dict = {
        'train_op': train_op,
        'loss': loss,
        'pos': pos,
        'neg': neg,
        'train_accuracy': train_accuracy,
        'val_accuracy': val_accuracy,
    }
    val_dict = {'val_accuracy': val_accuracy}
    if debug:
        for idx in range(len(train_image_list)):
            train_dict['train_im_%s' % idx] = train_image_list[idx]
        for idx in range(len(val_image_list)):
            val_dict['val_im_%s' % idx] = val_image_list[idx]

    # Resume training if requested
    if resume_ckpt is not None:
        print('*' * 50)
        print('Resuming training from: %s' % resume_ckpt)
        print('*' * 50)
        saver.restore(sess, resume_ckpt)

    # Start training loop
    np.save(out_dir + 'meta_info', config)
    step, losses = 0, []
    top_vals = np.asarray([0])
    try:
        # print response
        while not coord.should_stop():
            start_time = time.time()
            train_values = sess.run(train_dict.values())
            it_train_dict = {
                k: v
                for k, v in zip(train_dict.keys(), train_values)
            }
            losses += [it_train_dict['loss']]
            duration = time.time() - start_time
            if np.isnan(it_train_dict['loss']).sum():
                assert not np.isnan(it_train_dict['loss']),\
                    'Model loss = NaN'

            if step % config.validation_steps == 0:
                if validation_data is not False:
                    val_accs = []
                    for vit in range(config.validation_iterations):
                        val_values = sess.run(val_dict.values())
                        it_val_dict = {
                            k: v
                            for k, v in zip(val_dict.keys(), val_values)
                        }
                        val_accs += [it_val_dict['val_accuracy']]
                    val_acc = np.nanmean(val_accs)
                else:
                    val_acc -= 1  # Store every checkpoint

                # Summaries
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

                # Training status and validation accuracy
                format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; '
                              '%.3f sec/batch) | Training accuracy = %s | '
                              'Validation accuracy = %s | '
                              'logdir = %s')
                print(format_str %
                      (datetime.now(), step,
                       it_train_dict['loss'], config.train_batch / duration,
                       float(duration), it_train_dict['train_accuracy'],
                       it_train_dict['val_accuracy'], summary_dir))

                # Save the model checkpoint if it's the best yet
                top_vals = top_vals[:config.num_keep_checkpoints]
                check_val = val_acc > top_vals
                if check_val.sum():
                    saver.save(sess,
                               os.path.join(config.train_checkpoint,
                                            'model_' + str(step) + '.ckpt'),
                               global_step=step)
                    # Store the new validation accuracy
                    top_vals = np.append(top_vals, val_acc)

            else:
                # Training status
                format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; '
                              '%.3f sec/batch) | Training accuracy = %s | ')
                print(format_str %
                      (datetime.now(), step,
                       it_train_dict['loss'], config.train_batch / duration,
                       float(duration), it_train_dict['train_accuracy']))
            # End iteration
            step += 1

    except tf.errors.OutOfRangeError:
        print('Done training for %d epochs, %d steps.' % (config.epochs, step))
    finally:
        coord.request_stop()
        np.save(os.path.join(config.tfrecord_dir, 'training_loss'), losses)
    coord.join(threads)
    sess.close()
示例#15
0
def test_vgg16(model_file,
               trained_svm,
               ims,
               dead_ims=None,
               output_csv='prediction_file',
               training_max=None,
               C=1e-3,
               k_folds=10):
    """Test an SVM you've trained on a new dataset."""
    config = GEDIconfig()
    if ims is None:
        raise RuntimeError(
            'You need to supply a directory path to the images.')
    if dead_ims is None:
        print 'Assuming all of your images are in the ims folder' + \
            '-- will not derive labels to calculate accuracy.'
    # if not os.path.exists(trained_svm):
    #     raise RuntimeError(
    #         'Cannot find the trained svm model. Check the path you passed.')
    try:
        clf = cPickle.load(open(trained_svm, 'rb'))
        # clf = model_dict['clf']
        # mu = model_dict['mu']
        # sd = model_dict['sd']
    except:
        raise RuntimeError('Cannot find SVM file: %s' % trained_svm)

    if dead_ims is not None:
        live_files = glob(os.path.join(ims, '*%s' % config.raw_im_ext))
        dead_files = glob(os.path.join(dead_ims, '*%s' % config.raw_im_ext))
        combined_labels = np.concatenate(
            (np.zeros(len(live_files)), np.ones(len(dead_files))))
        combined_files = np.concatenate((live_files, dead_files))
    else:
        live_files = glob(os.path.join(ims, '*%s' % config.raw_im_ext))
        combined_labels = None
        combined_files = np.asarray(live_files)
    if len(combined_files) == 0:
        raise RuntimeError('Could not find any files. Check your image path.')

    config = GEDIconfig()
    model_file_path = os.path.sep.join(model_file.split(os.path.sep)[:-1])
    meta_file_pointer = os.path.join(model_file_path,
                                     'train_maximum_value.npz')
    if not os.path.exists(meta_file_pointer):
        raise RuntimeError(
            'Cannot find the training data meta file: train_maximum_value.npz'
            'Closest I could find from directory %s was %s.'
            'Download this from the link described in the README.md.' %
            (model_file_path, glob(os.path.join(model_file_path, '*.npz'))))
    meta_data = np.load(meta_file_pointer)

    # Prepare image normalization values
    if training_max is None:
        training_max = np.max(meta_data['max_array']).astype(np.float32)
    training_min = np.min(meta_data['min_array']).astype(np.float32)

    # Find model checkpoints
    ds_dt_stamp = re.split('/', model_file)[-2]
    out_dir = os.path.join(config.results, ds_dt_stamp)

    # Make output directories if they do not exist
    dir_list = [config.results, out_dir]
    [make_dir(d) for d in dir_list]

    # Prepare data on CPU
    images = tf.placeholder(tf.float32,
                            shape=[None] + config.model_image_size,
                            name='images')

    # Prepare model on GPU
    with tf.device('/gpu:0'):
        with tf.variable_scope('cnn'):
            vgg = vgg16.model_struct(vgg16_npy_path=config.vgg16_weight_path,
                                     fine_tune_layers=config.fine_tune_layers)
            vgg.build(images, output_shape=config.output_shape)

        # Setup validation op
        scores = vgg.fc7
        preds = tf.argmax(vgg.prob, 1)

    # Set up saver
    saver = tf.train.Saver(tf.global_variables())

    # Loop through each checkpoint then test the entire validation set
    ckpts = [model_file]
    ckpt_yhat, ckpt_y, ckpt_scores, ckpt_file_array = [], [], [], []
    print '-' * 60
    print 'Beginning evaluation'
    print '-' * 60

    if config.validation_batch > len(combined_files):
        print 'Trimming validation_batch size to %s (same as # of files).' % len(
            combined_files)
        config.validation_batch = len(combined_files)

    for idx, c in tqdm(enumerate(ckpts), desc='Running checkpoints'):
        dec_scores, yhat, y, file_array = [], [], [], []
        # Initialize the graph
        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
        sess.run(
            tf.group(tf.global_variables_initializer(),
                     tf.local_variables_initializer()))

        # Set up exemplar threading
        saver.restore(sess, c)
        start_time = time.time()
        num_batches = np.floor(
            len(combined_files) / float(config.validation_batch)).astype(int)
        for image_batch, label_batch, file_batch in tqdm(image_batcher(
                start=0,
                num_batches=num_batches,
                images=combined_files,
                labels=combined_labels,
                config=config,
                training_max=training_max,
                training_min=training_min),
                                                         total=num_batches):
            feed_dict = {images: image_batch}
            sc, tyh = sess.run([scores, preds], feed_dict=feed_dict)
            dec_scores += [sc]
            yhat = np.append(yhat, tyh)
            y = np.append(y, label_batch)
            file_array = np.append(file_array, file_batch)
        ckpt_yhat.append(yhat)
        ckpt_y.append(y)
        ckpt_scores.append(dec_scores)
        ckpt_file_array.append(file_array)
        print 'Batch %d took %.1f seconds' % (idx, time.time() - start_time)
    sess.close()

    # Save everything
    new_dt_string = re.split('\.', str(datetime.now()))[0].\
        replace(' ', '_').replace(':', '_').replace('-', '_')
    np.savez(os.path.join(out_dir, '%s_validation_accuracies' % new_dt_string),
             ckpt_yhat=ckpt_yhat,
             ckpt_y=ckpt_y,
             ckpt_scores=ckpt_scores,
             ckpt_names=ckpts,
             combined_files=ckpt_file_array)

    # Run SVM
    all_scores = np.concatenate(dec_scores)
    # all_scores = (all_scores - mu) / sd
    predictions = clf.predict(all_scores)
    if dead_ims is not None:
        mean_acc = np.mean(predictions == y)
        p_value = randomization_test(y=y, yhat=predictions)
        print 'SVM performance: mean accuracy = %s%%, p = %.5f' % (mean_acc,
                                                                   p_value)
        df_col_label = 'true label'
    else:
        mean_acc, p_value = None, None
        y = np.copy(yhat)
        df_col_label = 'Dummy column (no labels supplied)'
    np.savez(os.path.join(out_dir, '%s_svm_test_data' % new_dt_string),
             yhat=yhat,
             y=y,
             scores=dec_scores,
             ckpts=ckpts,
             p_value=p_value)

    # Also save a csv with item/guess pairs
    trimmed_files = np.asarray([
        x.split(os.path.sep)[-1] for x in np.asarray(ckpt_file_array).ravel()
    ])
    yhat = np.asarray(yhat)
    df = pd.DataFrame(
        np.hstack((trimmed_files.reshape(-1, 1), yhat.reshape(-1, 1))),
        #   y.reshape(-1, 1))),
        columns=['files', 'guesses'])  # , df_col_label])
    df.to_csv(os.path.join(out_dir, 'prediction_file.csv'))
    print 'Saved csv to: %s' % out_dir
def test_placeholder(
        image_path,
        model_file,
        model_meta,
        out_dir,
        n_images=3,
        first_n_images=1,
        debug=True,
        margin=.1,
        autopsy_csv=None,
        C=1,
        k_folds=10,
        embedding_type='tsne',
        autopsy_model='match'):
    config = GEDIconfig()
    assert margin is not None, 'Need a margin for the loss.'
    assert image_path is not None, 'Provide a path to an image directory.'
    assert model_file is not None, 'Provide a path to the model file.'

    try:
        # Load the model's config
        config = np.load(model_meta).item()
    except:
        print 'Could not load model config, falling back to default config.'
    config.model_image_size[-1] = 1
    try:
        # Load autopsy information
        autopsy_data = pd.read_csv(autopsy_csv)
    except IOError:
        print 'Unable to load autopsy file.'
    if not hasattr(config, 'include_GEDI'):
        raise RuntimeError('You need to pass the correct meta file.')
        config.include_GEDI = True
        config.l2_norm = False
        config.dist_fun = 'pearson'
        config.per_batch = False
        config.output_shape = 32
        config.margin = 0.1
    if os.path.isdir(image_path):
        combined_files = np.asarray(
            glob(os.path.join(image_path, '*%s' % config.raw_im_ext)))
    else:
        combined_files = [image_path]
    if len(combined_files) == 0:
        raise RuntimeError('Could not find any files. Check your image path.')

    # Make output directories if they do not exist
    dt_stamp = re.split(
        '\.', str(datetime.now()))[0].\
        replace(' ', '_').replace(':', '_').replace('-', '_')
    dt_dataset = config.which_dataset + '_' + dt_stamp + '/'
    config.train_checkpoint = os.path.join(
        config.train_checkpoint, dt_dataset)  # timestamp this run
    out_dir = os.path.join(out_dir, dt_dataset)
    dir_list = [out_dir]
    [tf_fun.make_dir(d) for d in dir_list]

    # Prepare data on CPU
    with tf.device('/cpu:0'):
        images = []
        for idx in range(first_n_images):
            images += [tf.placeholder(
                tf.float32,
                shape=[None] + config.model_image_size,
                name='images_%s' % idx)]

    # Prepare model on GPU
    with tf.device('/gpu:0'):
        if autopsy_model == 'match':
            from models import matching_vgg16 as model_type
            with tf.variable_scope('match'):
                # Build matching model for frame 0
                model_0 = model_type.model_struct(
                    vgg16_npy_path=config.gedi_weight_path)  # ,
                frame_activity = []
                model_activity = model_0.build(
                    images[0],
                    output_shape=config.output_shape,
                    include_GEDI=config.include_GEDI)
                if config.l2_norm:
                    model_activity = [model_activity]
                frame_activity += [model_activity]
            if first_n_images > 1:
                with tf.variable_scope('match', reuse=tf.AUTO_REUSE):
                    # Build matching model for other frames
                    for idx in range(1, len(images)):
                        model_activity = model_0.build(
                            images[idx],
                            output_shape=config.output_shape,
                            include_GEDI=config.include_GEDI)
                        if config.l2_norm:
                            model_activity = tf_fun.l2_normalize(
                                model_activity)
                        frame_activity += [model_activity]
                if config.dist_fun == 'l2':
                    pos = tf_fun.l2_dist(
                        frame_activity[0],
                        frame_activity[1], axis=1)
                    neg = tf_fun.l2_dist(
                        frame_activity[0],
                        frame_activity[2], axis=1)
                elif config.dist_fun == 'pearson':
                    pos = tf_fun.pearson_dist(
                        frame_activity[0],
                        frame_activity[1],
                        axis=1)
                    neg = tf_fun.pearson_dist(
                        frame_activity[0],
                        frame_activity[2],
                        axis=1)
                model_activity = pos - neg  # Store the difference in distances
        elif autopsy_model == 'GEDI' or autopsy_model == 'gedi':
            from models import baseline_vgg16 as model_type
            model = model_type.model_struct(
                vgg16_npy_path=config.gedi_weight_path)  # ,
            model.build(
                images[0],
                output_shape=config.output_shape)
            model_activity = model.fc7
        else:
            raise NotImplementedError(autopsy_model)

    if config.validation_batch > len(combined_files):
        print (
            'Trimming validation_batch size to %s '
            '(same as # of files).' % len(combined_files))
        config.validation_batch = len(combined_files)

    # Set up saver
    saver = tf.train.Saver(tf.global_variables())

    # Initialize the graph
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    # Need to initialize both of these if supplying num_epochs to inputs
    sess.run(
        tf.group(
            tf.global_variables_initializer(),
            tf.local_variables_initializer()))

    # Set up exemplar threading
    if autopsy_model == 'match':
        saver.restore(sess, model_file)
    start_time = time.time()
    num_batches = np.floor(
        len(combined_files) / float(
            config.validation_batch)).astype(int)
    score_array, file_array = [], []
    for image_batch, file_batch in tqdm(
            image_batcher(
                start=0,
                num_batches=num_batches,
                images=combined_files,
                config=config,
                first_n_images=first_n_images,
                n_images=n_images),
            total=num_batches):
        for im_head in images:
            feed_dict = {
                im_head: image_batch
            }
            activity = sess.run(
                model_activity,
                feed_dict=feed_dict)
            score_array += [activity]
        file_array += [file_batch]
    print 'Image processing %d took %.1f seconds' % (
        idx, time.time() - start_time)
    sess.close()
    score_array = np.concatenate(score_array, axis=0)
    score_array = score_array.reshape(-1, score_array.shape[-1])
    file_array = np.concatenate(file_array, axis=0)

    # Save everything
    np.savez(
        os.path.join(out_dir, 'validation_accuracies'),
        score_array=score_array,
        file_array=file_array)

    if first_n_images == 1:
        # Derive pathologies from file names
        pathologies = []
        for f in combined_files:
            sf = f.split(os.path.sep)[-1].split('_')
            line = sf[1]
            # time_col = sf[2]
            well = sf[4]
            disease = autopsy_data[
                np.logical_and(
                    autopsy_data['line'] == line,
                    autopsy_data['wells'] == well)]['type']
            try:
                disease = disease.as_matrix()[0]
            except:
                disease = 'Not_found'
            pathologies += [disease]
        pathologies = np.asarray(pathologies)[:len(score_array)]

        mu = score_array.mean(0)
        sd = score_array.std(0)
        z_score_array = (score_array - mu) / (sd + 1e-4)
        if embedding_type == 'TSNE' or embedding_type == 'tsne':
            emb = manifold.TSNE(n_components=2, init='pca', random_state=0)
        elif embedding_type == 'PCA' or embedding_type == 'pca':
            emb = PCA(n_components=2, svd_solver='randomized', random_state=0)
        elif embedding_type == 'spectral':
            emb = manifold.SpectralEmbedding(n_components=2, random_state=0)

        y = emb.fit_transform(score_array)

        # Do a classification analysis
        labels = np.unique(pathologies.reshape(-1, 1), return_inverse=True)[1]

        # Run SVM
        svm = LinearSVC(C=C, dual=False, class_weight='balanced')
        clf = make_pipeline(preprocessing.StandardScaler(), svm)
        predictions = cross_val_predict(clf, score_array, labels, cv=k_folds)
        cv_performance = metrics.accuracy_score(predictions, labels)
        clf.fit(score_array, labels)
        # mu = dec_scores.mean(0)
        # sd = dec_scores.std(0)
        print '%s-fold SVM performance: accuracy = %s%%' % (
            k_folds,
            np.mean(cv_performance * 100))
        np.savez(
            os.path.join(out_dir, 'svm_data'),
            yhat=score_array,
            y=labels,
            cv_performance=cv_performance,
            # mu=mu,
            # sd=sd,
            C=C)

        # Ouput csv
        df = pd.DataFrame(
            np.hstack((
                y,
                pathologies.reshape(-1, 1),
                file_array.reshape(-1, 1))),
            columns=['dim1', 'dim2', 'pathology', 'filename'])
        out_name = os.path.join(out_dir, 'raw_embedding.csv')
        df.to_csv(out_name)
        print 'Saved csv to: %s' % out_name

        create_figs(
            emb=emb,
            out_dir=out_dir,
            out_name=out_name,
            embedding_type=embedding_type,
            embedding_name='raw_embedding')

        # Now work on zscored data
        y = emb.fit_transform(z_score_array)

        # Ouput csv
        df = pd.DataFrame(
            np.hstack((
                y,
                pathologies.reshape(-1, 1),
                file_array.reshape(-1, 1))),
            columns=['dim1', 'dim2', 'pathology', 'filename'])
        out_name = os.path.join(out_dir, 'embedding.csv')
        df.to_csv(out_name)
        print 'Saved csv to: %s' % out_name

        # Create plot
        create_figs(
            emb=emb,
            out_dir=out_dir,
            out_name=out_name,
            embedding_type=embedding_type,
            embedding_name='normalized_embedding')

    else:
        # Do a classification (sign of the score)
        decisions = np.sign(score_array)
        df = pd.DataFrame(
            np.hstack(decisions, score_array),
            columns=['Decisions', 'Scores'])
        df.to_csv(
            os.path.join(
                out_dir, 'tracking_model_scores.csv'))
示例#17
0
def train_vgg16(train_dir=None, validation_dir=None):
    config = GEDIconfig()
    if train_dir is None:  # Use globals
        train_data = os.path.join(config.tfrecord_dir,
                                  config.tf_record_names['train'])
        meta_data = np.load(
            os.path.join(config.tfrecord_dir,
                         '%s_%s' % (config.tvt_flags[0], config.max_file)))
    else:
        meta_data = np.load(
            os.path.join(train_dir,
                         '%s_%s' % (config.tvt_flags[0], config.max_file)))

    # Prepare image normalization values
    if config.max_gedi is None:
        max_value = np.nanmax(meta_data['max_array']).astype(np.float32)
        if max_value == 0:
            max_value = None
            print('Derived max value is 0')
        else:
            print('Normalizing with empirical max.')
        if 'min_array' in meta_data.keys():
            min_value = np.min(meta_data['min_array']).astype(np.float32)
            print('Normalizing with empirical min.')
        else:
            min_value = None
            print('Not normalizing with a min.')
    else:
        max_value = config.max_gedi
        min_value = config.min_gedi
        print('max value', max_value)
    ratio = meta_data['ratio']
    if config.encode_time_of_death:
        tod = pd.read_csv(config.encode_time_of_death)
        tod_data = tod['dead_tp'].as_matrix()
        mask = np.isnan(tod_data).astype(int) + (
            tod['plate_well_neuron'] == 'empty').as_matrix().astype(int)
        tod_data = tod_data[mask == 0]
        tod_data = tod_data[tod_data <
                            10]  # throw away values have a high number
        config.output_shape = len(np.unique(tod_data))
        ratio = class_weight.compute_class_weight('balanced',
                                                  np.sort(np.unique(tod_data)),
                                                  tod_data)
        flip_ratio = False
    else:
        flip_ratio = True
    print('Ratio is: %s' % ratio)

    if validation_dir is None:  # Use globals
        validation_data = os.path.join(config.tfrecord_dir,
                                       config.tf_record_names['val'])
    elif validation_dir is False:
        pass  # Do not use validation data during training

    # Make output directories if they do not exist
    dt_stamp = re.split(
        '\.', str(datetime.now()))[0].\
        replace(' ', '_').replace(':', '_').replace('-', '_')
    dt_dataset = config.which_dataset + '_' + dt_stamp + '/'
    config.train_checkpoint = os.path.join(config.train_checkpoint,
                                           dt_dataset)  # timestamp this run
    out_dir = os.path.join(config.results, dt_dataset)
    dir_list = [
        config.train_checkpoint, config.train_summaries, config.results,
        out_dir
    ]
    [make_dir(d) for d in dir_list]
    # im_shape = get_image_size(config)
    im_shape = config.gedi_image_size

    print('-' * 60)
    print('Training model:' + dt_dataset)
    print('-' * 60)

    # Prepare data on CPU
    assert os.path.exists(train_data)
    assert os.path.exists(validation_data)
    assert os.path.exists(config.vgg16_weight_path)
    with tf.device('/cpu:0'):
        train_images, train_labels = inputs(train_data,
                                            config.train_batch,
                                            im_shape,
                                            config.model_image_size[:2],
                                            max_value=max_value,
                                            min_value=min_value,
                                            train=config.data_augmentations,
                                            num_epochs=config.epochs,
                                            normalize=config.normalize)
        val_images, val_labels = inputs(validation_data,
                                        config.validation_batch,
                                        im_shape,
                                        config.model_image_size[:2],
                                        max_value=max_value,
                                        min_value=min_value,
                                        num_epochs=config.epochs,
                                        normalize=config.normalize)
        tf.summary.image('train images', train_images)
        tf.summary.image('validation images', val_images)

    # Prepare model on GPU
    with tf.device('/gpu:0'):
        with tf.variable_scope('cnn') as scope:
            if config.ordinal_classification == 'ordinal':  # config.output_shape > 2:  # Hardcoded fix for timecourse pred
                vgg_output = config.output_shape * 2
            elif config.ordinal_classification == 'regression':
                vgg_output = 1
            elif config.ordinal_classification is None:
                vgg_output = config.output_shape
            vgg = vgg16.Vgg16(vgg16_npy_path=config.vgg16_weight_path,
                              fine_tune_layers=config.fine_tune_layers)
            train_mode = tf.get_variable(name='training', initializer=True)
            vgg.build(train_images,
                      output_shape=vgg_output,
                      train_mode=train_mode,
                      batchnorm=config.batchnorm_layers)

            # Prepare the cost function
            if config.ordinal_classification == 'ordinal':
                # Encode y w/ k-hot and yhat w/ sigmoid ce. units capture dist.
                enc = tf.concat([
                    tf.reshape(tf.range(0, config.output_shape), [1, -1])
                    for x in range(config.train_batch)
                ],
                                axis=0)
                enc_train_labels = tf.cast(
                    tf.greater_equal(enc, tf.expand_dims(train_labels,
                                                         axis=1)), tf.float32)
                split_labs = tf.split(enc_train_labels,
                                      config.output_shape,
                                      axis=1)
                res_output = tf.reshape(
                    vgg.fc8, [config.train_batch, 2, config.output_shape])
                split_logs = tf.split(res_output, config.output_shape, axis=2)
                if config.balance_cost:
                    cost = tf.add_n([
                        tf.nn.softmax_cross_entropy_with_logits(
                            labels=tf.one_hot(tf.cast(tf.squeeze(s), tf.int32),
                                              2),
                            logits=tf.squeeze(l)) * r
                        for s, l, r in zip(split_labs, split_logs, ratio)
                    ])
                else:
                    cost = tf.add_n([
                        tf.nn.softmax_cross_entropy_with_logits(
                            labels=tf.one_hot(tf.cast(tf.squeeze(s), tf.int32),
                                              2),
                            logits=tf.squeeze(l))
                        for s, l in zip(split_labs, split_logs)
                    ])
                cost = tf.reduce_mean(cost)
            elif config.ordinal_classification == 'regression':
                if config.balance_cost:
                    weight_vec = tf.gather(train_labels, ratio)
                    cost = tf.reduce_mean(
                        tf.pow(
                            (vgg.fc8) - tf.cast(train_labels, tf.float32), 2) *
                        weight_vec)
                else:
                    cost = tf.nn.l2_loss((vgg.fc8) -
                                         tf.cast(train_labels, tf.float32))
            else:
                if config.balance_cost:
                    cost = softmax_cost(vgg.fc8,
                                        train_labels,
                                        ratio=ratio,
                                        flip_ratio=flip_ratio)
                else:
                    cost = softmax_cost(vgg.fc8, train_labels)
            tf.summary.scalar("cost", cost)

            # Finetune the learning rates
            if config.wd_layers is not None:
                _, l2_wd_layers = fine_tune_prepare_layers(
                    tf.trainable_variables(), config.wd_layers)
                l2_wd_layers = [
                    x for x in l2_wd_layers if 'biases' not in x.name
                ]
                cost += (config.wd_penalty *
                         tf.add_n([tf.nn.l2_loss(x) for x in l2_wd_layers]))

            # for all variables in trainable variables
            # print name if there's duplicates you f****d up
            other_opt_vars, ft_opt_vars = fine_tune_prepare_layers(
                tf.trainable_variables(), config.fine_tune_layers)
            if config.optimizer == 'adam':
                train_op = ft_non_optimized(cost, other_opt_vars, ft_opt_vars,
                                            tf.train.AdamOptimizer,
                                            config.hold_lr, config.new_lr)
            elif config.optimizer == 'sgd':
                train_op = ft_non_optimized(cost, other_opt_vars, ft_opt_vars,
                                            tf.train.GradientDescentOptimizer,
                                            config.hold_lr, config.new_lr)

            if config.ordinal_classification == 'ordinal':
                arg_guesses = tf.cast(
                    tf.reduce_sum(tf.squeeze(tf.argmax(res_output, axis=1)),
                                  reduction_indices=[1]), tf.int32)
                train_accuracy = tf.reduce_mean(
                    tf.cast(tf.equal(arg_guesses, train_labels), tf.float32))
            elif config.ordinal_classification == 'regression':
                train_accuracy = tf.reduce_mean(
                    tf.cast(
                        tf.equal(tf.cast(tf.round(vgg.prob), tf.int32),
                                 train_labels), tf.float32))
            else:
                train_accuracy = class_accuracy(
                    vgg.prob, train_labels)  # training accuracy
            tf.summary.scalar("training accuracy", train_accuracy)

            # Setup validation op
            if validation_data is not False:
                scope.reuse_variables()
                # Validation graph is the same as training except no batchnorm
                val_vgg = vgg16.Vgg16(vgg16_npy_path=config.vgg16_weight_path,
                                      fine_tune_layers=config.fine_tune_layers)
                val_vgg.build(val_images, output_shape=vgg_output)
                # Calculate validation accuracy
                if config.ordinal_classification == 'ordinal':
                    val_res_output = tf.reshape(
                        val_vgg.fc8,
                        [config.validation_batch, 2, config.output_shape])
                    val_arg_guesses = tf.cast(
                        tf.reduce_sum(tf.squeeze(
                            tf.argmax(val_res_output, axis=1)),
                                      reduction_indices=[1]), tf.int32)
                    val_accuracy = tf.reduce_mean(
                        tf.cast(tf.equal(val_arg_guesses, val_labels),
                                tf.float32))
                elif config.ordinal_classification == 'regression':
                    val_accuracy = tf.reduce_mean(
                        tf.cast(
                            tf.equal(tf.cast(tf.round(val_vgg.prob), tf.int32),
                                     val_labels), tf.float32))
                else:
                    val_accuracy = class_accuracy(val_vgg.prob, val_labels)
                tf.summary.scalar("validation accuracy", val_accuracy)

    # Set up summaries and saver
    saver = tf.train.Saver(tf.global_variables(),
                           max_to_keep=config.keep_checkpoints)
    summary_op = tf.summary.merge_all()

    # Initialize the graph
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    # Need to initialize both of these if supplying num_epochs to inputs
    sess.run(
        tf.group(tf.global_variables_initializer(),
                 tf.local_variables_initializer()))
    summary_dir = os.path.join(config.train_summaries,
                               config.which_dataset + '_' + dt_stamp)
    summary_writer = tf.summary.FileWriter(summary_dir, sess.graph)

    # Set up exemplar threading
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    # Start training loop
    np.save(out_dir + 'meta_info', config)
    step, losses = 0, []  # val_max = 0
    try:
        # print response
        while not coord.should_stop():
            start_time = time.time()
            _, loss_value, train_acc = sess.run(
                [train_op, cost, train_accuracy])
            losses.append(loss_value)
            duration = time.time() - start_time
            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            if step % config.validation_steps == 0:
                if validation_data is not False:
                    _, val_acc = sess.run([train_op, val_accuracy])
                else:
                    val_acc -= 1  # Store every checkpoint

                # Summaries
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

                # Training status and validation accuracy
                format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; '
                              '%.3f sec/batch) | Training accuracy = %s | '
                              'Validation accuracy = %s | logdir = %s')
                print(format_str %
                      (datetime.now(), step,
                       loss_value, config.train_batch / duration,
                       float(duration), train_acc, val_acc, summary_dir))

                # Save the model checkpoint if it's the best yet
                if 1:  # val_acc >= val_max:
                    saver.save(sess,
                               os.path.join(config.train_checkpoint,
                                            'model_' + str(step) + '.ckpt'),
                               global_step=step)
                    # Store the new max validation accuracy
                    # val_max = val_acc

            else:
                # Training status
                format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; '
                              '%.3f sec/batch) | Training accuracy = %s')
                print(format_str %
                      (datetime.now(), step, loss_value, config.train_batch /
                       duration, float(duration), train_acc))
            # End iteration
            step += 1

    except tf.errors.OutOfRangeError:
        print('Done training for %d epochs, %d steps.' % (config.epochs, step))
    finally:
        coord.request_stop()
        np.save(os.path.join(config.tfrecord_dir, 'training_loss'), losses)
    coord.join(threads)
    sess.close()
示例#18
0
def visualize_model(
        live_ims,
        dead_ims,
        model_file,
        output_folder,
        num_channels,
        smooth_iterations=50,
        untargeted=False,
        viz='none',
        per_timepoint=True):
    """Train an SVM for your dataset on GEDI-model encodings."""
    config = GEDIconfig()
    if live_ims is None:
        raise RuntimeError(
            'You need to supply a directory path to the live images.')
    if dead_ims is None:
        raise RuntimeError(
            'You need to supply a directory path to the dead images.')

    live_files = glob(os.path.join(live_ims, '*%s' % config.raw_im_ext))
    dead_files = glob(os.path.join(dead_ims, '*%s' % config.raw_im_ext))
    combined_labels = np.concatenate((
        np.zeros(len(live_files)),
        np.ones(len(dead_files))))
    combined_files = np.concatenate((live_files, dead_files))
    if len(combined_files) == 0:
        raise RuntimeError('Could not find any files. Check your image path.')

    config = GEDIconfig()
    model_file_path = os.path.sep.join(model_file.split(os.path.sep)[:-1])
    meta_file_pointer = os.path.join(
        model_file_path,
        'train_maximum_value.npz')
    if not os.path.exists(meta_file_pointer):
        raise RuntimeError(
            'Cannot find the training data meta file: train_maximum_value.npz'
            'Closest I could find from directory %s was %s.'
            'Download this from the link described in the README.md.'
            % (model_file_path, glob(os.path.join(model_file_path, '*.npz'))))
    meta_data = np.load(meta_file_pointer)

    # Prepare image normalization values
    training_max = np.max(meta_data['max_array']).astype(np.float32)
    training_min = np.min(meta_data['min_array']).astype(np.float32)

    # Find model checkpoints
    ds_dt_stamp = re.split('/', model_file)[-2]
    out_dir = os.path.join(config.results, ds_dt_stamp)

    # Make output directories if they do not exist
    dir_list = [config.results, out_dir]
    [make_dir(d) for d in dir_list]

    # Prepare data on CPU
    images = tf.placeholder(
        tf.float32,
        shape=[None] + config.model_image_size,
        name='images')
    labels = tf.placeholder(
        tf.int64,
        shape=[None],
        name='labels')

    # Prepare model on GPU
    with tf.device('/gpu:0'):
        with tf.variable_scope('cnn'):
            vgg = vgg16.model_struct(
                vgg16_npy_path=config.vgg16_weight_path,
                fine_tune_layers=config.fine_tune_layers)
            vgg.build(
                images,
                output_shape=config.output_shape)

        # Setup validation op
        scores = vgg.fc7
        preds = tf.argmax(vgg.prob, 1)
        activity_pattern = vgg.fc8
        if not untargeted:
            oh_labels = tf.one_hot(labels, config.output_shape)
            activity_pattern *= oh_labels
        grad_image = tf.gradients(activity_pattern, images)

    # Set up saver
    saver = tf.train.Saver(tf.global_variables())

    # Loop through each checkpoint then test the entire validation set
    ckpts = [model_file]
    ckpt_yhat, ckpt_y, ckpt_scores = [], [], []
    ckpt_file_array, ckpt_viz_images = [], []
    print '-' * 60
    print 'Beginning evaluation'
    print '-' * 60

    if config.validation_batch > len(combined_files):
        print 'Trimming validation_batch to %s (same as # of files).' % len(
            combined_files)
        config.validation_batch = len(combined_files)

    count = 0
    for idx, c in tqdm(enumerate(ckpts), desc='Running checkpoints'):
        dec_scores, yhat, y, file_array, viz_images = [], [], [], [], []
        # Initialize the graph
        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
        sess.run(
            tf.group(
                tf.global_variables_initializer(),
                tf.local_variables_initializer()))

        # Set up exemplar threading
        saver.restore(sess, c)
        start_time = time.time()
        num_batches = np.floor(
            len(combined_files) / float(
                config.validation_batch)).astype(int)
        for image_batch, label_batch, file_batch in tqdm(
                image_batcher(
                    start=0,
                    num_batches=num_batches,
                    images=combined_files,
                    labels=combined_labels,
                    config=config,
                    training_max=training_max,
                    training_min=training_min,
                    num_channels=num_channels,
                    per_timepoint=per_timepoint),
                total=num_batches):
            feed_dict = {
                images: image_batch,
                labels: label_batch
            }
            it_grads = np.zeros((image_batch.shape))
            sc, tyh = sess.run(
                [scores, preds],
                feed_dict=feed_dict)
            for idx in range(smooth_iterations):
                feed_dict = {
                    images: add_noise(image_batch),
                    labels: label_batch
                }
                it_grad = sess.run(
                    grad_image,
                    feed_dict=feed_dict)
                it_grads += it_grad[0]
            it_grads /= smooth_iterations  # Mean across iterations
            it_grads = visualization_function(it_grads, viz)

            # Save each grad individually
            for grad_i, pred_i, file_i, label_i in zip(
                    it_grads, tyh, file_batch, label_batch):
                out_pointer = os.path.join(
                    output_folder,
                    file_i.split(os.path.sep)[-1])
                out_pointer = out_pointer.split('.')[0] + '.png'
                f = plt.figure()
                plt.imshow(grad_i)
                plt.title('Pred=%s, label=%s' % (pred_i, label_batch))
                plt.savefig(out_pointer)
                plt.close(f)

            # Plot a moisaic of the grads
            if viz == 'none':
                pos_grads = normalize(np.maximum(it_grads, 0))
                neg_grads = normalize(np.minimum(it_grads, 0))
                alpha_mosaic(
                    image_batch,
                    pos_grads,
                    'pos_batch_%s.pdf' % count,
                    title='Positive gradient overlays.',
                    rc=1,
                    cc=len(image_batch),
                    cmap=plt.cm.Reds)
                alpha_mosaic(
                    image_batch,
                    neg_grads,
                    'neg_batch_%s.pdf' % count,
                    title='Negative gradient overlays.',
                    rc=1,
                    cc=len(image_batch),
                    cmap=plt.cm.Reds)
            else:
                alpha_mosaic(
                    image_batch,
                    it_grads,
                    'batch_%s.pdf' % count,
                    title='Gradient overlays.',
                    rc=1,
                    cc=len(image_batch),
                    cmap=plt.cm.Reds)
            count += 1

            # Store the results
            dec_scores += [sc]
            yhat = np.append(yhat, tyh)
            y = np.append(y, label_batch)
            file_array = np.append(file_array, file_batch)
            viz_images += [it_grads]
        ckpt_yhat.append(yhat)
        ckpt_y.append(y)
        ckpt_scores.append(dec_scores)
        ckpt_file_array.append(file_array)
        ckpt_viz_images.append(viz_images)
        print 'Batch %d took %.1f seconds' % (
            idx, time.time() - start_time)
    sess.close()

    # Save everything
    np.savez(
        os.path.join(out_dir, 'validation_accuracies'),
        ckpt_yhat=ckpt_yhat,
        ckpt_y=ckpt_y,
        ckpt_scores=ckpt_scores,
        ckpt_names=ckpts,
        combined_files=ckpt_file_array,
        ckpt_viz_images=ckpt_viz_images)
示例#19
0
def test_vgg16(validation_data, model_dir, which_set, selected_ckpts):
    config = GEDIconfig()
    blur_kernel = config.hm_blur
    if validation_data is None:  # Use globals
        validation_data = os.path.join(config.tfrecord_dir,
                                       config.tf_record_names[which_set])
        meta_data = np.load(
            os.path.join(config.tfrecord_dir, 'val_%s' % config.max_file))
    else:
        meta_data = np.load('%s_maximum_value.npz' %
                            validation_data.split('.tfrecords')[0])
    label_list = os.path.join(
        config.processed_image_patch_dir,
        'list_of_' + '_'.join(x
                              for x in config.image_prefixes) + '_labels.txt')
    with open(label_list) as f:
        file_pointers = [l.rstrip('\n') for l in f.readlines()]

    # Prepare image normalization values
    try:
        max_value = np.max(meta_data['max_array']).astype(np.float32)
    except:
        max_value = np.asarray([config.max_gedi])
    try:
        min_value = np.max(meta_data['min_array']).astype(np.float32)
    except:
        min_value = np.asarray([config.min_gedi])

    # Find model checkpoints
    ds_dt_stamp = re.split('/', model_dir)[-1]
    out_dir = os.path.join(config.results, ds_dt_stamp + '/')
    try:
        config = np.load(os.path.join(out_dir, 'meta_info.npy')).item()
        # Make sure this is always at 1
        config.validation_batch = 1
        print '-' * 60
        print 'Loading config meta data for:%s' % out_dir
        print '-' * 60
    except:
        print '-' * 60
        print 'Using config from gedi_config.py for model:%s' % out_dir
        print '-' * 60

    # Make output directories if they do not exist
    im_shape = config.gedi_image_size

    # Prepare data on CPU
    with tf.device('/cpu:0'):
        val_images, val_labels = inputs(validation_data,
                                        1,
                                        im_shape,
                                        config.model_image_size[:2],
                                        max_value=max_value,
                                        min_value=min_value,
                                        num_epochs=1,
                                        normalize=config.normalize)

    # Prepare model on GPU
    with tf.device('/gpu:0'):
        with tf.variable_scope('cnn'):
            vgg = vgg16.Vgg16(vgg16_npy_path=config.vgg16_weight_path,
                              fine_tune_layers=config.fine_tune_layers)
            vgg.build(val_images, output_shape=config.output_shape)

        # Setup validation op
        preds = tf.argmax(vgg.prob, 1)
        targets = tf.cast(val_labels, dtype=tf.int64)
        grad_labels = tf.one_hot(val_labels,
                                 config.output_shape,
                                 dtype=tf.float32)
        heatmap_op = tf.gradients(vgg.fc8 * grad_labels, val_images)[0]

    # Set up saver
    saver = tf.train.Saver(tf.global_variables())
    ckpts = [selected_ckpts]

    # Loop through each checkpoint then test the entire validation set
    print '-' * 60
    print 'Beginning evaluation on ckpt: %s' % ckpts
    print '-' * 60
    yhat, y, tn_hms, tp_hms, fn_hms, fp_hms = [], [], [], [], [], []
    tn_ims, tp_ims, fn_ims, fp_ims = [], [], [], []
    for idx, c in tqdm(enumerate(ckpts), desc='Running checkpoints'):
        try:
            # Initialize the graph
            sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
            sess.run(
                tf.group(tf.global_variables_initializer(),
                         tf.local_variables_initializer()))

            # Set up exemplar threading
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            saver.restore(sess, c)
            start_time = time.time()
            while not coord.should_stop():
                tyh, ty, thm, tim = sess.run(
                    [preds, targets, heatmap_op, val_images])
                tyh = tyh[0]
                ty = ty[0]
                tim = (tim / tim.max()).squeeze()
                yhat += [tyh]
                y += [ty]
                if tyh == ty and not tyh:  # True negative
                    tn_hms += [hm_normalize(thm)]
                    tn_ims += [tim]
                elif tyh == ty and tyh:  # True positive
                    tp_hms += [hm_normalize(thm)]
                    tp_ims += [tim]
                elif tyh != ty and not tyh:  # False negative
                    fn_hms += [hm_normalize(thm)]
                    fn_ims += [tim]
                elif tyh != ty and tyh:  # False positive
                    fp_hms += [hm_normalize(thm)]
                    fp_ims += [tim]
        except tf.errors.OutOfRangeError:
            print 'Batch %d took %.1f seconds' % (idx,
                                                  time.time() - start_time)
        finally:
            coord.request_stop()
        coord.join(threads)
        sess.close()

    # Plot images -- add to a dict and incorporate file_pointers
    dir_pointer = os.path.join(config.heatmap_source_images, ds_dt_stamp)
    stem_dirs = ['tn', 'tp', 'fn', 'fp']
    dir_list = [dir_pointer]
    dir_list += [os.path.join(dir_pointer, x) for x in stem_dirs]
    [make_dir(d) for d in dir_list]
    loop_plot(tn_ims,
              tn_hms,
              'True negative',
              os.path.join(dir_pointer, 'tn'),
              blur=blur_kernel)
    loop_plot(tp_ims,
              tp_hms,
              'True positive',
              os.path.join(dir_pointer, 'tp'),
              blur=blur_kernel)
    loop_plot(fn_ims,
              fn_hms,
              'False negative',
              os.path.join(dir_pointer, 'fn'),
              blur=blur_kernel)
    loop_plot(fp_ims,
              fp_hms,
              'False positive',
              os.path.join(dir_pointer, 'fp'),
              blur=blur_kernel)
示例#20
0
def test_vgg16(image_dir,
               model_file,
               autopsy_csv=None,
               autopsy_path=None,
               output_csv='prediction_file',
               target_layer='fc7',
               save_npy=False,
               shuffle_images=True,
               embedding_type='PCA'):
    """Testing function for pretrained vgg16."""
    assert autopsy_csv is not None, 'You must pass an autopsy file name.'
    assert autopsy_path is not None, 'You must pass an autopsy path.'

    # Load autopsy information
    autopsy_data = pd.read_csv(os.path.join(autopsy_path, autopsy_csv))

    # Load config and begin preparing data
    config = GEDIconfig()
    if image_dir is None:
        raise RuntimeError(
            'You need to supply a directory path to the images.')

    combined_files = np.asarray(
        glob(os.path.join(image_dir, '*%s' % config.raw_im_ext)))
    if shuffle_images:
        combined_files = combined_files[np.random.permutation(
            len(combined_files))]
    if len(combined_files) == 0:
        raise RuntimeError('Could not find any files. Check your image path.')

    config = GEDIconfig()
    meta_file_pointer = os.path.join(
        model_file.split('/model')[0], 'train_maximum_value.npz')
    if not os.path.exists(meta_file_pointer):
        raise RuntimeError(
            'Cannot find the training data meta file.'
            'Download this from the link described in the README.md.')
    meta_data = np.load(meta_file_pointer)

    # Prepare image normalization values
    training_max = np.max(meta_data['max_array']).astype(np.float32)
    training_min = np.min(meta_data['min_array']).astype(np.float32)

    # Find model checkpoints
    ds_dt_stamp = re.split('/', model_file)[-2]
    out_dir = os.path.join(config.results, ds_dt_stamp)

    # Make output directories if they do not exist
    dir_list = [config.results, out_dir]
    [make_dir(d) for d in dir_list]

    # Prepare data on CPU
    images = tf.placeholder(tf.float32,
                            shape=[None] + config.model_image_size,
                            name='images')

    # Prepare model on GPU
    with tf.device('/gpu:0'):
        with tf.variable_scope('cnn'):
            vgg = vgg16.model_struct(vgg16_npy_path=config.vgg16_weight_path,
                                     fine_tune_layers=config.fine_tune_layers)
            vgg.build(images, output_shape=config.output_shape)

        # Setup validation op
        scores = vgg[target_layer]
        preds = tf.argmax(vgg.prob, 1)

    # Derive pathologies from file names
    pathologies = []
    for f in combined_files:
        sf = f.split('/')[-1].split('_')
        sf = '_'.join(sf[1:4])
        it_path = autopsy_data[autopsy_data['plate_well_neuron'] ==
                               sf]['disease']
        if not len(it_path):
            it_path = 'Absent'
        else:
            it_path = it_path.as_matrix()[0]
        pathologies += [it_path]
    pathologies = np.asarray(pathologies)

    # Set up saver
    saver = tf.train.Saver(tf.global_variables())

    # Loop through each checkpoint then test the entire validation set
    ckpts = [model_file]
    ckpt_yhat, ckpt_scores, ckpt_file_array = [], [], []
    print '-' * 60
    print 'Beginning evaluation'
    print '-' * 60

    if config.validation_batch > len(combined_files):
        print 'Trimming validation_batch size to %s.' % len(combined_files)
        config.validation_batch = len(combined_files)

    for idx, c in tqdm(enumerate(ckpts), desc='Running checkpoints'):
        dec_scores, yhat, file_array = [], [], []
        # Initialize the graph
        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
        sess.run(
            tf.group(tf.global_variables_initializer(),
                     tf.local_variables_initializer()))

        # Set up exemplar threading
        saver.restore(sess, c)
        start_time = time.time()
        num_batches = np.floor(
            len(combined_files) / float(config.validation_batch)).astype(int)
        for image_batch, file_batch in tqdm(image_batcher(
                start=0,
                num_batches=num_batches,
                images=combined_files,
                config=config,
                training_max=training_max,
                training_min=training_min),
                                            total=num_batches):
            feed_dict = {images: image_batch}
            sc, tyh = sess.run([scores, preds], feed_dict=feed_dict)
            dec_scores += [sc]
            yhat += [tyh]
            file_array += [file_batch]
        ckpt_yhat.append(yhat)
        ckpt_scores.append(dec_scores)
        ckpt_file_array.append(file_array)
        print 'Batch %d took %.1f seconds' % (idx, time.time() - start_time)
    sess.close()

    # Create and plot an embedding
    im_path_map = pathologies[:num_batches * config.validation_batch]
    dec_scores = np.concatenate(dec_scores)
    mu = dec_scores.mean(0)[None, :]
    sd = dec_scores.std(0)[None, :]
    dec_scores = (dec_scores - mu) / sd
    yhat = np.concatenate(yhat)
    file_array = np.concatenate(file_array)

    if embedding_type == 'TSNE' or embedding_type == 'tsne':
        emb = manifold.TSNE(n_components=2, init='pca', random_state=0)
    elif embedding_type == 'PCA' or embedding_type == 'pca':
        emb = PCA(n_components=2, svd_solver='randomized', random_state=0)
    elif embedding_type == 'spectral':
        emb = manifold.SpectralEmbedding(n_components=2, random_state=0)
    y = emb.fit_transform(dec_scores)

    # Ouput csv
    df = pd.DataFrame(np.hstack(
        (y, im_path_map.reshape(-1, 1), file_array.reshape(-1, 1))),
                      columns=['D1', 'D2', 'pathology', 'filename'])
    out_name = os.path.join(out_dir, 'embedding.csv')
    df.to_csv(out_name)
    print 'Saved csv to: %s' % out_name

    # Create plot
    f, ax = plt.subplots()
    unique_cats = np.unique(im_path_map)
    h = []
    for idx, cat in enumerate(unique_cats):
        h += [
            plt.scatter(y[im_path_map == cat, 0],
                        y[im_path_map == cat, 1],
                        c=plt.cm.Spectral(idx * 1000))
        ]
    plt.legend(h, unique_cats)
    plt.axis('tight')
    plt.show()
    plt.savefig('embedding.png')
    plt.close(f)

    # Save everything
    if save_npy:
        np.savez(os.path.join(out_dir, 'validation_accuracies'),
                 ckpt_yhat=ckpt_yhat,
                 ckpt_scores=ckpt_scores,
                 ckpt_names=ckpts,
                 combined_files=ckpt_file_array)