示例#1
0
def train_vgg16(train_dir=None, validation_dir=None):
    config = GEDIconfig()
    if train_dir is None:  # Use globals
        train_data = os.path.join(config.tfrecord_dir,
                                  config.tf_record_names['train'])
        meta_data = np.load(
            os.path.join(config.tfrecord_dir,
                         '%s_%s' % (config.tvt_flags[0], config.max_file)))
    else:
        meta_data = np.load(
            os.path.join(train_dir,
                         '%s_%s' % (config.tvt_flags[0], config.max_file)))

    # Prepare image normalization values
    if config.max_gedi is None:
        max_value = np.nanmax(meta_data['max_array']).astype(np.float32)
        if max_value == 0:
            max_value = None
            print('Derived max value is 0')
        else:
            print('Normalizing with empirical max.')
        if 'min_array' in meta_data.keys():
            min_value = np.min(meta_data['min_array']).astype(np.float32)
            print('Normalizing with empirical min.')
        else:
            min_value = None
            print('Not normalizing with a min.')
    else:
        max_value = config.max_gedi
        min_value = config.min_gedi
        print('max value', max_value)
    ratio = meta_data['ratio']
    if config.encode_time_of_death:
        tod = pd.read_csv(config.encode_time_of_death)
        tod_data = tod['dead_tp'].as_matrix()
        mask = np.isnan(tod_data).astype(int) + (
            tod['plate_well_neuron'] == 'empty').as_matrix().astype(int)
        tod_data = tod_data[mask == 0]
        tod_data = tod_data[tod_data <
                            10]  # throw away values have a high number
        config.output_shape = len(np.unique(tod_data))
        ratio = class_weight.compute_class_weight('balanced',
                                                  np.sort(np.unique(tod_data)),
                                                  tod_data)
        flip_ratio = False
    else:
        flip_ratio = True
    print('Ratio is: %s' % ratio)

    if validation_dir is None:  # Use globals
        validation_data = os.path.join(config.tfrecord_dir,
                                       config.tf_record_names['val'])
    elif validation_dir is False:
        pass  # Do not use validation data during training

    # Make output directories if they do not exist
    dt_stamp = re.split(
        '\.', str(datetime.now()))[0].\
        replace(' ', '_').replace(':', '_').replace('-', '_')
    dt_dataset = config.which_dataset + '_' + dt_stamp + '/'
    config.train_checkpoint = os.path.join(config.train_checkpoint,
                                           dt_dataset)  # timestamp this run
    out_dir = os.path.join(config.results, dt_dataset)
    dir_list = [
        config.train_checkpoint, config.train_summaries, config.results,
        out_dir
    ]
    [make_dir(d) for d in dir_list]
    # im_shape = get_image_size(config)
    im_shape = config.gedi_image_size

    print('-' * 60)
    print('Training model:' + dt_dataset)
    print('-' * 60)

    # Prepare data on CPU
    assert os.path.exists(train_data)
    assert os.path.exists(validation_data)
    assert os.path.exists(config.vgg16_weight_path)
    with tf.device('/cpu:0'):
        train_images, train_labels = inputs(train_data,
                                            config.train_batch,
                                            im_shape,
                                            config.model_image_size[:2],
                                            max_value=max_value,
                                            min_value=min_value,
                                            train=config.data_augmentations,
                                            num_epochs=config.epochs,
                                            normalize=config.normalize)
        val_images, val_labels = inputs(validation_data,
                                        config.validation_batch,
                                        im_shape,
                                        config.model_image_size[:2],
                                        max_value=max_value,
                                        min_value=min_value,
                                        num_epochs=config.epochs,
                                        normalize=config.normalize)
        tf.summary.image('train images', train_images)
        tf.summary.image('validation images', val_images)

    # Prepare model on GPU
    with tf.device('/gpu:0'):
        with tf.variable_scope('cnn') as scope:
            if config.ordinal_classification == 'ordinal':  # config.output_shape > 2:  # Hardcoded fix for timecourse pred
                vgg_output = config.output_shape * 2
            elif config.ordinal_classification == 'regression':
                vgg_output = 1
            elif config.ordinal_classification is None:
                vgg_output = config.output_shape
            vgg = vgg16.Vgg16(vgg16_npy_path=config.vgg16_weight_path,
                              fine_tune_layers=config.fine_tune_layers)
            train_mode = tf.get_variable(name='training', initializer=True)
            vgg.build(train_images,
                      output_shape=vgg_output,
                      train_mode=train_mode,
                      batchnorm=config.batchnorm_layers)

            # Prepare the cost function
            if config.ordinal_classification == 'ordinal':
                # Encode y w/ k-hot and yhat w/ sigmoid ce. units capture dist.
                enc = tf.concat([
                    tf.reshape(tf.range(0, config.output_shape), [1, -1])
                    for x in range(config.train_batch)
                ],
                                axis=0)
                enc_train_labels = tf.cast(
                    tf.greater_equal(enc, tf.expand_dims(train_labels,
                                                         axis=1)), tf.float32)
                split_labs = tf.split(enc_train_labels,
                                      config.output_shape,
                                      axis=1)
                res_output = tf.reshape(
                    vgg.fc8, [config.train_batch, 2, config.output_shape])
                split_logs = tf.split(res_output, config.output_shape, axis=2)
                if config.balance_cost:
                    cost = tf.add_n([
                        tf.nn.softmax_cross_entropy_with_logits(
                            labels=tf.one_hot(tf.cast(tf.squeeze(s), tf.int32),
                                              2),
                            logits=tf.squeeze(l)) * r
                        for s, l, r in zip(split_labs, split_logs, ratio)
                    ])
                else:
                    cost = tf.add_n([
                        tf.nn.softmax_cross_entropy_with_logits(
                            labels=tf.one_hot(tf.cast(tf.squeeze(s), tf.int32),
                                              2),
                            logits=tf.squeeze(l))
                        for s, l in zip(split_labs, split_logs)
                    ])
                cost = tf.reduce_mean(cost)
            elif config.ordinal_classification == 'regression':
                if config.balance_cost:
                    weight_vec = tf.gather(train_labels, ratio)
                    cost = tf.reduce_mean(
                        tf.pow(
                            (vgg.fc8) - tf.cast(train_labels, tf.float32), 2) *
                        weight_vec)
                else:
                    cost = tf.nn.l2_loss((vgg.fc8) -
                                         tf.cast(train_labels, tf.float32))
            else:
                if config.balance_cost:
                    cost = softmax_cost(vgg.fc8,
                                        train_labels,
                                        ratio=ratio,
                                        flip_ratio=flip_ratio)
                else:
                    cost = softmax_cost(vgg.fc8, train_labels)
            tf.summary.scalar("cost", cost)

            # Finetune the learning rates
            if config.wd_layers is not None:
                _, l2_wd_layers = fine_tune_prepare_layers(
                    tf.trainable_variables(), config.wd_layers)
                l2_wd_layers = [
                    x for x in l2_wd_layers if 'biases' not in x.name
                ]
                cost += (config.wd_penalty *
                         tf.add_n([tf.nn.l2_loss(x) for x in l2_wd_layers]))

            # for all variables in trainable variables
            # print name if there's duplicates you f****d up
            other_opt_vars, ft_opt_vars = fine_tune_prepare_layers(
                tf.trainable_variables(), config.fine_tune_layers)
            if config.optimizer == 'adam':
                train_op = ft_non_optimized(cost, other_opt_vars, ft_opt_vars,
                                            tf.train.AdamOptimizer,
                                            config.hold_lr, config.new_lr)
            elif config.optimizer == 'sgd':
                train_op = ft_non_optimized(cost, other_opt_vars, ft_opt_vars,
                                            tf.train.GradientDescentOptimizer,
                                            config.hold_lr, config.new_lr)

            if config.ordinal_classification == 'ordinal':
                arg_guesses = tf.cast(
                    tf.reduce_sum(tf.squeeze(tf.argmax(res_output, axis=1)),
                                  reduction_indices=[1]), tf.int32)
                train_accuracy = tf.reduce_mean(
                    tf.cast(tf.equal(arg_guesses, train_labels), tf.float32))
            elif config.ordinal_classification == 'regression':
                train_accuracy = tf.reduce_mean(
                    tf.cast(
                        tf.equal(tf.cast(tf.round(vgg.prob), tf.int32),
                                 train_labels), tf.float32))
            else:
                train_accuracy = class_accuracy(
                    vgg.prob, train_labels)  # training accuracy
            tf.summary.scalar("training accuracy", train_accuracy)

            # Setup validation op
            if validation_data is not False:
                scope.reuse_variables()
                # Validation graph is the same as training except no batchnorm
                val_vgg = vgg16.Vgg16(vgg16_npy_path=config.vgg16_weight_path,
                                      fine_tune_layers=config.fine_tune_layers)
                val_vgg.build(val_images, output_shape=vgg_output)
                # Calculate validation accuracy
                if config.ordinal_classification == 'ordinal':
                    val_res_output = tf.reshape(
                        val_vgg.fc8,
                        [config.validation_batch, 2, config.output_shape])
                    val_arg_guesses = tf.cast(
                        tf.reduce_sum(tf.squeeze(
                            tf.argmax(val_res_output, axis=1)),
                                      reduction_indices=[1]), tf.int32)
                    val_accuracy = tf.reduce_mean(
                        tf.cast(tf.equal(val_arg_guesses, val_labels),
                                tf.float32))
                elif config.ordinal_classification == 'regression':
                    val_accuracy = tf.reduce_mean(
                        tf.cast(
                            tf.equal(tf.cast(tf.round(val_vgg.prob), tf.int32),
                                     val_labels), tf.float32))
                else:
                    val_accuracy = class_accuracy(val_vgg.prob, val_labels)
                tf.summary.scalar("validation accuracy", val_accuracy)

    # Set up summaries and saver
    saver = tf.train.Saver(tf.global_variables(),
                           max_to_keep=config.keep_checkpoints)
    summary_op = tf.summary.merge_all()

    # Initialize the graph
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    # Need to initialize both of these if supplying num_epochs to inputs
    sess.run(
        tf.group(tf.global_variables_initializer(),
                 tf.local_variables_initializer()))
    summary_dir = os.path.join(config.train_summaries,
                               config.which_dataset + '_' + dt_stamp)
    summary_writer = tf.summary.FileWriter(summary_dir, sess.graph)

    # Set up exemplar threading
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    # Start training loop
    np.save(out_dir + 'meta_info', config)
    step, losses = 0, []  # val_max = 0
    try:
        # print response
        while not coord.should_stop():
            start_time = time.time()
            _, loss_value, train_acc = sess.run(
                [train_op, cost, train_accuracy])
            losses.append(loss_value)
            duration = time.time() - start_time
            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            if step % config.validation_steps == 0:
                if validation_data is not False:
                    _, val_acc = sess.run([train_op, val_accuracy])
                else:
                    val_acc -= 1  # Store every checkpoint

                # Summaries
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

                # Training status and validation accuracy
                format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; '
                              '%.3f sec/batch) | Training accuracy = %s | '
                              'Validation accuracy = %s | logdir = %s')
                print(format_str %
                      (datetime.now(), step,
                       loss_value, config.train_batch / duration,
                       float(duration), train_acc, val_acc, summary_dir))

                # Save the model checkpoint if it's the best yet
                if 1:  # val_acc >= val_max:
                    saver.save(sess,
                               os.path.join(config.train_checkpoint,
                                            'model_' + str(step) + '.ckpt'),
                               global_step=step)
                    # Store the new max validation accuracy
                    # val_max = val_acc

            else:
                # Training status
                format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; '
                              '%.3f sec/batch) | Training accuracy = %s')
                print(format_str %
                      (datetime.now(), step, loss_value, config.train_batch /
                       duration, float(duration), train_acc))
            # End iteration
            step += 1

    except tf.errors.OutOfRangeError:
        print('Done training for %d epochs, %d steps.' % (config.epochs, step))
    finally:
        coord.request_stop()
        np.save(os.path.join(config.tfrecord_dir, 'training_loss'), losses)
    coord.join(threads)
    sess.close()
