Ejemplo n.º 1
0
def train_vgg16(train_dir=None, validation_dir=None):
    config = GEDIconfig()
    if train_dir is None:  # Use globals
        train_data = os.path.join(config.tfrecord_dir,
                                  config.tf_record_names['train'])
        meta_data = np.load(
            os.path.join(config.tfrecord_dir,
                         '%s_%s' % (config.tvt_flags[0], config.max_file)))
    else:
        meta_data = np.load(
            os.path.join(train_dir,
                         '%s_%s' % (config.tvt_flags[0], config.max_file)))

    # Prepare image normalization values
    if config.max_gedi is None:
        max_value = np.nanmax(meta_data['max_array']).astype(np.float32)
        if max_value == 0:
            max_value = None
            print('Derived max value is 0')
        else:
            print('Normalizing with empirical max.')
        if 'min_array' in meta_data.keys():
            min_value = np.min(meta_data['min_array']).astype(np.float32)
            print('Normalizing with empirical min.')
        else:
            min_value = None
            print('Not normalizing with a min.')
    else:
        max_value = config.max_gedi
        min_value = config.min_gedi
        print('max value', max_value)
    ratio = meta_data['ratio']
    if config.encode_time_of_death:
        tod = pd.read_csv(config.encode_time_of_death)
        tod_data = tod['dead_tp'].as_matrix()
        mask = np.isnan(tod_data).astype(int) + (
            tod['plate_well_neuron'] == 'empty').as_matrix().astype(int)
        tod_data = tod_data[mask == 0]
        tod_data = tod_data[tod_data <
                            10]  # throw away values have a high number
        config.output_shape = len(np.unique(tod_data))
        ratio = class_weight.compute_class_weight('balanced',
                                                  np.sort(np.unique(tod_data)),
                                                  tod_data)
        flip_ratio = False
    else:
        flip_ratio = True
    print('Ratio is: %s' % ratio)

    if validation_dir is None:  # Use globals
        validation_data = os.path.join(config.tfrecord_dir,
                                       config.tf_record_names['val'])
    elif validation_dir is False:
        pass  # Do not use validation data during training

    # Make output directories if they do not exist
    dt_stamp = re.split(
        '\.', str(datetime.now()))[0].\
        replace(' ', '_').replace(':', '_').replace('-', '_')
    dt_dataset = config.which_dataset + '_' + dt_stamp + '/'
    config.train_checkpoint = os.path.join(config.train_checkpoint,
                                           dt_dataset)  # timestamp this run
    out_dir = os.path.join(config.results, dt_dataset)
    dir_list = [
        config.train_checkpoint, config.train_summaries, config.results,
        out_dir
    ]
    [make_dir(d) for d in dir_list]
    # im_shape = get_image_size(config)
    im_shape = config.gedi_image_size

    print('-' * 60)
    print('Training model:' + dt_dataset)
    print('-' * 60)

    # Prepare data on CPU
    assert os.path.exists(train_data)
    assert os.path.exists(validation_data)
    assert os.path.exists(config.vgg16_weight_path)
    with tf.device('/cpu:0'):
        train_images, train_labels = inputs(train_data,
                                            config.train_batch,
                                            im_shape,
                                            config.model_image_size[:2],
                                            max_value=max_value,
                                            min_value=min_value,
                                            train=config.data_augmentations,
                                            num_epochs=config.epochs,
                                            normalize=config.normalize)
        val_images, val_labels = inputs(validation_data,
                                        config.validation_batch,
                                        im_shape,
                                        config.model_image_size[:2],
                                        max_value=max_value,
                                        min_value=min_value,
                                        num_epochs=config.epochs,
                                        normalize=config.normalize)
        tf.summary.image('train images', train_images)
        tf.summary.image('validation images', val_images)

    # Prepare model on GPU
    with tf.device('/gpu:0'):
        with tf.variable_scope('cnn') as scope:
            if config.ordinal_classification == 'ordinal':  # config.output_shape > 2:  # Hardcoded fix for timecourse pred
                vgg_output = config.output_shape * 2
            elif config.ordinal_classification == 'regression':
                vgg_output = 1
            elif config.ordinal_classification is None:
                vgg_output = config.output_shape
            vgg = vgg16.Vgg16(vgg16_npy_path=config.vgg16_weight_path,
                              fine_tune_layers=config.fine_tune_layers)
            train_mode = tf.get_variable(name='training', initializer=True)
            vgg.build(train_images,
                      output_shape=vgg_output,
                      train_mode=train_mode,
                      batchnorm=config.batchnorm_layers)

            # Prepare the cost function
            if config.ordinal_classification == 'ordinal':
                # Encode y w/ k-hot and yhat w/ sigmoid ce. units capture dist.
                enc = tf.concat([
                    tf.reshape(tf.range(0, config.output_shape), [1, -1])
                    for x in range(config.train_batch)
                ],
                                axis=0)
                enc_train_labels = tf.cast(
                    tf.greater_equal(enc, tf.expand_dims(train_labels,
                                                         axis=1)), tf.float32)
                split_labs = tf.split(enc_train_labels,
                                      config.output_shape,
                                      axis=1)
                res_output = tf.reshape(
                    vgg.fc8, [config.train_batch, 2, config.output_shape])
                split_logs = tf.split(res_output, config.output_shape, axis=2)
                if config.balance_cost:
                    cost = tf.add_n([
                        tf.nn.softmax_cross_entropy_with_logits(
                            labels=tf.one_hot(tf.cast(tf.squeeze(s), tf.int32),
                                              2),
                            logits=tf.squeeze(l)) * r
                        for s, l, r in zip(split_labs, split_logs, ratio)
                    ])
                else:
                    cost = tf.add_n([
                        tf.nn.softmax_cross_entropy_with_logits(
                            labels=tf.one_hot(tf.cast(tf.squeeze(s), tf.int32),
                                              2),
                            logits=tf.squeeze(l))
                        for s, l in zip(split_labs, split_logs)
                    ])
                cost = tf.reduce_mean(cost)
            elif config.ordinal_classification == 'regression':
                if config.balance_cost:
                    weight_vec = tf.gather(train_labels, ratio)
                    cost = tf.reduce_mean(
                        tf.pow(
                            (vgg.fc8) - tf.cast(train_labels, tf.float32), 2) *
                        weight_vec)
                else:
                    cost = tf.nn.l2_loss((vgg.fc8) -
                                         tf.cast(train_labels, tf.float32))
            else:
                if config.balance_cost:
                    cost = softmax_cost(vgg.fc8,
                                        train_labels,
                                        ratio=ratio,
                                        flip_ratio=flip_ratio)
                else:
                    cost = softmax_cost(vgg.fc8, train_labels)
            tf.summary.scalar("cost", cost)

            # Finetune the learning rates
            if config.wd_layers is not None:
                _, l2_wd_layers = fine_tune_prepare_layers(
                    tf.trainable_variables(), config.wd_layers)
                l2_wd_layers = [
                    x for x in l2_wd_layers if 'biases' not in x.name
                ]
                cost += (config.wd_penalty *
                         tf.add_n([tf.nn.l2_loss(x) for x in l2_wd_layers]))

            # for all variables in trainable variables
            # print name if there's duplicates you f****d up
            other_opt_vars, ft_opt_vars = fine_tune_prepare_layers(
                tf.trainable_variables(), config.fine_tune_layers)
            if config.optimizer == 'adam':
                train_op = ft_non_optimized(cost, other_opt_vars, ft_opt_vars,
                                            tf.train.AdamOptimizer,
                                            config.hold_lr, config.new_lr)
            elif config.optimizer == 'sgd':
                train_op = ft_non_optimized(cost, other_opt_vars, ft_opt_vars,
                                            tf.train.GradientDescentOptimizer,
                                            config.hold_lr, config.new_lr)

            if config.ordinal_classification == 'ordinal':
                arg_guesses = tf.cast(
                    tf.reduce_sum(tf.squeeze(tf.argmax(res_output, axis=1)),
                                  reduction_indices=[1]), tf.int32)
                train_accuracy = tf.reduce_mean(
                    tf.cast(tf.equal(arg_guesses, train_labels), tf.float32))
            elif config.ordinal_classification == 'regression':
                train_accuracy = tf.reduce_mean(
                    tf.cast(
                        tf.equal(tf.cast(tf.round(vgg.prob), tf.int32),
                                 train_labels), tf.float32))
            else:
                train_accuracy = class_accuracy(
                    vgg.prob, train_labels)  # training accuracy
            tf.summary.scalar("training accuracy", train_accuracy)

            # Setup validation op
            if validation_data is not False:
                scope.reuse_variables()
                # Validation graph is the same as training except no batchnorm
                val_vgg = vgg16.Vgg16(vgg16_npy_path=config.vgg16_weight_path,
                                      fine_tune_layers=config.fine_tune_layers)
                val_vgg.build(val_images, output_shape=vgg_output)
                # Calculate validation accuracy
                if config.ordinal_classification == 'ordinal':
                    val_res_output = tf.reshape(
                        val_vgg.fc8,
                        [config.validation_batch, 2, config.output_shape])
                    val_arg_guesses = tf.cast(
                        tf.reduce_sum(tf.squeeze(
                            tf.argmax(val_res_output, axis=1)),
                                      reduction_indices=[1]), tf.int32)
                    val_accuracy = tf.reduce_mean(
                        tf.cast(tf.equal(val_arg_guesses, val_labels),
                                tf.float32))
                elif config.ordinal_classification == 'regression':
                    val_accuracy = tf.reduce_mean(
                        tf.cast(
                            tf.equal(tf.cast(tf.round(val_vgg.prob), tf.int32),
                                     val_labels), tf.float32))
                else:
                    val_accuracy = class_accuracy(val_vgg.prob, val_labels)
                tf.summary.scalar("validation accuracy", val_accuracy)

    # Set up summaries and saver
    saver = tf.train.Saver(tf.global_variables(),
                           max_to_keep=config.keep_checkpoints)
    summary_op = tf.summary.merge_all()

    # Initialize the graph
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    # Need to initialize both of these if supplying num_epochs to inputs
    sess.run(
        tf.group(tf.global_variables_initializer(),
                 tf.local_variables_initializer()))
    summary_dir = os.path.join(config.train_summaries,
                               config.which_dataset + '_' + dt_stamp)
    summary_writer = tf.summary.FileWriter(summary_dir, sess.graph)

    # Set up exemplar threading
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    # Start training loop
    np.save(out_dir + 'meta_info', config)
    step, losses = 0, []  # val_max = 0
    try:
        # print response
        while not coord.should_stop():
            start_time = time.time()
            _, loss_value, train_acc = sess.run(
                [train_op, cost, train_accuracy])
            losses.append(loss_value)
            duration = time.time() - start_time
            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            if step % config.validation_steps == 0:
                if validation_data is not False:
                    _, val_acc = sess.run([train_op, val_accuracy])
                else:
                    val_acc -= 1  # Store every checkpoint

                # Summaries
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

                # Training status and validation accuracy
                format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; '
                              '%.3f sec/batch) | Training accuracy = %s | '
                              'Validation accuracy = %s | logdir = %s')
                print(format_str %
                      (datetime.now(), step,
                       loss_value, config.train_batch / duration,
                       float(duration), train_acc, val_acc, summary_dir))

                # Save the model checkpoint if it's the best yet
                if 1:  # val_acc >= val_max:
                    saver.save(sess,
                               os.path.join(config.train_checkpoint,
                                            'model_' + str(step) + '.ckpt'),
                               global_step=step)
                    # Store the new max validation accuracy
                    # val_max = val_acc

            else:
                # Training status
                format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; '
                              '%.3f sec/batch) | Training accuracy = %s')
                print(format_str %
                      (datetime.now(), step, loss_value, config.train_batch /
                       duration, float(duration), train_acc))
            # End iteration
            step += 1

    except tf.errors.OutOfRangeError:
        print('Done training for %d epochs, %d steps.' % (config.epochs, step))
    finally:
        coord.request_stop()
        np.save(os.path.join(config.tfrecord_dir, 'training_loss'), losses)
    coord.join(threads)
    sess.close()