示例#2
0
def train_vgg16(train_dir=None, validation_dir=None):
    config = GEDIconfig()
    if train_dir is None:  # Use globals
        train_data = os.path.join(config.tfrecord_dir, 'train.tfrecords')
        meta_data = np.load(
            os.path.join(
                config.tfrecord_dir, config.tvt_flags[0] + '_' +
                config.max_file))
    else:
        meta_data = np.load(
            os.path.join(
                train_dir, config.tvt_flags[0] + '_' + config.max_file))

    # Prepare image normalization values
    max_value = np.nanmax(meta_data['max_array']).astype(np.float32)
    if max_value == 0:
        max_value = None
        print 'Derived max value is 0'
    else:
        print 'Normalizing with empirical max.'
    if 'min_array' in meta_data.keys():
        min_value = np.min(meta_data['min_array']).astype(np.float32)
        print 'Normalizing with empirical min.'
    else:
        min_value = None
        print 'Not normalizing with a min.'
    ratio = meta_data['ratio']
    print 'Ratio is: %s' % ratio

    if validation_dir is None:  # Use globals
        validation_data = os.path.join(config.tfrecord_dir, 'val.tfrecords')
    elif validation_dir is False:
        pass  # Do not use validation data during training

    # Make output directories if they do not exist
    dt_stamp = re.split(
        '\.', str(datetime.now()))[0].\
        replace(' ', '_').replace(':', '_').replace('-', '_')
    dt_dataset = config.which_dataset + '_' + dt_stamp + '/'
    config.train_checkpoint = os.path.join(
        config.train_checkpoint, dt_dataset)  # timestamp this run
    out_dir = os.path.join(config.results, dt_dataset)
    dir_list = [
        config.train_checkpoint, config.train_summaries,
        config.results, out_dir]
    [make_dir(d) for d in dir_list]
    # im_shape = get_image_size(config)
    im_shape = config.gedi_image_size

    print '-'*60
    print('Training model:' + dt_dataset)
    print '-'*60

    # Prepare data on CPU
    with tf.device('/cpu:0'):
        train_images, train_labels = inputs(
            train_data,
            config.train_batch,
            im_shape,
            config.model_image_size[:2],
            max_value=max_value,
            min_value=min_value,
            train=config.data_augmentations,
            num_epochs=config.epochs,
            normalize=config.normalize)
        val_images, val_labels = inputs(
            validation_data,
            config.validation_batch,
            im_shape,
            config.model_image_size[:2],
            max_value=max_value,
            min_value=min_value,
            num_epochs=config.epochs,
            normalize=config.normalize)
        tf.summary.image('train images', train_images)
        tf.summary.image('validation images', val_images)

    # Prepare model on GPU
    num_timepoints = len(config.channel)
    with tf.device('/gpu:0'):
        with tf.variable_scope('cnn') as scope:

            # Prepare the loss function
            if config.balance_cost:
                cost_fun = lambda yhat, y: softmax_cost(yhat, y, ratio=ratio)
            else:
                cost_fun = lambda yhat, y: softmax_cost(yhat, y)

            def cond(i, state, images, output, loss, num_timepoints):  # NOT CORRECT
                return tf.less(i, num_timepoints)

            def body(
                    i,
                    images,
                    label,
                    cell,
                    state,
                    output,
                    loss,
                    vgg,
                    train_mode,
                    output_shape,
                    batchnorm_layers,
                    cost_fun,
                    score_layer='fc7'):
                vgg.build(
                    images[i],
                    output_shape=config.output_shape,
                    train_mode=train_mode,
                    batchnorm=config.batchnorm_layers)
                    it_output, state = cell(vgg[score_layer], state)
                    output= output.write(i, it_output)
                    it_loss = cost_fun(output, label)
                    loss = loss.write(i, it_loss)
                return (i+1, images, label, cell, state, output, loss, vgg, train_mode, output_shape, batchnorm_layers) 

            # Prepare LSTM loop
            train_mode = tf.get_variable(name='training', initializer=True)
            cell = lstm_layers(layer_units=config.lstm_units)
            output = tf.TensorArray(tf.float32, num_timepoints) # output array for LSTM loop -- one slot per timepoint
            loss = tf.TensorArray(tf.float32, num_timepoints)  # loss for LSTM loop -- one slot per timepoint
            i = tf.constant(0)  # iterator for LSTM loop
            loop_vars = [i, images, label, cell, state, output, loss, vgg, train_mode, output_shape, batchnorm_layers, cost_fun]

            # Run the lstm
            processed_list = tf.while_loop(
                cond=cond,
                body=body,
                loop_vars=loop_vars,
                back_prop=True,
                swap_memory=False)
            output_cell = processed_list[3]
            output_state = processed_list[4]
            output_activity = processed_list[5]
            cost = processed_list[6]

            # Optimize
            combined_cost = tf.reduce_sum(cost)
            tf.summary.scalar('cost', combined_cost)
            import ipdb;ipdb.set_trace()  # Need to make sure lstm weights are being trained
            other_opt_vars, ft_opt_vars = fine_tune_prepare_layers(
                tf.trainable_variables(), config.fine_tune_layers)
            if config.optimizer == 'adam':
                train_op = ft_non_optimized(
                    cost, other_opt_vars, ft_opt_vars,
                    tf.train.AdamOptimizer, config.hold_lr, config.new_lr)
            elif config.optimizer == 'sgd':
                train_op = ft_non_optimized(
                    cost, other_opt_vars, ft_opt_vars,
                    tf.train.GradientDescentOptimizer,
                    config.hold_lr, config.new_lr)

            train_accuracy = class_accuracy(
                vgg.prob, train_labels)  # training accuracy
            tf.summary.scalar("training accuracy", train_accuracy)

            # Setup validation op
            if validation_data is not False:
                scope.reuse_variables()
                # Validation graph is the same as training except no batchnorm
                val_vgg = vgg16.Vgg16(
                    vgg16_npy_path=config.vgg16_weight_path,
                    fine_tune_layers=config.fine_tune_layers)
                val_vgg.build(val_images, output_shape=config.output_shape)
                # Calculate validation accuracy
                val_accuracy = class_accuracy(val_vgg.prob, val_labels)
                tf.summary.scalar("validation accuracy", val_accuracy)
示例#3
0
def test_vgg16(
        which_dataset,
        validation_data=None,
        model_dir=None):  #Fine tuning defaults to wipe out the final two FCs
    config = GEDIconfig(which_dataset)
    if validation_data == None:  #Use globals
        validation_data = config.tfrecord_dir + 'val.tfrecords'

    #Make output directories if they do not exist
    out_dir = config.results + config.which_dataset + '/'
    dir_list = [config.results, out_dir]
    [make_dir(d) for d in dir_list]
    im_shape = get_image_size(config)

    #Find model checkpoints
    ckpts, ckpt_names = find_ckpts(config)

    #Prepare data on CPU
    with tf.device('/cpu:0'):
        val_images, val_labels = inputs(validation_data,
                                        config.validation_batch,
                                        im_shape,
                                        config.model_image_size[:2],
                                        num_epochs=1)

    #Prepare model on GPU
    with tf.device('/gpu:0'):
        vgg = vgg16.Vgg16(vgg16_npy_path=config.vgg16_weight_path,
                          fine_tune_layers=config.fine_tune_layers)
        validation_mode = tf.Variable(False, name='training')
        vgg.build(val_images,
                  output_shape=config.output_shape,
                  train_mode=validation_mode)

        #Setup validation op
        eval_accuracy = class_accuracy(vgg.prob,
                                       val_labels)  #training accuracy now...

    #Set up saver
    saver = tf.train.Saver(tf.all_variables())

    #Loop through each checkpoint, loading the model weights, then testing the entire validation set
    ckpt_accs = []
    for idx in tqdm(range(len(ckpts))):
        accs = []
        try:
            #Initialize the graph
            sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
            sess.run(
                tf.group(tf.initialize_all_variables(),
                         tf.initialize_local_variables())
            )  #need to initialize both if supplying num_epochs to inputs

            #Set up exemplar threading
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            saver.restore(sess, ckpts[idx])
            start_time = time.time()
            while not coord.should_stop():
                accs = np.append(accs, sess.run([eval_accuracy]))

        except tf.errors.OutOfRangeError:
            ckpt_accs.append(accs)
            print('Batch %d took %.1f seconds', idx, time.time() - start_time)
        finally:
            coord.request_stop()
        coord.join(threads)
        sess.close()

    #Plot everything
    plot_accuracies(ckpt_accs, ckpt_names,
                    out_dir + 'validation_accuracies.png')
    np.savez(out_dir + 'validation_accuracies',
             ckpt_accs=ckpt_accs,
             ckpt_names=ckpt_names)