Ejemplo n.º 2
0
def train_model(train_dir=None,
                validation_dir=None,
                debug=True,
                resume_ckpt=None,
                resume_meta=None,
                margin=.4):
    config = GEDIconfig()
    if resume_meta is not None:
        config = np.load(resume_meta).item()
    assert margin is not None, 'Need a margin for the loss.'
    if train_dir is None:  # Use globals
        train_data = os.path.join(config.tfrecord_dir,
                                  config.tf_record_names['train'])
        meta_data = np.load(
            os.path.join(config.tfrecord_dir,
                         '%s_%s' % (config.tvt_flags[0], config.max_file)))
    else:
        meta_data = np.load(
            os.path.join(train_dir,
                         '%s_%s' % (config.tvt_flags[0], config.max_file)))

    # Prepare image normalization values
    if config.max_gedi is None:
        max_value = np.nanmax(meta_data['max_array']).astype(np.float32)
        if max_value == 0:
            max_value = None
            print('Derived max value is 0')
        else:
            print('Normalizing with empirical max.')
        if 'min_array' in meta_data.keys():
            # min_value = np.min(meta_data['min_array']).astype(np.float32)
            print('Normalizing with empirical min.')
        else:
            # min_value = None
            print('Not normalizing with a min.')
    else:
        max_value = config.max_gedi
        # min_value = config.min_gedi
    ratio = meta_data['ratio']
    print('Ratio is: %s' % ratio)

    if validation_dir is None:  # Use globals
        validation_data = os.path.join(config.tfrecord_dir,
                                       config.tf_record_names['val'])
    elif validation_dir is False:
        pass  # Do not use validation data during training

    # Make output directories if they do not exist
    dt_stamp = re.split(
        '\.', str(datetime.now()))[0].\
        replace(' ', '_').replace(':', '_').replace('-', '_')
    dt_dataset = config.which_dataset + '_' + dt_stamp + '/'
    config.train_checkpoint = os.path.join(config.train_checkpoint,
                                           dt_dataset)  # timestamp this run
    out_dir = os.path.join(config.results, dt_dataset)
    dir_list = [
        config.train_checkpoint, config.train_summaries, config.results,
        out_dir
    ]
    [tf_fun.make_dir(d) for d in dir_list]
    # im_shape = get_image_size(config)
    im_shape = config.gedi_image_size

    print('-' * 60)
    print('Training model:' + dt_dataset)
    print('-' * 60)

    # Prepare data on CPU
    assert os.path.exists(train_data)
    assert os.path.exists(validation_data)
    assert os.path.exists(config.vgg16_weight_path)
    with tf.device('/cpu:0'):
        train_images, train_labels, train_times = inputs(
            train_data,
            config.train_batch,
            im_shape,
            config.model_image_size,
            # max_value=max_value,
            # min_value=min_value,
            train=config.data_augmentations,
            num_epochs=config.epochs,
            normalize=config.normalize,
            return_filename=True)
        val_images, val_labels, val_times = inputs(
            validation_data,
            config.validation_batch,
            im_shape,
            config.model_image_size,
            # max_value=max_value,
            # min_value=min_value,
            num_epochs=config.epochs,
            normalize=config.normalize,
            return_filename=True)
        train_image_list, val_image_list = [], []
        for idx in range(int(train_images.get_shape()[1])):
            train_image_list += [tf.gather(train_images, idx, axis=1)]
            val_image_list += [tf.gather(val_images, idx, axis=1)]
            tf.summary.image('train_image_frame_%s' % idx,
                             train_image_list[idx])
            tf.summary.image('validation_image_frame_%s' % idx,
                             val_image_list[idx])

    # Prepare model on GPU
    config.l2_norm = False
    config.norm_axis = 0
    config.dist_fun = 'pearson'
    config.per_batch = False
    config.include_GEDI = True  # False
    config.output_shape = 32
    config.margin = margin
    with tf.device('/gpu:0'):
        with tf.variable_scope('match'):
            # Build matching model for frame 0
            model_0 = vgg16.model_struct(
                vgg16_npy_path=config.gedi_weight_path)  # ,
            frame_activity = []
            model_activity = model_0.build(train_image_list[0],
                                           output_shape=config.output_shape,
                                           include_GEDI=config.include_GEDI)
            if config.l2_norm:
                model_activity = tf_fun.l2_normalize(model_activity,
                                                     axis=config.norm_axis)
            frame_activity += [model_activity]

        with tf.variable_scope('match', reuse=tf.AUTO_REUSE):
            # Build matching model for other frames
            for idx in range(1, len(train_image_list)):
                model_activity = model_0.build(
                    train_image_list[idx],
                    output_shape=config.output_shape,
                    include_GEDI=config.include_GEDI)
                if config.l2_norm:
                    model_activity = tf_fun.l2_normalize(model_activity,
                                                         axis=config.norm_axis)
                frame_activity += [model_activity]

        if config.dist_fun == 'l2':
            pos = tf_fun.l2_dist(frame_activity[0], frame_activity[1], axis=1)
            neg = tf_fun.l2_dist(frame_activity[0], frame_activity[2], axis=1)
        elif config.dist_fun == 'pearson':
            pos = tf_fun.pearson_dist(frame_activity[0],
                                      frame_activity[1],
                                      axis=1)
            neg = tf_fun.pearson_dist(frame_activity[0],
                                      frame_activity[2],
                                      axis=1)
        else:
            raise NotImplementedError(config.dist_fun)
        if config.per_batch:
            loss = tf.maximum(tf.reduce_mean(pos - neg) + margin, 0.)
        else:
            loss = tf.reduce_mean(tf.maximum(pos - neg + margin, 0.))
        tf.summary.scalar('Triplet_loss', loss)

        # Weight decay
        if config.wd_layers is not None:
            _, l2_wd_layers = tf_fun.fine_tune_prepare_layers(
                tf.trainable_variables(), config.wd_layers)
            l2_wd_layers = [x for x in l2_wd_layers if 'biases' not in x.name]
            if len(l2_wd_layers) > 0:
                loss += (config.wd_penalty *
                         tf.add_n([tf.nn.l2_loss(x) for x in l2_wd_layers]))

        # Optimize
        train_op = tf.train.AdamOptimizer(config.new_lr).minimize(loss)
        train_accuracy = tf.reduce_mean(
            tf.cast(
                tf.equal(
                    tf.nn.relu(tf.sign(neg - pos)),  # 1 if pos < neg
                    tf.cast(tf.ones_like(train_labels), tf.float32)),
                tf.float32))
        tf.summary.scalar('training_accuracy', train_accuracy)

        # Setup validation op
        if validation_data is not False:
            with tf.variable_scope('match', tf.AUTO_REUSE) as match:
                # Build matching model for frame 0
                match.reuse_variables()
                val_model_0 = vgg16.model_struct(
                    vgg16_npy_path=config.gedi_weight_path)
                val_frame_activity = []
                model_activity = val_model_0.build(
                    val_image_list[0],
                    output_shape=config.output_shape,
                    include_GEDI=config.include_GEDI)
                if config.l2_norm:
                    model_activity = tf_fun.l2_normalize(model_activity,
                                                         axis=config.norm_axis)
                val_frame_activity += [model_activity]

                # Build matching model for other frames
                for idx in range(1, len(train_image_list)):
                    model_activity = val_model_0.build(
                        val_image_list[idx],
                        output_shape=config.output_shape,
                        include_GEDI=config.include_GEDI)
                    if config.l2_norm:
                        model_activity = tf_fun.l2_normalize(
                            model_activity, axis=config.norm_axis)
                    val_frame_activity += [model_activity]
            if config.dist_fun == 'l2':
                val_pos = tf_fun.l2_dist(val_frame_activity[0],
                                         val_frame_activity[1],
                                         axis=1)
                val_neg = tf_fun.l2_dist(val_frame_activity[0],
                                         val_frame_activity[2],
                                         axis=1)
            elif config.dist_fun == 'pearson':
                val_pos = tf_fun.pearson_dist(val_frame_activity[0],
                                              val_frame_activity[1],
                                              axis=1)
                val_neg = tf_fun.pearson_dist(val_frame_activity[0],
                                              val_frame_activity[2],
                                              axis=1)
            if config.per_batch:
                val_loss = tf.maximum(
                    tf.reduce_mean(val_pos - val_neg) + margin, 0.)
            else:
                val_loss = tf.reduce_mean(
                    tf.maximum(val_pos - val_neg + margin, 0.))
            tf.summary.scalar('Validation_triplet_loss', val_loss)

        # Calculate validation accuracy
        val_accuracy = tf.reduce_mean(
            tf.cast(
                tf.equal(tf.nn.relu(tf.sign(val_neg - val_pos)),
                         tf.cast(tf.ones_like(val_labels), tf.float32)),
                tf.float32))
        tf.summary.scalar('val_accuracy', val_accuracy)

    # Set up summaries and saver
    saver = tf.train.Saver(tf.global_variables(),
                           max_to_keep=config.keep_checkpoints)
    summary_op = tf.summary.merge_all()

    # Initialize the graph
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    # Need to initialize both of these if supplying num_epochs to inputs
    sess.run(
        tf.group(tf.global_variables_initializer(),
                 tf.local_variables_initializer()))
    summary_dir = os.path.join(config.train_summaries,
                               config.which_dataset + '_' + dt_stamp)
    summary_writer = tf.summary.FileWriter(summary_dir, sess.graph)

    # Set up exemplar threading
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    # Train operations
    train_dict = {
        'train_op': train_op,
        'loss': loss,
        'pos': pos,
        'neg': neg,
        'train_accuracy': train_accuracy,
        'val_accuracy': val_accuracy,
    }
    val_dict = {'val_accuracy': val_accuracy}
    if debug:
        for idx in range(len(train_image_list)):
            train_dict['train_im_%s' % idx] = train_image_list[idx]
        for idx in range(len(val_image_list)):
            val_dict['val_im_%s' % idx] = val_image_list[idx]

    # Resume training if requested
    if resume_ckpt is not None:
        print('*' * 50)
        print('Resuming training from: %s' % resume_ckpt)
        print('*' * 50)
        saver.restore(sess, resume_ckpt)

    # Start training loop
    np.save(out_dir + 'meta_info', config)
    step, losses = 0, []
    top_vals = np.asarray([0])
    try:
        # print response
        while not coord.should_stop():
            start_time = time.time()
            train_values = sess.run(train_dict.values())
            it_train_dict = {
                k: v
                for k, v in zip(train_dict.keys(), train_values)
            }
            losses += [it_train_dict['loss']]
            duration = time.time() - start_time
            if np.isnan(it_train_dict['loss']).sum():
                assert not np.isnan(it_train_dict['loss']),\
                    'Model loss = NaN'

            if step % config.validation_steps == 0:
                if validation_data is not False:
                    val_accs = []
                    for vit in range(config.validation_iterations):
                        val_values = sess.run(val_dict.values())
                        it_val_dict = {
                            k: v
                            for k, v in zip(val_dict.keys(), val_values)
                        }
                        val_accs += [it_val_dict['val_accuracy']]
                    val_acc = np.nanmean(val_accs)
                else:
                    val_acc -= 1  # Store every checkpoint

                # Summaries
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

                # Training status and validation accuracy
                format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; '
                              '%.3f sec/batch) | Training accuracy = %s | '
                              'Validation accuracy = %s | '
                              'logdir = %s')
                print(format_str %
                      (datetime.now(), step,
                       it_train_dict['loss'], config.train_batch / duration,
                       float(duration), it_train_dict['train_accuracy'],
                       it_train_dict['val_accuracy'], summary_dir))

                # Save the model checkpoint if it's the best yet
                top_vals = top_vals[:config.num_keep_checkpoints]
                check_val = val_acc > top_vals
                if check_val.sum():
                    saver.save(sess,
                               os.path.join(config.train_checkpoint,
                                            'model_' + str(step) + '.ckpt'),
                               global_step=step)
                    # Store the new validation accuracy
                    top_vals = np.append(top_vals, val_acc)

            else:
                # Training status
                format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; '
                              '%.3f sec/batch) | Training accuracy = %s | ')
                print(format_str %
                      (datetime.now(), step,
                       it_train_dict['loss'], config.train_batch / duration,
                       float(duration), it_train_dict['train_accuracy']))
            # End iteration
            step += 1

    except tf.errors.OutOfRangeError:
        print('Done training for %d epochs, %d steps.' % (config.epochs, step))
    finally:
        coord.request_stop()
        np.save(os.path.join(config.tfrecord_dir, 'training_loss'), losses)
    coord.join(threads)
    sess.close()