示例#4
0
def train_vgg16(which_dataset,train_data=None,validation_data=None): #Fine tuning defaults to wipe out the final two FCs
    config = GEDIconfig(which_dataset)
    if train_data == None: #Use globals
        train_data = config.tfrecord_dir + 'train.tfrecords'
    if validation_data == None: #Use globals
        validation_data = config.tfrecord_dir + 'val.tfrecords'

    #Make output directories if they do not exist
    dt_stamp = re.split(' ',str(datetime.now()))[0].replace('-','_')
    config.train_checkpoint = os.path.join(config.train_checkpoint,config.which_dataset + '_' + dt_stamp) #timestamp this run
    out_dir = config.results + config.which_dataset + '/'
    dir_list = [config.train_checkpoint,config.train_summaries,config.results,out_dir]
    [make_dir(d) for d in dir_list]
    im_shape = get_image_size(config)

    #Prepare data on CPU
    with tf.device('/cpu:0'):
        train_images, train_labels = inputs(train_data, config.train_batch, im_shape, config.model_image_size[:2], train=config.data_augmentations, num_epochs=config.epochs)
        tf.image_summary('train images', train_images)

    #Prepare model on GPU
    with tf.device('/gpu:0'):
        vgg = vgg16.Vgg16(vgg16_npy_path=config.vgg16_weight_path,fine_tune_layers=config.fine_tune_layers)
        train_mode = tf.Variable(True, name='training')
        vgg.build(train_images,output_shape=config.output_shape,train_mode=train_mode)

        #Prepare the cost function
        cost = softmax_cost(vgg.fc8, train_labels)
        tf.scalar_summary("cost", cost)
        
        #Finetune the learning rates
        other_opt_vars,ft_opt_vars = fine_tune_prepare_layers(tf.trainable_variables(),config.fine_tune_layers) #for all variables in trainable variables, print name if there's duplicates you f****d up
        train_op = ft_non_optimized(cost,other_opt_vars,ft_opt_vars,tf.train.AdamOptimizer,config.hold_lr,config.new_lr) #actually is faster :)

        #Setup validation op
        eval_accuracy = class_accuracy(vgg.prob,train_labels) #training accuracy now...
        tf.scalar_summary("train accuracy", eval_accuracy)

    #Set up summaries and saver
    saver = tf.train.Saver(tf.all_variables(),max_to_keep=100)
    summary_op = tf.merge_all_summaries()

    #Initialize the graph
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement = True))
    sess.run(tf.group(tf.initialize_all_variables(), tf.initialize_local_variables())) #need to initialize both if supplying num_epochs to inputs
    summary_writer = tf.train.SummaryWriter(os.path.join(config.train_summaries,config.which_dataset + '_' + dt_stamp), sess.graph)

    #Set up exemplar threading
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess,coord=coord)

    #Start training loop
    try:
        step = 0
        losses = []
        accs=[]
        while not coord.should_stop():
            start_time = time.time()
            _, loss_value, acc = sess.run([train_op, cost, eval_accuracy])
            losses.append(loss_value)
            accs.append(acc)
            duration = time.time() - start_time
            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            if step % 100 == 0:
                if step % 500 == 0:
                    #Summaries
                    summary_str = sess.run(summary_op)
                    summary_writer.add_summary(summary_str, step)
                else:
                    #Training status
                    format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                                  'sec/batch)')
                    print (format_str % (datetime.now(), step, loss_value,
                                         config.train_batch / duration, float(duration)))

            # Save the model checkpoint periodically.
            if step % 500 == 0 and step>0 :
                saver.save(sess, os.path.join(config.train_checkpoint, 'model_' + str(step) + '.ckpt'),
                 global_step=step)
                if np.average(accs[-500:-1])>0.95:
                     coord.request_stop()
            step += 1

    except tf.errors.OutOfRangeError:
        print('Done training for %d epochs, %d steps.' % (config.epochs, step))
        saver.save(sess, os.path.join(config.train_checkpoint, 'model_' + str(step) + '.ckpt'),
                 global_step=step)
    finally:
        coord.request_stop()
    coord.join(threads)
    sess.close()
    np.save(out_dir + 'training_loss',losses)