Ejemplo n.º 3
0
def train_vgg16(train_dir=None, validation_dir=None):
    config = GEDIconfig()
    if train_dir is None:  # Use globals
        train_data = os.path.join(
            config.tfrecord_dir,
            config.tf_record_names['train'])
        meta_data = np.load(
            os.path.join(
                config.tfrecord_dir,
                '%s_%s' % (config.tvt_flags[0], config.max_file)))
    else:
        meta_data = np.load(
            os.path.join(
                train_dir,
                '%s_%s' % (config.tvt_flags[0], config.max_file)))

    # Prepare image normalization values
    if config.max_gedi is None:
        max_value = np.nanmax(meta_data['max_array']).astype(np.float32)
        if max_value == 0:
            max_value = None
            print 'Derived max value is 0'
        else:
            print 'Normalizing with empirical max.'
        if 'min_array' in meta_data.keys():
            min_value = np.min(meta_data['min_array']).astype(np.float32)
            print 'Normalizing with empirical min.'
        else:
            min_value = None
            print 'Not normalizing with a min.'
    else:
        max_value = config.max_gedi
        min_value = config.min_gedi
    ratio = meta_data['ratio']
    if config.encode_time_of_death:
        tod = pd.read_csv(config.encode_time_of_death)
        tod_data = tod['dead_tp'].as_matrix()
        mask = np.isnan(tod_data).astype(int) + (
            tod['plate_well_neuron'] == 'empty').as_matrix().astype(int)
        tod_data = tod_data[mask == 0]
        tod_data = tod_data[tod_data > config.mask_timepoint_value]
        config.output_shape = len(np.unique(tod_data))
        ratio = class_weight.compute_class_weight(
            'balanced',
            np.sort(np.unique(tod_data)),
            tod_data)
        flip_ratio = False
    else:
        flip_ratio = True
    print 'Ratio is: %s' % ratio

    if validation_dir is None:  # Use globals
        validation_data = os.path.join(
            config.tfrecord_dir,
            config.tf_record_names['val'])
    elif validation_dir is False:
        pass  # Do not use validation data during training

    # Make output directories if they do not exist
    dt_stamp = re.split(
        '\.', str(datetime.now()))[0].\
        replace(' ', '_').replace(':', '_').replace('-', '_')
    dt_dataset = config.which_dataset + '_' + dt_stamp + '/'
    config.train_checkpoint = os.path.join(
        config.train_checkpoint, dt_dataset)  # timestamp this run
    out_dir = os.path.join(config.results, dt_dataset)
    dir_list = [
        config.train_checkpoint, config.train_summaries,
        config.results, out_dir]
    [make_dir(d) for d in dir_list]
    # im_shape = get_image_size(config)
    im_shape = config.gedi_image_size

    print '-'*60
    print('Training model:' + dt_dataset)
    print '-'*60

    # Prepare data on CPU
    assert os.path.exists(train_data)
    assert os.path.exists(validation_data)
    assert os.path.exists(config.vgg16_weight_path)
    with tf.device('/cpu:0'):
        train_images, train_labels, train_gedi_images = inputs(
            train_data,
            config.train_batch,
            im_shape,
            config.model_image_size[:2],
            max_value=max_value,
            min_value=min_value,
            train=config.data_augmentations,
            num_epochs=config.epochs,
            normalize=config.normalize,
            return_gedi=config.include_GEDI_in_tfrecords,
            return_extra_gfp=config.extra_image,
            return_GEDI_derivative=True)
        val_images, val_labels, val_gedi_images = inputs(
            validation_data,
            config.validation_batch,
            im_shape,
            config.model_image_size[:2],
            max_value=max_value,
            min_value=min_value,
            num_epochs=config.epochs,
            normalize=config.normalize,
            return_gedi=config.include_GEDI_in_tfrecords,
            return_extra_gfp=config.extra_image,
            return_GEDI_derivative=True)
        if config.include_GEDI_in_tfrecords:
            extra_im_name = 'GEDI at current timepoint'
        else:
            extra_im_name = 'next gfp timepoint'
        tf.summary.image('train images', train_images)
        tf.summary.image('validation images', val_images)
        tf.summary.image('train %s' % extra_im_name, train_gedi_images)
        tf.summary.image('validation %s' % extra_im_name, val_gedi_images)

    # Prepare model on GPU
    with tf.device('/gpu:0'):
        with tf.variable_scope('cnn') as scope:
            if config.ordinal_classification is None:
                vgg_output = 2  # Sign of derivative (inf norm)
                train_labels = tf.cast(tf.sign(train_labels), tf.int32)
                val_labels = tf.cast(tf.sign(val_labels), tf.int32)
            elif config.ordinal_classification == 'regression':
                vgg_output = 1
            else:
                raise RuntimeError(
                    'config.ordinal_classification must be sign or regression.'
                    )
            vgg = vgg16.model_struct()
            train_mode = tf.get_variable(name='training', initializer=True)

            # Mask NAN images from loss
            image_nan = tf.reduce_sum(
                tf.cast(tf.is_nan(train_images), tf.float32),
                reduction_indices=[1, 2, 3])
            gedi_nan = tf.reduce_sum(
                tf.cast(tf.is_nan(train_gedi_images), tf.float32),
                reduction_indices=[1, 2, 3],
                keep_dims=True)
            image_mask = tf.cast(tf.equal(image_nan, 0.), tf.float32)
            gedi_nan = tf.cast(tf.equal(gedi_nan, 0.), tf.float32)
            train_images = tf.where(
                tf.is_nan(train_images),
                tf.zeros_like(train_images),
                train_images)
            train_gedi_images = tf.where(
                tf.is_nan(train_gedi_images),
                tf.zeros_like(train_gedi_images),
                train_gedi_images)
            train_images = tf.concat([train_images, train_images, train_images], axis=3)
            val_images = tf.concat([val_images, val_images, val_images], axis=3)
            vgg.build(
                train_images, output_shape=vgg_output,
                train_mode=train_mode, batchnorm=config.batchnorm_layers)
            # Prepare the cost function
            if config.ordinal_classification is None:
                # Encode y w/ k-hot and yhat w/ sigmoid ce. units capture dist.
                cost = softmax_cost(
                    vgg.fc8,
                    train_labels,
                    mask=image_mask)
            elif config.ordinal_classification == 'regression':
                cost = tf.nn.l2_loss(tf.squeeze(vgg.fc8) - train_labels)

            class_loss = cost
            tf.summary.scalar("cce cost", cost)

            # Weight decay
            if config.wd_layers is not None:
                _, l2_wd_layers = fine_tune_prepare_layers(
                    tf.trainable_variables(), config.wd_layers)
                l2_wd_layers = [
                    x for x in l2_wd_layers if 'biases' not in x.name]
                if len(l2_wd_layers) > 0:
                    cost += (config.wd_penalty * tf.add_n(
                        [tf.nn.l2_loss(x) for x in l2_wd_layers]))

            # Optimize
            train_op = tf.train.AdamOptimizer(config.new_lr).minimize(cost)
            if config.ordinal_classification is None:
                train_accuracy = class_accuracy(
                    vgg.prob, train_labels)  # training accuracy
            elif config.ordinal_classification == 'regression':
                train_accuracy = tf.nn.l2_loss(
                    tf.squeeze(vgg.fc8) - train_labels)
            tf.summary.scalar("training accuracy", train_accuracy)

            # Setup validation op
            if validation_data is not False:
                scope.reuse_variables()
                # Validation graph is the same as training except no batchnorm
                val_vgg = vgg16.model_struct(
                    fine_tune_layers=config.fine_tune_layers)
                val_vgg.build(val_images, output_shape=vgg_output)
                # Calculate validation accuracy
                if config.ordinal_classification is None:
                    val_accuracy = class_accuracy(val_vgg.prob, val_labels)
                elif config.ordinal_classification == 'regression':
                    val_accuracy = tf.nn.l2_loss(tf.squeeze(val_vgg.fc8) - val_labels)
                tf.summary.scalar("validation accuracy", val_accuracy)

    # Set up summaries and saver
    saver = tf.train.Saver(
        tf.global_variables(), max_to_keep=config.keep_checkpoints)
    summary_op = tf.summary.merge_all()

    # Initialize the graph
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    # Need to initialize both of these if supplying num_epochs to inputs
    sess.run(tf.group(tf.global_variables_initializer(),
             tf.local_variables_initializer()))
    summary_dir = os.path.join(
        config.train_summaries, config.which_dataset + '_' + dt_stamp)
    summary_writer = tf.summary.FileWriter(summary_dir, sess.graph)

    # Set up exemplar threading
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    # Restore model if requested
    if config.restore_path is not None:
        print '-' * 60
        print 'Restoring from a previous model: %s' % config.restore_path
        print '-' * 60
        saver.restore(sess, config.restore_path)

    # Start training loop
    np.save(out_dir + 'meta_info', config)
    step, losses = 0, []  # val_max = 0

    try:
        # print response
        while not coord.should_stop():
            start_time = time.time()
            _, loss_value, train_acc = sess.run(
                [train_op, cost, train_accuracy])
            losses.append(loss_value)
            duration = time.time() - start_time
            if np.isnan(loss_value).sum():
                import ipdb;ipdb.set_trace()
            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
            if step % config.validation_steps == 0:
                if validation_data is not False:
                    _, val_acc = sess.run([train_op, val_accuracy])
                else:
                    val_acc -= 1  # Store every checkpoint

                # Summaries
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

                # Training status and validation accuracy
                format_str = (
                    '%s: step %d, loss = %.2f (%.1f examples/sec; '
                    '%.3f sec/batch) | Training accuracy = %s | '
                    'Training %s = %s | Training class loss = %s | '
                    'Validation accuracy = %s | Validation %s = %s | '
                    'logdir = %s')
                print (format_str % (
                    datetime.now(), step, loss_value,
                    config.train_batch / duration, float(duration),
                    train_acc, extra_im_name, 0.,
                    0., val_acc, extra_im_name,
                    0., summary_dir))

                # Save the model checkpoint if it's the best yet
                if 1:  # val_acc >= val_max:
                    saver.save(
                        sess, os.path.join(
                            config.train_checkpoint,
                            'model_' + str(step) + '.ckpt'), global_step=step)
                    # Store the new max validation accuracy
                    # val_max = val_acc

            else:
                # Training status
                format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; '
                              '%.3f sec/batch) | Training accuracy = %s | '
                              'Training %s = %s | Training class loss = %s')
                print (format_str % (datetime.now(), step, loss_value,
                                     config.train_batch / duration,
                                     float(duration), train_acc,
                                     extra_im_name, 0.,
                                     0.))
            # End iteration
            step += 1

    except tf.errors.OutOfRangeError:
        print('Done training for %d epochs, %d steps.' % (config.epochs, step))
    finally:
        coord.request_stop()
        np.save(os.path.join(config.tfrecord_dir, 'training_loss'), losses)
    coord.join(threads)
    sess.close()