示例#5
0
def train_att(which_dataset,train_data=None,validation_data=None): #Fine tuning defaults to wipe out the final two FCs
    config = GEDIconfig(which_dataset)
    if train_data == None: #Use globals
        train_data = config.tfrecord_dir + 'train.tfrecords'
    if validation_data == None: #Use globals
        validation_data = config.tfrecord_dir + 'val.tfrecords'

    #Make output directories if they do not exist
    dt_stamp = re.split(' ',str(datetime.now()))[0].replace('-','_')
    config.train_checkpoint = os.path.join(config.train_checkpoint,config.which_dataset + '_' + dt_stamp) #timestamp this run
    out_dir = config.results + config.which_dataset + '/'
    dir_list = [config.train_checkpoint,config.train_summaries,config.results,out_dir]
    [make_dir(d) for d in dir_list]
    im_shape = get_image_size(config)

    #Prepare data on CPU
    with tf.device('/cpu:0'):
        train_images, train_labels = inputs(train_data, config.train_batch, im_shape, config.model_image_size[:2], train=config.data_augmentations, num_epochs=config.epochs)
        tf.image_summary('train images', train_images)

    #Prepare model on GPU
    with tf.device('/gpu:0'):
        att_model = att.Attention()
       
        train_mode = tf.Variable(True, name='training')
        att_model.build(train_images,enc_size=config.enc_size,read_n=config.read_n,T=config.T,output_shape=config.output_shape,train_mode=train_mode)

        #Prepare the cost function
        cost = softmax_cost(att_model.fc2, train_labels)
        
        tf.scalar_summary("cost", cost)
        #print type(tf.trainable_variables()[0])
        print [x.name for x in tf.trainable_variables()]
        #Finetune the learning rates
        optimizer=tf.train.AdamOptimizer(learning_rate=config.new_lr,beta1=0.5)
        grads=optimizer.compute_gradients(cost)
        #v_i=att_model.v_i
	for i,(g,v) in enumerate(grads):
	    if g is not None:
            
		grads[i]=(tf.clip_by_norm(g,5),v) # clip gradients
	train_op=optimizer.apply_gradients(grads)
        #print a
        #Setup validation op
        eval_accuracy = class_accuracy(att_model.prob,train_labels) #training accuracy now...
        tf.scalar_summary("train accuracy", eval_accuracy)

    #Set up summaries and saver
    saver = tf.train.Saver(tf.all_variables(),max_to_keep=100)
    summary_op = tf.merge_all_summaries()

    #Initialize the graph
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement = True))
    sess.run(tf.group(tf.initialize_all_variables(), tf.initialize_local_variables())) #need to initialize both if supplying num_epochs to inputs
    summary_writer = tf.train.SummaryWriter(os.path.join(config.train_summaries,config.which_dataset + '_' + dt_stamp), sess.graph)

    #Set up exemplar threading
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess,coord=coord)

    #Start training loop
    try:
        step = 0
        losses = []
        accs=[]
        while not coord.should_stop():
            start_time = time.time()
            _,loss_value, acc= sess.run([train_op, cost, eval_accuracy])
            #print v_j
            losses.append(loss_value)
            accs.append(acc)
  
            duration = time.time() - start_time
            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            if step % 100 == 0:
                if step % 2000 == 0:
                    #Summaries
                    summary_str = sess.run(summary_op)
                    summary_writer.add_summary(summary_str, step)
                else:
                    #Training status
                    format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                                  'sec/batch)')
                    print (format_str % (datetime.now(), step, loss_value,
                                         config.train_batch / duration, float(duration)))
               
                if np.average(accs[-100:-1])>0.9 or np.average(losses[-100:-1])<0.1:
                     saver.save(sess, os.path.join(config.train_checkpoint, 'model_' + str(step) + '.ckpt'),
                 global_step=step)
                
                     coord.request_stop()

            # Save the model checkpoint periodically.
            if step % 1000 == 0 and step>0 :
                saver.save(sess, os.path.join(config.train_checkpoint, 'model_' + str(step) + '.ckpt'),
                 global_step=step)
                
            step += 1

    except tf.errors.OutOfRangeError:
        print('Done training for %d epochs, %d steps.' % (config.epochs, step))
        saver.save(sess, os.path.join(config.train_checkpoint, 'model_' + str(step) + '.ckpt'),
                 global_step=step)
    finally:
        coord.request_stop()
    coord.join(threads)
    sess.close()
    np.save(out_dir + 'training_loss',losses)
示例#6
0
def test_vgg16(validation_data, model_dir, which_set, selected_ckpts):
    config = GEDIconfig()
    blur_kernel = config.hm_blur
    if validation_data is None:  # Use globals
        validation_data = os.path.join(config.tfrecord_dir,
                                       config.tf_record_names[which_set])
        meta_data = np.load(
            os.path.join(config.tfrecord_dir, 'val_%s' % config.max_file))
    else:
        meta_data = np.load('%s_maximum_value.npz' %
                            validation_data.split('.tfrecords')[0])
    label_list = os.path.join(
        config.processed_image_patch_dir,
        'list_of_' + '_'.join(x
                              for x in config.image_prefixes) + '_labels.txt')
    with open(label_list) as f:
        file_pointers = [l.rstrip('\n') for l in f.readlines()]

    # Prepare image normalization values
    try:
        max_value = np.max(meta_data['max_array']).astype(np.float32)
    except:
        max_value = np.asarray([config.max_gedi])
    try:
        min_value = np.max(meta_data['min_array']).astype(np.float32)
    except:
        min_value = np.asarray([config.min_gedi])

    # Find model checkpoints
    ds_dt_stamp = re.split('/', model_dir)[-1]
    out_dir = os.path.join(config.results, ds_dt_stamp + '/')
    try:
        config = np.load(os.path.join(out_dir, 'meta_info.npy')).item()
        # Make sure this is always at 1
        config.validation_batch = 1
        print '-' * 60
        print 'Loading config meta data for:%s' % out_dir
        print '-' * 60
    except:
        print '-' * 60
        print 'Using config from gedi_config.py for model:%s' % out_dir
        print '-' * 60

    # Make output directories if they do not exist
    im_shape = config.gedi_image_size

    # Prepare data on CPU
    with tf.device('/cpu:0'):
        val_images, val_labels = inputs(validation_data,
                                        1,
                                        im_shape,
                                        config.model_image_size[:2],
                                        max_value=max_value,
                                        min_value=min_value,
                                        num_epochs=1,
                                        normalize=config.normalize)

    # Prepare model on GPU
    with tf.device('/gpu:0'):
        with tf.variable_scope('cnn'):
            vgg = vgg16.Vgg16(vgg16_npy_path=config.vgg16_weight_path,
                              fine_tune_layers=config.fine_tune_layers)
            vgg.build(val_images, output_shape=config.output_shape)

        # Setup validation op
        preds = tf.argmax(vgg.prob, 1)
        targets = tf.cast(val_labels, dtype=tf.int64)
        grad_labels = tf.one_hot(val_labels,
                                 config.output_shape,
                                 dtype=tf.float32)
        heatmap_op = tf.gradients(vgg.fc8 * grad_labels, val_images)[0]

    # Set up saver
    saver = tf.train.Saver(tf.global_variables())
    ckpts = [selected_ckpts]

    # Loop through each checkpoint then test the entire validation set
    print '-' * 60
    print 'Beginning evaluation on ckpt: %s' % ckpts
    print '-' * 60
    yhat, y, tn_hms, tp_hms, fn_hms, fp_hms = [], [], [], [], [], []
    tn_ims, tp_ims, fn_ims, fp_ims = [], [], [], []
    for idx, c in tqdm(enumerate(ckpts), desc='Running checkpoints'):
        try:
            # Initialize the graph
            sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
            sess.run(
                tf.group(tf.global_variables_initializer(),
                         tf.local_variables_initializer()))

            # Set up exemplar threading
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            saver.restore(sess, c)
            start_time = time.time()
            while not coord.should_stop():
                tyh, ty, thm, tim = sess.run(
                    [preds, targets, heatmap_op, val_images])
                tyh = tyh[0]
                ty = ty[0]
                tim = (tim / tim.max()).squeeze()
                yhat += [tyh]
                y += [ty]
                if tyh == ty and not tyh:  # True negative
                    tn_hms += [hm_normalize(thm)]
                    tn_ims += [tim]
                elif tyh == ty and tyh:  # True positive
                    tp_hms += [hm_normalize(thm)]
                    tp_ims += [tim]
                elif tyh != ty and not tyh:  # False negative
                    fn_hms += [hm_normalize(thm)]
                    fn_ims += [tim]
                elif tyh != ty and tyh:  # False positive
                    fp_hms += [hm_normalize(thm)]
                    fp_ims += [tim]
        except tf.errors.OutOfRangeError:
            print 'Batch %d took %.1f seconds' % (idx,
                                                  time.time() - start_time)
        finally:
            coord.request_stop()
        coord.join(threads)
        sess.close()

    # Plot images -- add to a dict and incorporate file_pointers
    dir_pointer = os.path.join(config.heatmap_source_images, ds_dt_stamp)
    stem_dirs = ['tn', 'tp', 'fn', 'fp']
    dir_list = [dir_pointer]
    dir_list += [os.path.join(dir_pointer, x) for x in stem_dirs]
    [make_dir(d) for d in dir_list]
    loop_plot(tn_ims,
              tn_hms,
              'True negative',
              os.path.join(dir_pointer, 'tn'),
              blur=blur_kernel)
    loop_plot(tp_ims,
              tp_hms,
              'True positive',
              os.path.join(dir_pointer, 'tp'),
              blur=blur_kernel)
    loop_plot(fn_ims,
              fn_hms,
              'False negative',
              os.path.join(dir_pointer, 'fn'),
              blur=blur_kernel)
    loop_plot(fp_ims,
              fp_hms,
              'False positive',
              os.path.join(dir_pointer, 'fp'),
              blur=blur_kernel)