Ejemplo n.º 4
0
def test_placeholder(
        image_path,
        model_file,
        model_meta,
        out_dir,
        n_images=3,
        first_n_images=1,
        debug=True,
        margin=.1,
        autopsy_csv=None,
        C=1,
        k_folds=10,
        embedding_type='tsne',
        autopsy_model='match'):
    config = GEDIconfig()
    assert margin is not None, 'Need a margin for the loss.'
    assert image_path is not None, 'Provide a path to an image directory.'
    assert model_file is not None, 'Provide a path to the model file.'

    try:
        # Load the model's config
        config = np.load(model_meta).item()
    except:
        print 'Could not load model config, falling back to default config.'
    config.model_image_size[-1] = 1
    try:
        # Load autopsy information
        autopsy_data = pd.read_csv(autopsy_csv)
    except IOError:
        print 'Unable to load autopsy file.'
    if not hasattr(config, 'include_GEDI'):
        raise RuntimeError('You need to pass the correct meta file.')
        config.include_GEDI = True
        config.l2_norm = False
        config.dist_fun = 'pearson'
        config.per_batch = False
        config.output_shape = 32
        config.margin = 0.1
    if os.path.isdir(image_path):
        combined_files = np.asarray(
            glob(os.path.join(image_path, '*%s' % config.raw_im_ext)))
    else:
        combined_files = [image_path]
    if len(combined_files) == 0:
        raise RuntimeError('Could not find any files. Check your image path.')

    # Make output directories if they do not exist
    dt_stamp = re.split(
        '\.', str(datetime.now()))[0].\
        replace(' ', '_').replace(':', '_').replace('-', '_')
    dt_dataset = config.which_dataset + '_' + dt_stamp + '/'
    config.train_checkpoint = os.path.join(
        config.train_checkpoint, dt_dataset)  # timestamp this run
    out_dir = os.path.join(out_dir, dt_dataset)
    dir_list = [out_dir]
    [tf_fun.make_dir(d) for d in dir_list]

    # Prepare data on CPU
    with tf.device('/cpu:0'):
        images = []
        for idx in range(first_n_images):
            images += [tf.placeholder(
                tf.float32,
                shape=[None] + config.model_image_size,
                name='images_%s' % idx)]

    # Prepare model on GPU
    with tf.device('/gpu:0'):
        if autopsy_model == 'match':
            from models import matching_vgg16 as model_type
            with tf.variable_scope('match'):
                # Build matching model for frame 0
                model_0 = model_type.model_struct(
                    vgg16_npy_path=config.gedi_weight_path)  # ,
                frame_activity = []
                model_activity = model_0.build(
                    images[0],
                    output_shape=config.output_shape,
                    include_GEDI=config.include_GEDI)
                if config.l2_norm:
                    model_activity = [model_activity]
                frame_activity += [model_activity]
            if first_n_images > 1:
                with tf.variable_scope('match', reuse=tf.AUTO_REUSE):
                    # Build matching model for other frames
                    for idx in range(1, len(images)):
                        model_activity = model_0.build(
                            images[idx],
                            output_shape=config.output_shape,
                            include_GEDI=config.include_GEDI)
                        if config.l2_norm:
                            model_activity = tf_fun.l2_normalize(
                                model_activity)
                        frame_activity += [model_activity]
                if config.dist_fun == 'l2':
                    pos = tf_fun.l2_dist(
                        frame_activity[0],
                        frame_activity[1], axis=1)
                    neg = tf_fun.l2_dist(
                        frame_activity[0],
                        frame_activity[2], axis=1)
                elif config.dist_fun == 'pearson':
                    pos = tf_fun.pearson_dist(
                        frame_activity[0],
                        frame_activity[1],
                        axis=1)
                    neg = tf_fun.pearson_dist(
                        frame_activity[0],
                        frame_activity[2],
                        axis=1)
                model_activity = pos - neg  # Store the difference in distances
        elif autopsy_model == 'GEDI' or autopsy_model == 'gedi':
            from models import baseline_vgg16 as model_type
            model = model_type.model_struct(
                vgg16_npy_path=config.gedi_weight_path)  # ,
            model.build(
                images[0],
                output_shape=config.output_shape)
            model_activity = model.fc7
        else:
            raise NotImplementedError(autopsy_model)

    if config.validation_batch > len(combined_files):
        print (
            'Trimming validation_batch size to %s '
            '(same as # of files).' % len(combined_files))
        config.validation_batch = len(combined_files)

    # Set up saver
    saver = tf.train.Saver(tf.global_variables())

    # Initialize the graph
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    # Need to initialize both of these if supplying num_epochs to inputs
    sess.run(
        tf.group(
            tf.global_variables_initializer(),
            tf.local_variables_initializer()))

    # Set up exemplar threading
    if autopsy_model == 'match':
        saver.restore(sess, model_file)
    start_time = time.time()
    num_batches = np.floor(
        len(combined_files) / float(
            config.validation_batch)).astype(int)
    score_array, file_array = [], []
    for image_batch, file_batch in tqdm(
            image_batcher(
                start=0,
                num_batches=num_batches,
                images=combined_files,
                config=config,
                first_n_images=first_n_images,
                n_images=n_images),
            total=num_batches):
        for im_head in images:
            feed_dict = {
                im_head: image_batch
            }
            activity = sess.run(
                model_activity,
                feed_dict=feed_dict)
            score_array += [activity]
        file_array += [file_batch]
    print 'Image processing %d took %.1f seconds' % (
        idx, time.time() - start_time)
    sess.close()
    score_array = np.concatenate(score_array, axis=0)
    score_array = score_array.reshape(-1, score_array.shape[-1])
    file_array = np.concatenate(file_array, axis=0)

    # Save everything
    np.savez(
        os.path.join(out_dir, 'validation_accuracies'),
        score_array=score_array,
        file_array=file_array)

    if first_n_images == 1:
        # Derive pathologies from file names
        pathologies = []
        for f in combined_files:
            sf = f.split(os.path.sep)[-1].split('_')
            line = sf[1]
            # time_col = sf[2]
            well = sf[4]
            disease = autopsy_data[
                np.logical_and(
                    autopsy_data['line'] == line,
                    autopsy_data['wells'] == well)]['type']
            try:
                disease = disease.as_matrix()[0]
            except:
                disease = 'Not_found'
            pathologies += [disease]
        pathologies = np.asarray(pathologies)[:len(score_array)]

        mu = score_array.mean(0)
        sd = score_array.std(0)
        z_score_array = (score_array - mu) / (sd + 1e-4)
        if embedding_type == 'TSNE' or embedding_type == 'tsne':
            emb = manifold.TSNE(n_components=2, init='pca', random_state=0)
        elif embedding_type == 'PCA' or embedding_type == 'pca':
            emb = PCA(n_components=2, svd_solver='randomized', random_state=0)
        elif embedding_type == 'spectral':
            emb = manifold.SpectralEmbedding(n_components=2, random_state=0)

        y = emb.fit_transform(score_array)

        # Do a classification analysis
        labels = np.unique(pathologies.reshape(-1, 1), return_inverse=True)[1]

        # Run SVM
        svm = LinearSVC(C=C, dual=False, class_weight='balanced')
        clf = make_pipeline(preprocessing.StandardScaler(), svm)
        predictions = cross_val_predict(clf, score_array, labels, cv=k_folds)
        cv_performance = metrics.accuracy_score(predictions, labels)
        clf.fit(score_array, labels)
        # mu = dec_scores.mean(0)
        # sd = dec_scores.std(0)
        print '%s-fold SVM performance: accuracy = %s%%' % (
            k_folds,
            np.mean(cv_performance * 100))
        np.savez(
            os.path.join(out_dir, 'svm_data'),
            yhat=score_array,
            y=labels,
            cv_performance=cv_performance,
            # mu=mu,
            # sd=sd,
            C=C)

        # Ouput csv
        df = pd.DataFrame(
            np.hstack((
                y,
                pathologies.reshape(-1, 1),
                file_array.reshape(-1, 1))),
            columns=['dim1', 'dim2', 'pathology', 'filename'])
        out_name = os.path.join(out_dir, 'raw_embedding.csv')
        df.to_csv(out_name)
        print 'Saved csv to: %s' % out_name

        create_figs(
            emb=emb,
            out_dir=out_dir,
            out_name=out_name,
            embedding_type=embedding_type,
            embedding_name='raw_embedding')

        # Now work on zscored data
        y = emb.fit_transform(z_score_array)

        # Ouput csv
        df = pd.DataFrame(
            np.hstack((
                y,
                pathologies.reshape(-1, 1),
                file_array.reshape(-1, 1))),
            columns=['dim1', 'dim2', 'pathology', 'filename'])
        out_name = os.path.join(out_dir, 'embedding.csv')
        df.to_csv(out_name)
        print 'Saved csv to: %s' % out_name

        # Create plot
        create_figs(
            emb=emb,
            out_dir=out_dir,
            out_name=out_name,
            embedding_type=embedding_type,
            embedding_name='normalized_embedding')

    else:
        # Do a classification (sign of the score)
        decisions = np.sign(score_array)
        df = pd.DataFrame(
            np.hstack(decisions, score_array),
            columns=['Decisions', 'Scores'])
        df.to_csv(
            os.path.join(
                out_dir, 'tracking_model_scores.csv'))