示例#7
0
def test_vgg16(validation_data, model_dir, which_set, selected_ckpts=-1):
    config = GEDIconfig()
    if validation_data is None:  # Use globals
        validation_data = os.path.join(config.tfrecord_dir,
                                       config.tf_record_names[which_set])
        meta_data = np.load(
            os.path.join(config.tfrecord_dir, 'val_%s' % config.max_file))
    else:
        meta_data = np.load('%s_maximum_value.npz' %
                            validation_data.split('.tfrecords')[0])
    label_list = os.path.join(
        config.processed_image_patch_dir,
        'list_of_' + '_'.join(x
                              for x in config.image_prefixes) + '_labels.txt')
    with open(label_list) as f:
        file_pointers = [l.rstrip('\n') for l in f.readlines()]

    # Prepare image normalization values
    try:
        max_value = np.max(meta_data['max_array']).astype(np.float32)
    except:
        max_value = np.asarray([config.max_gedi])
    try:
        min_value = np.max(meta_data['min_array']).astype(np.float32)
    except:
        min_value = np.asarray([config.min_gedi])

    # Find model checkpoints
    ckpts, ckpt_names = find_ckpts(config, model_dir)
    ds_dt_stamp = re.split('/', ckpts[0])[-2]
    out_dir = os.path.join(config.results, ds_dt_stamp)
    try:
        config = np.load(os.path.join(out_dir, 'meta_info.npy')).item()
        # Make sure this is always at 1
        config.validation_batch = 1
        print('-' * 60)
        print('Loading config meta data for:%s' % out_dir)
        print('-' * 60)
    except:
        print('-' * 60)
        print('Using config from gedi_config.py for model:%s' % out_dir)
        print('-' * 60)

    # Make output directories if they do not exist
    dir_list = [config.results, out_dir]
    [make_dir(d) for d in dir_list]
    # im_shape = get_image_size(config)
    im_shape = config.gedi_image_size

    # Prepare data on CPU
    with tf.device('/cpu:0'):
        val_images, val_labels = inputs(validation_data,
                                        1,
                                        im_shape,
                                        config.model_image_size[:2],
                                        max_value=max_value,
                                        min_value=min_value,
                                        num_epochs=1,
                                        normalize=config.normalize)

    # Prepare model on GPU
    with tf.device('/gpu:0'):
        with tf.variable_scope('cnn'):
            vgg = vgg16.Vgg16(vgg16_npy_path=config.vgg16_weight_path,
                              fine_tune_layers=config.fine_tune_layers)
            vgg.build(val_images, output_shape=config.output_shape)

        # Setup validation op
        scores = vgg.prob
        preds = tf.argmax(vgg.prob, 1)
        targets = tf.cast(val_labels, dtype=tf.int64)

    # Set up saver
    saver = tf.train.Saver(tf.global_variables())

    # Loop through each checkpoint then test the entire validation set
    ckpt_yhat, ckpt_y, ckpt_scores = [], [], []
    print('-' * 60)
    print('Beginning evaluation')
    print('-' * 60)

    if selected_ckpts is not None:
        # Select a specific ckpt
        if selected_ckpts < 0:
            ckpts = ckpts[selected_ckpts:]
        else:
            ckpts = ckpts[:selected_ckpts]

    for idx, c in tqdm(enumerate(ckpts), desc='Running checkpoints'):
        dec_scores, yhat, y = [], [], []
        try:
            # Initialize the graph
            sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
            sess.run(
                tf.group(tf.global_variables_initializer(),
                         tf.local_variables_initializer()))

            # Set up exemplar threading
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            saver.restore(sess, c)
            start_time = time.time()
            while not coord.should_stop():
                sc, tyh, ty = sess.run([scores, preds, targets])
                dec_scores = np.append(dec_scores, sc)
                yhat = np.append(yhat, tyh)
                y = np.append(y, ty)
        except tf.errors.OutOfRangeError:
            ckpt_yhat.append(yhat)
            ckpt_y.append(y)
            ckpt_scores.append(dec_scores)
            print('Iteration accuracy: %s' % np.mean(yhat == y))
            print('Iteration pvalue: %.5f' %
                  randomization_test(y=y, yhat=yhat))
            print('Batch %d took %.1f seconds' %
                  (idx, time.time() - start_time))
        finally:
            coord.request_stop()
        coord.join(threads)
        sess.close()

    # Save everything
    np.savez(os.path.join(out_dir, 'validation_accuracies'),
             ckpt_yhat=ckpt_yhat,
             ckpt_y=ckpt_y,
             ckpt_scores=ckpt_scores,
             ckpt_names=ckpt_names,
             file_pointers=file_pointers)

    # Also save a csv with item/guess pairs
    try:
        trimmed_files = [re.split('/', x)[-1] for x in file_pointers]
        trimmed_files = np.asarray(trimmed_files)
        dec_scores = np.asarray(dec_scores)
        yhat = np.asarray(yhat)
        df = pd.DataFrame(
            np.hstack((trimmed_files.reshape(-1, 1), yhat.reshape(-1, 1),
                       dec_scores.reshape(dec_scores.shape[0] // 2, 2))),
            columns=['files', 'guesses', 'score dead', 'score live'])
        df.to_csv(os.path.join(out_dir, 'prediction_file.csv'))
        print('Saved csv to: %s' % out_dir)
    except:
        print('X' * 60)
        print('Could not save a spreadsheet of file info')
        print('X' * 60)

    # Plot everything
    try:
        plot_accuracies(ckpt_y, ckpt_yhat, config, ckpt_names,
                        os.path.join(out_dir, 'validation_accuracies.png'))
        plot_std(ckpt_y, ckpt_yhat, ckpt_names,
                 os.path.join(out_dir, 'validation_stds.png'))
        plot_cms(ckpt_y, ckpt_yhat, config,
                 os.path.join(out_dir, 'confusion_matrix.png'))
        plot_pr(ckpt_y, ckpt_yhat, ckpt_scores,
                os.path.join(out_dir, 'precision_recall.png'))
        plot_cost(os.path.join(out_dir, 'training_loss.npy'), ckpt_names,
                  os.path.join(out_dir, 'training_costs.png'))
    except:
        print('X' * 60)
        print('Could not locate the loss numpy')
        print('X' * 60)
示例#8
0
def test_vgg16(
        which_dataset,
        validation_data=None,
        model_dir=None):  #Fine tuning defaults to wipe out the final two FCs
    config = GEDIconfig(which_dataset)
    if validation_data == None:  #Use globals
        validation_data = config.tfrecord_dir + 'val.tfrecords'

    #Make output directories if they do not exist
    out_dir = config.results + config.which_dataset + '/'
    dir_list = [config.results, out_dir]
    [make_dir(d) for d in dir_list]
    im_shape = get_image_size(config)

    #Find model checkpoints
    ckpts, ckpt_names = find_ckpts(config)
    print ckpts, ckpt_names
    #Prepare data on CPU
    with tf.device('/cpu:0'):
        val_images, val_labels = inputs(validation_data,
                                        config.validation_batch,
                                        im_shape,
                                        config.model_image_size[:2],
                                        num_epochs=1)

    #Prepare model on GPU
    with tf.device('/gpu:0'):
        att_model = att.Attention()

        validation_mode = tf.Variable(False, name='training')
        att_model.build(val_images,
                        enc_size=config.enc_size,
                        read_n=config.read_n,
                        T=config.T,
                        output_shape=config.output_shape,
                        train_mode=validation_mode)
        image_0 = val_images
        image_1 = att_model.image_show
        image_loc = att_model.location
        # print image_0.get_shape()
        # print image_1[0].get_shape()
        #Setup validation op
        eval_accuracy = class_accuracy(att_model.prob,
                                       val_labels)  #training accuracy now...

    #Set up saver
    saver = tf.train.Saver(tf.all_variables())

    #Loop through each checkpoint, loading the model weights, then testing the entire validation set
    ckpt_accs = []
    max_acc = 0
    max_ind = 0
    max_show_0 = []
    max_show_1 = []
    max_loc = []
    for idx in tqdm(range(len(ckpts))):
        print ckpts[idx]
        accs = []
        show_0 = np.array([])
        show_1 = np.array([])
        show_loc = np.array([])
        try:

            #print type(show_0)
            #Initialize the graph
            sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
            sess.run(
                tf.group(tf.initialize_all_variables(),
                         tf.initialize_local_variables())
            )  #need to initialize both if supplying num_epochs to inputs

            #Set up exemplar threading
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            saver.restore(sess, ckpts[idx])
            start_time = time.time()
            while not coord.should_stop():
                #print '1'
                #print type(accs)
                acc, aa, bb, cc = sess.run(
                    [eval_accuracy, image_0, image_1, image_loc])
                accs = np.append(accs, acc)
                if accs[-1] > 0.8 and show_0.shape[-1] < 5:
                    #print show_0.shape[-1]
                    #print aa.shape, bb
                    aa = aa
                    bb = bb
                    (x1, x2, x3, x4) = aa.shape
                    (y1, y2, y3, y4) = bb.shape
                    (z1, z2, z3) = cc.shape
                    aa = np.reshape(aa, (x1, x2, x3, x4, 1))
                    bb = np.reshape(bb, (y1, y2, y3, y4, 1))
                    cc = np.reshape(cc, (z1, z2, z3, 1))
                    if show_0.shape[0] <= 2:
                        #print sess.run([image_1])

                        show_0 = aa
                        show_1 = bb
                        show_loc = cc
                    else:
                        #print sess.run([image_0])[0].shape, show_0.shape

                        show_0 = np.concatenate((show_0, aa), 4)
                        show_1 = np.concatenate((show_1, bb), 4)
                        show_loc = np.concatenate((show_loc, cc), 3)

        except tf.errors.OutOfRangeError:
            if np.mean(accs) > max_acc:
                max_acc = np.mean(accs)
                max_ind = idx
                max_show_0 = show_0
                max_show_1 = show_1
                max_loc = show_loc
            ckpt_accs.append(accs)
            print('Batch %d took %.1f seconds', idx, time.time() - start_time)
        finally:
            coord.request_stop()
        coord.join(threads)
        sess.close()

    print ckpt_accs, ckpt_names

    #Plot everything
    plot_accuracies(ckpt_accs, ckpt_names,
                    out_dir + 'validation_accuracies.png')
    np.savez(out_dir + 'validation_accuracies',
             ckpt_accs=ckpt_accs,
             ckpt_names=ckpt_names)
    np.savez(out_dir + 'att_verification_' + which_dataset,
             max_show_0=max_show_0,
             max_show_1=max_show_1,
             max_loc=max_loc)
    for idx in range(len(ckpts)):
        if idx != max_ind:
            os.remove(ckpts[idx] + '.data-00000-of-00001')
            os.remove(ckpts[idx] + '.meta')