コード例 #1
0
ファイル: train_vgg16.py プロジェクト: jdlamstein/GEDI3
def train_vgg16(train_dir=None, validation_dir=None):
    config = GEDIconfig()
    if train_dir is None:  # Use globals
        train_data = os.path.join(config.tfrecord_dir,
                                  config.tf_record_names['train'])
        meta_data = np.load(
            os.path.join(config.tfrecord_dir,
                         '%s_%s' % (config.tvt_flags[0], config.max_file)))
    else:
        meta_data = np.load(
            os.path.join(train_dir,
                         '%s_%s' % (config.tvt_flags[0], config.max_file)))

    # Prepare image normalization values
    if config.max_gedi is None:
        max_value = np.nanmax(meta_data['max_array']).astype(np.float32)
        if max_value == 0:
            max_value = None
            print('Derived max value is 0')
        else:
            print('Normalizing with empirical max.')
        if 'min_array' in meta_data.keys():
            min_value = np.min(meta_data['min_array']).astype(np.float32)
            print('Normalizing with empirical min.')
        else:
            min_value = None
            print('Not normalizing with a min.')
    else:
        max_value = config.max_gedi
        min_value = config.min_gedi
        print('max value', max_value)
    ratio = meta_data['ratio']
    if config.encode_time_of_death:
        tod = pd.read_csv(config.encode_time_of_death)
        tod_data = tod['dead_tp'].as_matrix()
        mask = np.isnan(tod_data).astype(int) + (
            tod['plate_well_neuron'] == 'empty').as_matrix().astype(int)
        tod_data = tod_data[mask == 0]
        tod_data = tod_data[tod_data <
                            10]  # throw away values have a high number
        config.output_shape = len(np.unique(tod_data))
        ratio = class_weight.compute_class_weight('balanced',
                                                  np.sort(np.unique(tod_data)),
                                                  tod_data)
        flip_ratio = False
    else:
        flip_ratio = True
    print('Ratio is: %s' % ratio)

    if validation_dir is None:  # Use globals
        validation_data = os.path.join(config.tfrecord_dir,
                                       config.tf_record_names['val'])
    elif validation_dir is False:
        pass  # Do not use validation data during training

    # Make output directories if they do not exist
    dt_stamp = re.split(
        '\.', str(datetime.now()))[0].\
        replace(' ', '_').replace(':', '_').replace('-', '_')
    dt_dataset = config.which_dataset + '_' + dt_stamp + '/'
    config.train_checkpoint = os.path.join(config.train_checkpoint,
                                           dt_dataset)  # timestamp this run
    out_dir = os.path.join(config.results, dt_dataset)
    dir_list = [
        config.train_checkpoint, config.train_summaries, config.results,
        out_dir
    ]
    [make_dir(d) for d in dir_list]
    # im_shape = get_image_size(config)
    im_shape = config.gedi_image_size

    print('-' * 60)
    print('Training model:' + dt_dataset)
    print('-' * 60)

    # Prepare data on CPU
    assert os.path.exists(train_data)
    assert os.path.exists(validation_data)
    assert os.path.exists(config.vgg16_weight_path)
    with tf.device('/cpu:0'):
        train_images, train_labels = inputs(train_data,
                                            config.train_batch,
                                            im_shape,
                                            config.model_image_size[:2],
                                            max_value=max_value,
                                            min_value=min_value,
                                            train=config.data_augmentations,
                                            num_epochs=config.epochs,
                                            normalize=config.normalize)
        val_images, val_labels = inputs(validation_data,
                                        config.validation_batch,
                                        im_shape,
                                        config.model_image_size[:2],
                                        max_value=max_value,
                                        min_value=min_value,
                                        num_epochs=config.epochs,
                                        normalize=config.normalize)
        tf.summary.image('train images', train_images)
        tf.summary.image('validation images', val_images)

    # Prepare model on GPU
    with tf.device('/gpu:0'):
        with tf.variable_scope('cnn') as scope:
            if config.ordinal_classification == 'ordinal':  # config.output_shape > 2:  # Hardcoded fix for timecourse pred
                vgg_output = config.output_shape * 2
            elif config.ordinal_classification == 'regression':
                vgg_output = 1
            elif config.ordinal_classification is None:
                vgg_output = config.output_shape
            vgg = vgg16.Vgg16(vgg16_npy_path=config.vgg16_weight_path,
                              fine_tune_layers=config.fine_tune_layers)
            train_mode = tf.get_variable(name='training', initializer=True)
            vgg.build(train_images,
                      output_shape=vgg_output,
                      train_mode=train_mode,
                      batchnorm=config.batchnorm_layers)

            # Prepare the cost function
            if config.ordinal_classification == 'ordinal':
                # Encode y w/ k-hot and yhat w/ sigmoid ce. units capture dist.
                enc = tf.concat([
                    tf.reshape(tf.range(0, config.output_shape), [1, -1])
                    for x in range(config.train_batch)
                ],
                                axis=0)
                enc_train_labels = tf.cast(
                    tf.greater_equal(enc, tf.expand_dims(train_labels,
                                                         axis=1)), tf.float32)
                split_labs = tf.split(enc_train_labels,
                                      config.output_shape,
                                      axis=1)
                res_output = tf.reshape(
                    vgg.fc8, [config.train_batch, 2, config.output_shape])
                split_logs = tf.split(res_output, config.output_shape, axis=2)
                if config.balance_cost:
                    cost = tf.add_n([
                        tf.nn.softmax_cross_entropy_with_logits(
                            labels=tf.one_hot(tf.cast(tf.squeeze(s), tf.int32),
                                              2),
                            logits=tf.squeeze(l)) * r
                        for s, l, r in zip(split_labs, split_logs, ratio)
                    ])
                else:
                    cost = tf.add_n([
                        tf.nn.softmax_cross_entropy_with_logits(
                            labels=tf.one_hot(tf.cast(tf.squeeze(s), tf.int32),
                                              2),
                            logits=tf.squeeze(l))
                        for s, l in zip(split_labs, split_logs)
                    ])
                cost = tf.reduce_mean(cost)
            elif config.ordinal_classification == 'regression':
                if config.balance_cost:
                    weight_vec = tf.gather(train_labels, ratio)
                    cost = tf.reduce_mean(
                        tf.pow(
                            (vgg.fc8) - tf.cast(train_labels, tf.float32), 2) *
                        weight_vec)
                else:
                    cost = tf.nn.l2_loss((vgg.fc8) -
                                         tf.cast(train_labels, tf.float32))
            else:
                if config.balance_cost:
                    cost = softmax_cost(vgg.fc8,
                                        train_labels,
                                        ratio=ratio,
                                        flip_ratio=flip_ratio)
                else:
                    cost = softmax_cost(vgg.fc8, train_labels)
            tf.summary.scalar("cost", cost)

            # Finetune the learning rates
            if config.wd_layers is not None:
                _, l2_wd_layers = fine_tune_prepare_layers(
                    tf.trainable_variables(), config.wd_layers)
                l2_wd_layers = [
                    x for x in l2_wd_layers if 'biases' not in x.name
                ]
                cost += (config.wd_penalty *
                         tf.add_n([tf.nn.l2_loss(x) for x in l2_wd_layers]))

            # for all variables in trainable variables
            # print name if there's duplicates you f****d up
            other_opt_vars, ft_opt_vars = fine_tune_prepare_layers(
                tf.trainable_variables(), config.fine_tune_layers)
            if config.optimizer == 'adam':
                train_op = ft_non_optimized(cost, other_opt_vars, ft_opt_vars,
                                            tf.train.AdamOptimizer,
                                            config.hold_lr, config.new_lr)
            elif config.optimizer == 'sgd':
                train_op = ft_non_optimized(cost, other_opt_vars, ft_opt_vars,
                                            tf.train.GradientDescentOptimizer,
                                            config.hold_lr, config.new_lr)

            if config.ordinal_classification == 'ordinal':
                arg_guesses = tf.cast(
                    tf.reduce_sum(tf.squeeze(tf.argmax(res_output, axis=1)),
                                  reduction_indices=[1]), tf.int32)
                train_accuracy = tf.reduce_mean(
                    tf.cast(tf.equal(arg_guesses, train_labels), tf.float32))
            elif config.ordinal_classification == 'regression':
                train_accuracy = tf.reduce_mean(
                    tf.cast(
                        tf.equal(tf.cast(tf.round(vgg.prob), tf.int32),
                                 train_labels), tf.float32))
            else:
                train_accuracy = class_accuracy(
                    vgg.prob, train_labels)  # training accuracy
            tf.summary.scalar("training accuracy", train_accuracy)

            # Setup validation op
            if validation_data is not False:
                scope.reuse_variables()
                # Validation graph is the same as training except no batchnorm
                val_vgg = vgg16.Vgg16(vgg16_npy_path=config.vgg16_weight_path,
                                      fine_tune_layers=config.fine_tune_layers)
                val_vgg.build(val_images, output_shape=vgg_output)
                # Calculate validation accuracy
                if config.ordinal_classification == 'ordinal':
                    val_res_output = tf.reshape(
                        val_vgg.fc8,
                        [config.validation_batch, 2, config.output_shape])
                    val_arg_guesses = tf.cast(
                        tf.reduce_sum(tf.squeeze(
                            tf.argmax(val_res_output, axis=1)),
                                      reduction_indices=[1]), tf.int32)
                    val_accuracy = tf.reduce_mean(
                        tf.cast(tf.equal(val_arg_guesses, val_labels),
                                tf.float32))
                elif config.ordinal_classification == 'regression':
                    val_accuracy = tf.reduce_mean(
                        tf.cast(
                            tf.equal(tf.cast(tf.round(val_vgg.prob), tf.int32),
                                     val_labels), tf.float32))
                else:
                    val_accuracy = class_accuracy(val_vgg.prob, val_labels)
                tf.summary.scalar("validation accuracy", val_accuracy)

    # Set up summaries and saver
    saver = tf.train.Saver(tf.global_variables(),
                           max_to_keep=config.keep_checkpoints)
    summary_op = tf.summary.merge_all()

    # Initialize the graph
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    # Need to initialize both of these if supplying num_epochs to inputs
    sess.run(
        tf.group(tf.global_variables_initializer(),
                 tf.local_variables_initializer()))
    summary_dir = os.path.join(config.train_summaries,
                               config.which_dataset + '_' + dt_stamp)
    summary_writer = tf.summary.FileWriter(summary_dir, sess.graph)

    # Set up exemplar threading
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    # Start training loop
    np.save(out_dir + 'meta_info', config)
    step, losses = 0, []  # val_max = 0
    try:
        # print response
        while not coord.should_stop():
            start_time = time.time()
            _, loss_value, train_acc = sess.run(
                [train_op, cost, train_accuracy])
            losses.append(loss_value)
            duration = time.time() - start_time
            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            if step % config.validation_steps == 0:
                if validation_data is not False:
                    _, val_acc = sess.run([train_op, val_accuracy])
                else:
                    val_acc -= 1  # Store every checkpoint

                # Summaries
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

                # Training status and validation accuracy
                format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; '
                              '%.3f sec/batch) | Training accuracy = %s | '
                              'Validation accuracy = %s | logdir = %s')
                print(format_str %
                      (datetime.now(), step,
                       loss_value, config.train_batch / duration,
                       float(duration), train_acc, val_acc, summary_dir))

                # Save the model checkpoint if it's the best yet
                if 1:  # val_acc >= val_max:
                    saver.save(sess,
                               os.path.join(config.train_checkpoint,
                                            'model_' + str(step) + '.ckpt'),
                               global_step=step)
                    # Store the new max validation accuracy
                    # val_max = val_acc

            else:
                # Training status
                format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; '
                              '%.3f sec/batch) | Training accuracy = %s')
                print(format_str %
                      (datetime.now(), step, loss_value, config.train_batch /
                       duration, float(duration), train_acc))
            # End iteration
            step += 1

    except tf.errors.OutOfRangeError:
        print('Done training for %d epochs, %d steps.' % (config.epochs, step))
    finally:
        coord.request_stop()
        np.save(os.path.join(config.tfrecord_dir, 'training_loss'), losses)
    coord.join(threads)
    sess.close()
コード例 #2
0
def train_vgg16(train_dir=None, validation_dir=None):
    config = GEDIconfig()
    if train_dir is None:  # Use globals
        train_data = os.path.join(
            config.tfrecord_dir,
            config.tf_record_names['train'])
        meta_data = np.load(
            os.path.join(
                config.tfrecord_dir,
                '%s_%s' % (config.tvt_flags[0], config.max_file)))
    else:
        meta_data = np.load(
            os.path.join(
                train_dir,
                '%s_%s' % (config.tvt_flags[0], config.max_file)))

    # Prepare image normalization values
    if config.max_gedi is None:
        max_value = np.nanmax(meta_data['max_array']).astype(np.float32)
        if max_value == 0:
            max_value = None
            print 'Derived max value is 0'
        else:
            print 'Normalizing with empirical max.'
        if 'min_array' in meta_data.keys():
            min_value = np.min(meta_data['min_array']).astype(np.float32)
            print 'Normalizing with empirical min.'
        else:
            min_value = None
            print 'Not normalizing with a min.'
    else:
        max_value = config.max_gedi
        min_value = config.min_gedi
    ratio = meta_data['ratio']
    if config.encode_time_of_death:
        tod = pd.read_csv(config.encode_time_of_death)
        tod_data = tod['dead_tp'].as_matrix()
        mask = np.isnan(tod_data).astype(int) + (
            tod['plate_well_neuron'] == 'empty').as_matrix().astype(int)
        tod_data = tod_data[mask == 0]
        tod_data = tod_data[tod_data > config.mask_timepoint_value]
        config.output_shape = len(np.unique(tod_data))
        ratio = class_weight.compute_class_weight(
            'balanced',
            np.sort(np.unique(tod_data)),
            tod_data)
        flip_ratio = False
    else:
        flip_ratio = True
    print 'Ratio is: %s' % ratio

    if validation_dir is None:  # Use globals
        validation_data = os.path.join(
            config.tfrecord_dir,
            config.tf_record_names['val'])
    elif validation_dir is False:
        pass  # Do not use validation data during training

    # Make output directories if they do not exist
    dt_stamp = re.split(
        '\.', str(datetime.now()))[0].\
        replace(' ', '_').replace(':', '_').replace('-', '_')
    dt_dataset = config.which_dataset + '_' + dt_stamp + '/'
    config.train_checkpoint = os.path.join(
        config.train_checkpoint, dt_dataset)  # timestamp this run
    out_dir = os.path.join(config.results, dt_dataset)
    dir_list = [
        config.train_checkpoint, config.train_summaries,
        config.results, out_dir]
    [make_dir(d) for d in dir_list]
    # im_shape = get_image_size(config)
    im_shape = config.gedi_image_size

    print '-'*60
    print('Training model:' + dt_dataset)
    print '-'*60

    # Prepare data on CPU
    assert os.path.exists(train_data)
    assert os.path.exists(validation_data)
    assert os.path.exists(config.vgg16_weight_path)
    with tf.device('/cpu:0'):
        train_images, train_labels, train_gedi_images = inputs(
            train_data,
            config.train_batch,
            im_shape,
            config.model_image_size[:2],
            max_value=max_value,
            min_value=min_value,
            train=config.data_augmentations,
            num_epochs=config.epochs,
            normalize=config.normalize,
            return_gedi=config.include_GEDI_in_tfrecords,
            return_extra_gfp=config.extra_image,
            return_GEDI_derivative=True)
        val_images, val_labels, val_gedi_images = inputs(
            validation_data,
            config.validation_batch,
            im_shape,
            config.model_image_size[:2],
            max_value=max_value,
            min_value=min_value,
            num_epochs=config.epochs,
            normalize=config.normalize,
            return_gedi=config.include_GEDI_in_tfrecords,
            return_extra_gfp=config.extra_image,
            return_GEDI_derivative=True)
        if config.include_GEDI_in_tfrecords:
            extra_im_name = 'GEDI at current timepoint'
        else:
            extra_im_name = 'next gfp timepoint'
        tf.summary.image('train images', train_images)
        tf.summary.image('validation images', val_images)
        tf.summary.image('train %s' % extra_im_name, train_gedi_images)
        tf.summary.image('validation %s' % extra_im_name, val_gedi_images)

    # Prepare model on GPU
    with tf.device('/gpu:0'):
        with tf.variable_scope('cnn') as scope:
            if config.ordinal_classification is None:
                vgg_output = 2  # Sign of derivative (inf norm)
                train_labels = tf.cast(tf.sign(train_labels), tf.int32)
                val_labels = tf.cast(tf.sign(val_labels), tf.int32)
            elif config.ordinal_classification == 'regression':
                vgg_output = 1
            else:
                raise RuntimeError(
                    'config.ordinal_classification must be sign or regression.'
                    )
            vgg = vgg16.model_struct()
            train_mode = tf.get_variable(name='training', initializer=True)

            # Mask NAN images from loss
            image_nan = tf.reduce_sum(
                tf.cast(tf.is_nan(train_images), tf.float32),
                reduction_indices=[1, 2, 3])
            gedi_nan = tf.reduce_sum(
                tf.cast(tf.is_nan(train_gedi_images), tf.float32),
                reduction_indices=[1, 2, 3],
                keep_dims=True)
            image_mask = tf.cast(tf.equal(image_nan, 0.), tf.float32)
            gedi_nan = tf.cast(tf.equal(gedi_nan, 0.), tf.float32)
            train_images = tf.where(
                tf.is_nan(train_images),
                tf.zeros_like(train_images),
                train_images)
            train_gedi_images = tf.where(
                tf.is_nan(train_gedi_images),
                tf.zeros_like(train_gedi_images),
                train_gedi_images)
            train_images = tf.concat([train_images, train_images, train_images], axis=3)
            val_images = tf.concat([val_images, val_images, val_images], axis=3)
            vgg.build(
                train_images, output_shape=vgg_output,
                train_mode=train_mode, batchnorm=config.batchnorm_layers)
            # Prepare the cost function
            if config.ordinal_classification is None:
                # Encode y w/ k-hot and yhat w/ sigmoid ce. units capture dist.
                cost = softmax_cost(
                    vgg.fc8,
                    train_labels,
                    mask=image_mask)
            elif config.ordinal_classification == 'regression':
                cost = tf.nn.l2_loss(tf.squeeze(vgg.fc8) - train_labels)

            class_loss = cost
            tf.summary.scalar("cce cost", cost)

            # Weight decay
            if config.wd_layers is not None:
                _, l2_wd_layers = fine_tune_prepare_layers(
                    tf.trainable_variables(), config.wd_layers)
                l2_wd_layers = [
                    x for x in l2_wd_layers if 'biases' not in x.name]
                if len(l2_wd_layers) > 0:
                    cost += (config.wd_penalty * tf.add_n(
                        [tf.nn.l2_loss(x) for x in l2_wd_layers]))

            # Optimize
            train_op = tf.train.AdamOptimizer(config.new_lr).minimize(cost)
            if config.ordinal_classification is None:
                train_accuracy = class_accuracy(
                    vgg.prob, train_labels)  # training accuracy
            elif config.ordinal_classification == 'regression':
                train_accuracy = tf.nn.l2_loss(
                    tf.squeeze(vgg.fc8) - train_labels)
            tf.summary.scalar("training accuracy", train_accuracy)

            # Setup validation op
            if validation_data is not False:
                scope.reuse_variables()
                # Validation graph is the same as training except no batchnorm
                val_vgg = vgg16.model_struct(
                    fine_tune_layers=config.fine_tune_layers)
                val_vgg.build(val_images, output_shape=vgg_output)
                # Calculate validation accuracy
                if config.ordinal_classification is None:
                    val_accuracy = class_accuracy(val_vgg.prob, val_labels)
                elif config.ordinal_classification == 'regression':
                    val_accuracy = tf.nn.l2_loss(tf.squeeze(val_vgg.fc8) - val_labels)
                tf.summary.scalar("validation accuracy", val_accuracy)

    # Set up summaries and saver
    saver = tf.train.Saver(
        tf.global_variables(), max_to_keep=config.keep_checkpoints)
    summary_op = tf.summary.merge_all()

    # Initialize the graph
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    # Need to initialize both of these if supplying num_epochs to inputs
    sess.run(tf.group(tf.global_variables_initializer(),
             tf.local_variables_initializer()))
    summary_dir = os.path.join(
        config.train_summaries, config.which_dataset + '_' + dt_stamp)
    summary_writer = tf.summary.FileWriter(summary_dir, sess.graph)

    # Set up exemplar threading
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    # Restore model if requested
    if config.restore_path is not None:
        print '-' * 60
        print 'Restoring from a previous model: %s' % config.restore_path
        print '-' * 60
        saver.restore(sess, config.restore_path)

    # Start training loop
    np.save(out_dir + 'meta_info', config)
    step, losses = 0, []  # val_max = 0

    try:
        # print response
        while not coord.should_stop():
            start_time = time.time()
            _, loss_value, train_acc = sess.run(
                [train_op, cost, train_accuracy])
            losses.append(loss_value)
            duration = time.time() - start_time
            if np.isnan(loss_value).sum():
                import ipdb;ipdb.set_trace()
            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
            if step % config.validation_steps == 0:
                if validation_data is not False:
                    _, val_acc = sess.run([train_op, val_accuracy])
                else:
                    val_acc -= 1  # Store every checkpoint

                # Summaries
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

                # Training status and validation accuracy
                format_str = (
                    '%s: step %d, loss = %.2f (%.1f examples/sec; '
                    '%.3f sec/batch) | Training accuracy = %s | '
                    'Training %s = %s | Training class loss = %s | '
                    'Validation accuracy = %s | Validation %s = %s | '
                    'logdir = %s')
                print (format_str % (
                    datetime.now(), step, loss_value,
                    config.train_batch / duration, float(duration),
                    train_acc, extra_im_name, 0.,
                    0., val_acc, extra_im_name,
                    0., summary_dir))

                # Save the model checkpoint if it's the best yet
                if 1:  # val_acc >= val_max:
                    saver.save(
                        sess, os.path.join(
                            config.train_checkpoint,
                            'model_' + str(step) + '.ckpt'), global_step=step)
                    # Store the new max validation accuracy
                    # val_max = val_acc

            else:
                # Training status
                format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; '
                              '%.3f sec/batch) | Training accuracy = %s | '
                              'Training %s = %s | Training class loss = %s')
                print (format_str % (datetime.now(), step, loss_value,
                                     config.train_batch / duration,
                                     float(duration), train_acc,
                                     extra_im_name, 0.,
                                     0.))
            # End iteration
            step += 1

    except tf.errors.OutOfRangeError:
        print('Done training for %d epochs, %d steps.' % (config.epochs, step))
    finally:
        coord.request_stop()
        np.save(os.path.join(config.tfrecord_dir, 'training_loss'), losses)
    coord.join(threads)
    sess.close()
コード例 #3
0
ファイル: train_matching.py プロジェクト: jdlamstein/GEDI3
def train_model(train_dir=None, validation_dir=None):
    config = GEDIconfig()
    if train_dir is None:  # Use globals
        train_data = os.path.join(config.tfrecord_dir,
                                  config.tf_record_names['train'])
        meta_data = np.load(
            os.path.join(config.tfrecord_dir,
                         '%s_%s' % (config.tvt_flags[0], config.max_file)))
    else:
        meta_data = np.load(
            os.path.join(train_dir,
                         '%s_%s' % (config.tvt_flags[0], config.max_file)))

    # Prepare image normalization values
    if config.max_gedi is None:
        max_value = np.nanmax(meta_data['max_array']).astype(np.float32)
        if max_value == 0:
            max_value = None
            print 'Derived max value is 0'
        else:
            print 'Normalizing with empirical max.'
        if 'min_array' in meta_data.keys():
            min_value = np.min(meta_data['min_array']).astype(np.float32)
            print 'Normalizing with empirical min.'
        else:
            min_value = None
            print 'Not normalizing with a min.'
    else:
        max_value = config.max_gedi
        min_value = config.min_gedi
    ratio = meta_data['ratio']
    print 'Ratio is: %s' % ratio

    if validation_dir is None:  # Use globals
        validation_data = os.path.join(config.tfrecord_dir,
                                       config.tf_record_names['val'])
    elif validation_dir is False:
        pass  # Do not use validation data during training

    # Make output directories if they do not exist
    dt_stamp = re.split(
        '\.', str(datetime.now()))[0].\
        replace(' ', '_').replace(':', '_').replace('-', '_')
    dt_dataset = config.which_dataset + '_' + dt_stamp + '/'
    config.train_checkpoint = os.path.join(config.train_checkpoint,
                                           dt_dataset)  # timestamp this run
    out_dir = os.path.join(config.results, dt_dataset)
    dir_list = [
        config.train_checkpoint, config.train_summaries, config.results,
        out_dir
    ]
    [make_dir(d) for d in dir_list]
    # im_shape = get_image_size(config)
    im_shape = config.gedi_image_size

    print '-' * 60
    print('Training model:' + dt_dataset)
    print '-' * 60

    # Prepare data on CPU
    assert os.path.exists(train_data)
    assert os.path.exists(validation_data)
    assert os.path.exists(config.vgg16_weight_path)
    with tf.device('/cpu:0'):
        train_images_0, train_images_1, train_labels, train_times = inputs(
            train_data,
            config.train_batch,
            im_shape,
            config.model_image_size[:2],
            max_value=max_value,
            min_value=min_value,
            train=config.data_augmentations,
            num_epochs=config.epochs,
            normalize=config.normalize,
            return_filename=True)
        val_images_0, val_images_1, val_labels, val_times = inputs(
            validation_data,
            config.validation_batch,
            im_shape,
            config.model_image_size[:2],
            max_value=max_value,
            min_value=min_value,
            num_epochs=config.epochs,
            normalize=config.normalize,
            return_filename=True)
        tf.summary.image('train image frame 0', train_images_0)
        tf.summary.image('train image frame 1', train_images_1)
        tf.summary.image('validation image frame 0', val_images_0)
        tf.summary.image('validation image frame 1', val_images_1)

    # Prepare model on GPU
    with tf.device('/gpu:0'):
        with tf.variable_scope('gedi'):
            # Build training GEDI model for frame 0
            vgg_train_mode = tf.get_variable(name='vgg_training',
                                             initializer=False)
            gedi_model_0 = vgg16.model_struct(
                vgg16_npy_path=config.gedi_weight_path, trainable=False)
            gedi_model_0.build(prep_images_for_gedi(train_images_0),
                               output_shape=2,
                               train_mode=vgg_train_mode)
            gedi_scores_0 = gedi_model_0.fc7

        with tf.variable_scope('match'):
            # Build matching model for frame 0
            model_0 = matching_gedi.model_struct()
            model_0.build(train_images_0)

            # Build frame 0 vector
            frame_0 = tf.concat([gedi_scores_0, model_0.output], axis=-1)

            # Build output layer
            if config.matching_combine == 'concatenate':
                output_shape = [int(frame_0.get_shape()[-1]) * 2, 2]
            elif config.matching_combine == 'subtract':
                output_shape = [int(frame_0.get_shape()[-1]), 2]
            else:
                raise RuntimeError

        # Build GEDI model for frame 1
        with tf.variable_scope('gedi', reuse=True):
            gedi_model_1 = vgg16.model_struct(
                vgg16_npy_path=config.gedi_weight_path, trainable=False)
            gedi_model_1.build(prep_images_for_gedi(train_images_1),
                               output_shape=2,
                               train_mode=vgg_train_mode)
            gedi_scores_1 = gedi_model_1.fc7

        with tf.variable_scope('match', reuse=True):
            # Build matching model for frame 1
            model_1 = matching_gedi.model_struct()
            model_1.build(train_images_1)

        # Build frame 0 and frame 1 vectors
        frame_1 = tf.concat([gedi_scores_1, model_1.output], axis=-1)

        with tf.variable_scope('output'):
            # Concatenate or subtract
            if config.matching_combine == 'concatenate':
                output_scores = tf.concat([frame_0, frame_1], axis=-1)
            elif config.matching_combine == 'subtract':
                output_scores = frame_0 - frame_1
            else:
                raise NotImplementedError

            # Build output layer
            output_shape = [int(output_scores.get_shape()[-1]), 2]
            output_weights = tf.get_variable(
                name='output_weights',
                shape=output_shape,
                initializer=tf.contrib.layers.xavier_initializer(
                    uniform=False))
            output_bias = tf.get_variable(name='output_bias',
                                          initializer=tf.truncated_normal(
                                              [output_shape[-1]], .0, .001))
            decision_logits = tf.nn.bias_add(
                tf.matmul(output_scores, output_weights), output_bias)
            train_soft_decisions = tf.nn.softmax(decision_logits)
            cost = softmax_cost(decision_logits, train_labels)
            tf.summary.scalar("cce loss", cost)
            cost += tf.nn.l2_loss(output_weights)

            # Weight decay
            if config.wd_layers is not None:
                _, l2_wd_layers = fine_tune_prepare_layers(
                    tf.trainable_variables(), config.wd_layers)
                l2_wd_layers = [
                    x for x in l2_wd_layers if 'biases' not in x.name
                ]
                if len(l2_wd_layers) > 0:
                    cost += (config.wd_penalty *
                             tf.add_n([tf.nn.l2_loss(x)
                                       for x in l2_wd_layers]))

        # Optimize
        train_op = tf.train.AdamOptimizer(config.new_lr).minimize(cost)
        train_accuracy = class_accuracy(train_soft_decisions,
                                        train_labels)  # training accuracy
        tf.summary.scalar("training accuracy", train_accuracy)

        # Setup validation op
        if validation_data is not False:
            with tf.variable_scope('gedi', reuse=tf.AUTO_REUSE):  # FIX THIS
                # Validation graph is the same as training except no batchnorm
                val_gedi_model_0 = vgg16.model_struct(
                    vgg16_npy_path=config.gedi_weight_path)
                val_gedi_model_0.build(prep_images_for_gedi(val_images_0),
                                       output_shape=2,
                                       train_mode=vgg_train_mode)
                val_gedi_scores_0 = val_gedi_model_0.fc7

                # Build GEDI model for frame 1
                val_gedi_model_1 = vgg16.model_struct(
                    vgg16_npy_path=config.gedi_weight_path)
                val_gedi_model_1.build(prep_images_for_gedi(val_images_1),
                                       output_shape=2,
                                       train_mode=vgg_train_mode)
                val_gedi_scores_1 = val_gedi_model_1.fc7

            with tf.variable_scope('match', reuse=tf.AUTO_REUSE):
                # Build matching model for frame 0
                val_model_0 = matching_gedi.model_struct()
                val_model_0.build(val_images_0)

                # Build matching model for frame 1
                val_model_1 = matching_gedi.model_struct()
                val_model_1.build(val_images_1)

            # Build frame 0 and frame 1 vectors
            val_frame_0 = tf.concat([val_gedi_scores_0, val_model_0.output],
                                    axis=-1)
            val_frame_1 = tf.concat([val_gedi_scores_1, val_model_1.output],
                                    axis=-1)

            # Concatenate or subtract
            if config.matching_combine == 'concatenate':
                val_output_scores = tf.concat([val_frame_0, val_frame_1],
                                              axis=-1)
            elif config.matching_combine == 'subtract':
                val_output_scores = val_frame_0 - val_frame_1
            else:
                raise NotImplementedError

            with tf.variable_scope('output', reuse=tf.AUTO_REUSE):
                # Build output layer
                val_output_weights = tf.get_variable(
                    name='val_output_weights',
                    shape=output_shape,
                    trainable=False,
                    initializer=tf.contrib.layers.xavier_initializer(
                        uniform=False))
                val_output_bias = tf.get_variable(
                    name='output_bias',
                    trainable=False,
                    initializer=tf.truncated_normal([output_shape[-1]], .0,
                                                    .001))
                val_decision_logits = tf.nn.bias_add(
                    tf.matmul(val_output_scores, val_output_weights),
                    val_output_bias)
                val_soft_decisions = tf.nn.softmax(val_decision_logits)

            # Calculate validation accuracy
            val_accuracy = class_accuracy(val_soft_decisions, val_labels)
            tf.summary.scalar("validation accuracy", val_accuracy)

    # Set up summaries and saver
    saver = tf.train.Saver(tf.global_variables(),
                           max_to_keep=config.keep_checkpoints)
    summary_op = tf.summary.merge_all()

    # Initialize the graph
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    # Need to initialize both of these if supplying num_epochs to inputs
    sess.run(
        tf.group(tf.global_variables_initializer(),
                 tf.local_variables_initializer()))
    summary_dir = os.path.join(config.train_summaries,
                               config.which_dataset + '_' + dt_stamp)
    summary_writer = tf.summary.FileWriter(summary_dir, sess.graph)

    # Set up exemplar threading
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    # Start training loop
    np.save(out_dir + 'meta_info', config)
    step, losses = 0, []  # val_max = 0
    try:
        # print response
        while not coord.should_stop():
            start_time = time.time()
            _, loss_value, train_acc, val_acc = sess.run(
                [train_op, cost, train_accuracy, val_accuracy])
            losses += [loss_value]
            duration = time.time() - start_time
            if np.isnan(loss_value).sum():
                assert not np.isnan(loss_value), 'Model loss = NaN'

            if step % config.validation_steps == 0:
                if validation_data is not False:
                    val_acc = sess.run(val_accuracy)
                else:
                    val_acc -= 1  # Store every checkpoint

                # Summaries
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

                # Training status and validation accuracy
                format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; '
                              '%.3f sec/batch) | Training accuracy = %s | '
                              'Validation accuracy = %s | '
                              'logdir = %s')
                print(format_str %
                      (datetime.now(), step,
                       loss_value, config.train_batch / duration,
                       float(duration), train_acc, val_acc, summary_dir))

                # Save the model checkpoint if it's the best yet
                if 1:  # val_acc >= val_max:
                    saver.save(sess,
                               os.path.join(config.train_checkpoint,
                                            'model_' + str(step) + '.ckpt'),
                               global_step=step)
                    # Store the new max validation accuracy
                    # val_max = val_acc

            else:
                # Training status
                format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; '
                              '%.3f sec/batch) | Training accuracy = %s | '
                              'Training loss = %s')
                print(format_str %
                      (datetime.now(), step, loss_value, config.train_batch /
                       duration, float(duration), loss_value))
            # End iteration
            step += 1

    except tf.errors.OutOfRangeError:
        print('Done training for %d epochs, %d steps.' % (config.epochs, step))
    finally:
        coord.request_stop()
        np.save(os.path.join(config.tfrecord_dir, 'training_loss'), losses)
    coord.join(threads)
    sess.close()
コード例 #4
0
ファイル: train_vgg16_lstm.py プロジェクト: jdlamstein/GEDI3
def train_vgg16(train_dir=None, validation_dir=None):
    config = GEDIconfig()
    if train_dir is None:  # Use globals
        train_data = os.path.join(config.tfrecord_dir, 'train.tfrecords')
        meta_data = np.load(
            os.path.join(
                config.tfrecord_dir, config.tvt_flags[0] + '_' +
                config.max_file))
    else:
        meta_data = np.load(
            os.path.join(
                train_dir, config.tvt_flags[0] + '_' + config.max_file))

    # Prepare image normalization values
    max_value = np.nanmax(meta_data['max_array']).astype(np.float32)
    if max_value == 0:
        max_value = None
        print 'Derived max value is 0'
    else:
        print 'Normalizing with empirical max.'
    if 'min_array' in meta_data.keys():
        min_value = np.min(meta_data['min_array']).astype(np.float32)
        print 'Normalizing with empirical min.'
    else:
        min_value = None
        print 'Not normalizing with a min.'
    ratio = meta_data['ratio']
    print 'Ratio is: %s' % ratio

    if validation_dir is None:  # Use globals
        validation_data = os.path.join(config.tfrecord_dir, 'val.tfrecords')
    elif validation_dir is False:
        pass  # Do not use validation data during training

    # Make output directories if they do not exist
    dt_stamp = re.split(
        '\.', str(datetime.now()))[0].\
        replace(' ', '_').replace(':', '_').replace('-', '_')
    dt_dataset = config.which_dataset + '_' + dt_stamp + '/'
    config.train_checkpoint = os.path.join(
        config.train_checkpoint, dt_dataset)  # timestamp this run
    out_dir = os.path.join(config.results, dt_dataset)
    dir_list = [
        config.train_checkpoint, config.train_summaries,
        config.results, out_dir]
    [make_dir(d) for d in dir_list]
    # im_shape = get_image_size(config)
    im_shape = config.gedi_image_size

    print '-'*60
    print('Training model:' + dt_dataset)
    print '-'*60

    # Prepare data on CPU
    with tf.device('/cpu:0'):
        train_images, train_labels = inputs(
            train_data,
            config.train_batch,
            im_shape,
            config.model_image_size[:2],
            max_value=max_value,
            min_value=min_value,
            train=config.data_augmentations,
            num_epochs=config.epochs,
            normalize=config.normalize)
        val_images, val_labels = inputs(
            validation_data,
            config.validation_batch,
            im_shape,
            config.model_image_size[:2],
            max_value=max_value,
            min_value=min_value,
            num_epochs=config.epochs,
            normalize=config.normalize)
        tf.summary.image('train images', train_images)
        tf.summary.image('validation images', val_images)

    # Prepare model on GPU
    num_timepoints = len(config.channel)
    with tf.device('/gpu:0'):
        with tf.variable_scope('cnn') as scope:

            # Prepare the loss function
            if config.balance_cost:
                cost_fun = lambda yhat, y: softmax_cost(yhat, y, ratio=ratio)
            else:
                cost_fun = lambda yhat, y: softmax_cost(yhat, y)

            def cond(i, state, images, output, loss, num_timepoints):  # NOT CORRECT
                return tf.less(i, num_timepoints)

            def body(
                    i,
                    images,
                    label,
                    cell,
                    state,
                    output,
                    loss,
                    vgg,
                    train_mode,
                    output_shape,
                    batchnorm_layers,
                    cost_fun,
                    score_layer='fc7'):
                vgg.build(
                    images[i],
                    output_shape=config.output_shape,
                    train_mode=train_mode,
                    batchnorm=config.batchnorm_layers)
                    it_output, state = cell(vgg[score_layer], state)
                    output= output.write(i, it_output)
                    it_loss = cost_fun(output, label)
                    loss = loss.write(i, it_loss)
                return (i+1, images, label, cell, state, output, loss, vgg, train_mode, output_shape, batchnorm_layers) 

            # Prepare LSTM loop
            train_mode = tf.get_variable(name='training', initializer=True)
            cell = lstm_layers(layer_units=config.lstm_units)
            output = tf.TensorArray(tf.float32, num_timepoints) # output array for LSTM loop -- one slot per timepoint
            loss = tf.TensorArray(tf.float32, num_timepoints)  # loss for LSTM loop -- one slot per timepoint
            i = tf.constant(0)  # iterator for LSTM loop
            loop_vars = [i, images, label, cell, state, output, loss, vgg, train_mode, output_shape, batchnorm_layers, cost_fun]

            # Run the lstm
            processed_list = tf.while_loop(
                cond=cond,
                body=body,
                loop_vars=loop_vars,
                back_prop=True,
                swap_memory=False)
            output_cell = processed_list[3]
            output_state = processed_list[4]
            output_activity = processed_list[5]
            cost = processed_list[6]

            # Optimize
            combined_cost = tf.reduce_sum(cost)
            tf.summary.scalar('cost', combined_cost)
            import ipdb;ipdb.set_trace()  # Need to make sure lstm weights are being trained
            other_opt_vars, ft_opt_vars = fine_tune_prepare_layers(
                tf.trainable_variables(), config.fine_tune_layers)
            if config.optimizer == 'adam':
                train_op = ft_non_optimized(
                    cost, other_opt_vars, ft_opt_vars,
                    tf.train.AdamOptimizer, config.hold_lr, config.new_lr)
            elif config.optimizer == 'sgd':
                train_op = ft_non_optimized(
                    cost, other_opt_vars, ft_opt_vars,
                    tf.train.GradientDescentOptimizer,
                    config.hold_lr, config.new_lr)

            train_accuracy = class_accuracy(
                vgg.prob, train_labels)  # training accuracy
            tf.summary.scalar("training accuracy", train_accuracy)

            # Setup validation op
            if validation_data is not False:
                scope.reuse_variables()
                # Validation graph is the same as training except no batchnorm
                val_vgg = vgg16.Vgg16(
                    vgg16_npy_path=config.vgg16_weight_path,
                    fine_tune_layers=config.fine_tune_layers)
                val_vgg.build(val_images, output_shape=config.output_shape)
                # Calculate validation accuracy
                val_accuracy = class_accuracy(val_vgg.prob, val_labels)
                tf.summary.scalar("validation accuracy", val_accuracy)
コード例 #5
0
def train_model(train_dir=None,
                validation_dir=None,
                debug=True,
                resume_ckpt=None,
                resume_meta=None,
                margin=.4):
    config = GEDIconfig()
    if resume_meta is not None:
        config = np.load(resume_meta).item()
    assert margin is not None, 'Need a margin for the loss.'
    if train_dir is None:  # Use globals
        train_data = os.path.join(config.tfrecord_dir,
                                  config.tf_record_names['train'])
        meta_data = np.load(
            os.path.join(config.tfrecord_dir,
                         '%s_%s' % (config.tvt_flags[0], config.max_file)))
    else:
        meta_data = np.load(
            os.path.join(train_dir,
                         '%s_%s' % (config.tvt_flags[0], config.max_file)))

    # Prepare image normalization values
    if config.max_gedi is None:
        max_value = np.nanmax(meta_data['max_array']).astype(np.float32)
        if max_value == 0:
            max_value = None
            print('Derived max value is 0')
        else:
            print('Normalizing with empirical max.')
        if 'min_array' in meta_data.keys():
            # min_value = np.min(meta_data['min_array']).astype(np.float32)
            print('Normalizing with empirical min.')
        else:
            # min_value = None
            print('Not normalizing with a min.')
    else:
        max_value = config.max_gedi
        # min_value = config.min_gedi
    ratio = meta_data['ratio']
    print('Ratio is: %s' % ratio)

    if validation_dir is None:  # Use globals
        validation_data = os.path.join(config.tfrecord_dir,
                                       config.tf_record_names['val'])
    elif validation_dir is False:
        pass  # Do not use validation data during training

    # Make output directories if they do not exist
    dt_stamp = re.split(
        '\.', str(datetime.now()))[0].\
        replace(' ', '_').replace(':', '_').replace('-', '_')
    dt_dataset = config.which_dataset + '_' + dt_stamp + '/'
    config.train_checkpoint = os.path.join(config.train_checkpoint,
                                           dt_dataset)  # timestamp this run
    out_dir = os.path.join(config.results, dt_dataset)
    dir_list = [
        config.train_checkpoint, config.train_summaries, config.results,
        out_dir
    ]
    [tf_fun.make_dir(d) for d in dir_list]
    # im_shape = get_image_size(config)
    im_shape = config.gedi_image_size

    print('-' * 60)
    print('Training model:' + dt_dataset)
    print('-' * 60)

    # Prepare data on CPU
    assert os.path.exists(train_data)
    assert os.path.exists(validation_data)
    assert os.path.exists(config.vgg16_weight_path)
    with tf.device('/cpu:0'):
        train_images, train_labels, train_times = inputs(
            train_data,
            config.train_batch,
            im_shape,
            config.model_image_size,
            # max_value=max_value,
            # min_value=min_value,
            train=config.data_augmentations,
            num_epochs=config.epochs,
            normalize=config.normalize,
            return_filename=True)
        val_images, val_labels, val_times = inputs(
            validation_data,
            config.validation_batch,
            im_shape,
            config.model_image_size,
            # max_value=max_value,
            # min_value=min_value,
            num_epochs=config.epochs,
            normalize=config.normalize,
            return_filename=True)
        train_image_list, val_image_list = [], []
        for idx in range(int(train_images.get_shape()[1])):
            train_image_list += [tf.gather(train_images, idx, axis=1)]
            val_image_list += [tf.gather(val_images, idx, axis=1)]
            tf.summary.image('train_image_frame_%s' % idx,
                             train_image_list[idx])
            tf.summary.image('validation_image_frame_%s' % idx,
                             val_image_list[idx])

    # Prepare model on GPU
    config.l2_norm = False
    config.norm_axis = 0
    config.dist_fun = 'pearson'
    config.per_batch = False
    config.include_GEDI = True  # False
    config.output_shape = 32
    config.margin = margin
    with tf.device('/gpu:0'):
        with tf.variable_scope('match'):
            # Build matching model for frame 0
            model_0 = vgg16.model_struct(
                vgg16_npy_path=config.gedi_weight_path)  # ,
            frame_activity = []
            model_activity = model_0.build(train_image_list[0],
                                           output_shape=config.output_shape,
                                           include_GEDI=config.include_GEDI)
            if config.l2_norm:
                model_activity = tf_fun.l2_normalize(model_activity,
                                                     axis=config.norm_axis)
            frame_activity += [model_activity]

        with tf.variable_scope('match', reuse=tf.AUTO_REUSE):
            # Build matching model for other frames
            for idx in range(1, len(train_image_list)):
                model_activity = model_0.build(
                    train_image_list[idx],
                    output_shape=config.output_shape,
                    include_GEDI=config.include_GEDI)
                if config.l2_norm:
                    model_activity = tf_fun.l2_normalize(model_activity,
                                                         axis=config.norm_axis)
                frame_activity += [model_activity]

        if config.dist_fun == 'l2':
            pos = tf_fun.l2_dist(frame_activity[0], frame_activity[1], axis=1)
            neg = tf_fun.l2_dist(frame_activity[0], frame_activity[2], axis=1)
        elif config.dist_fun == 'pearson':
            pos = tf_fun.pearson_dist(frame_activity[0],
                                      frame_activity[1],
                                      axis=1)
            neg = tf_fun.pearson_dist(frame_activity[0],
                                      frame_activity[2],
                                      axis=1)
        else:
            raise NotImplementedError(config.dist_fun)
        if config.per_batch:
            loss = tf.maximum(tf.reduce_mean(pos - neg) + margin, 0.)
        else:
            loss = tf.reduce_mean(tf.maximum(pos - neg + margin, 0.))
        tf.summary.scalar('Triplet_loss', loss)

        # Weight decay
        if config.wd_layers is not None:
            _, l2_wd_layers = tf_fun.fine_tune_prepare_layers(
                tf.trainable_variables(), config.wd_layers)
            l2_wd_layers = [x for x in l2_wd_layers if 'biases' not in x.name]
            if len(l2_wd_layers) > 0:
                loss += (config.wd_penalty *
                         tf.add_n([tf.nn.l2_loss(x) for x in l2_wd_layers]))

        # Optimize
        train_op = tf.train.AdamOptimizer(config.new_lr).minimize(loss)
        train_accuracy = tf.reduce_mean(
            tf.cast(
                tf.equal(
                    tf.nn.relu(tf.sign(neg - pos)),  # 1 if pos < neg
                    tf.cast(tf.ones_like(train_labels), tf.float32)),
                tf.float32))
        tf.summary.scalar('training_accuracy', train_accuracy)

        # Setup validation op
        if validation_data is not False:
            with tf.variable_scope('match', tf.AUTO_REUSE) as match:
                # Build matching model for frame 0
                match.reuse_variables()
                val_model_0 = vgg16.model_struct(
                    vgg16_npy_path=config.gedi_weight_path)
                val_frame_activity = []
                model_activity = val_model_0.build(
                    val_image_list[0],
                    output_shape=config.output_shape,
                    include_GEDI=config.include_GEDI)
                if config.l2_norm:
                    model_activity = tf_fun.l2_normalize(model_activity,
                                                         axis=config.norm_axis)
                val_frame_activity += [model_activity]

                # Build matching model for other frames
                for idx in range(1, len(train_image_list)):
                    model_activity = val_model_0.build(
                        val_image_list[idx],
                        output_shape=config.output_shape,
                        include_GEDI=config.include_GEDI)
                    if config.l2_norm:
                        model_activity = tf_fun.l2_normalize(
                            model_activity, axis=config.norm_axis)
                    val_frame_activity += [model_activity]
            if config.dist_fun == 'l2':
                val_pos = tf_fun.l2_dist(val_frame_activity[0],
                                         val_frame_activity[1],
                                         axis=1)
                val_neg = tf_fun.l2_dist(val_frame_activity[0],
                                         val_frame_activity[2],
                                         axis=1)
            elif config.dist_fun == 'pearson':
                val_pos = tf_fun.pearson_dist(val_frame_activity[0],
                                              val_frame_activity[1],
                                              axis=1)
                val_neg = tf_fun.pearson_dist(val_frame_activity[0],
                                              val_frame_activity[2],
                                              axis=1)
            if config.per_batch:
                val_loss = tf.maximum(
                    tf.reduce_mean(val_pos - val_neg) + margin, 0.)
            else:
                val_loss = tf.reduce_mean(
                    tf.maximum(val_pos - val_neg + margin, 0.))
            tf.summary.scalar('Validation_triplet_loss', val_loss)

        # Calculate validation accuracy
        val_accuracy = tf.reduce_mean(
            tf.cast(
                tf.equal(tf.nn.relu(tf.sign(val_neg - val_pos)),
                         tf.cast(tf.ones_like(val_labels), tf.float32)),
                tf.float32))
        tf.summary.scalar('val_accuracy', val_accuracy)

    # Set up summaries and saver
    saver = tf.train.Saver(tf.global_variables(),
                           max_to_keep=config.keep_checkpoints)
    summary_op = tf.summary.merge_all()

    # Initialize the graph
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    # Need to initialize both of these if supplying num_epochs to inputs
    sess.run(
        tf.group(tf.global_variables_initializer(),
                 tf.local_variables_initializer()))
    summary_dir = os.path.join(config.train_summaries,
                               config.which_dataset + '_' + dt_stamp)
    summary_writer = tf.summary.FileWriter(summary_dir, sess.graph)

    # Set up exemplar threading
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    # Train operations
    train_dict = {
        'train_op': train_op,
        'loss': loss,
        'pos': pos,
        'neg': neg,
        'train_accuracy': train_accuracy,
        'val_accuracy': val_accuracy,
    }
    val_dict = {'val_accuracy': val_accuracy}
    if debug:
        for idx in range(len(train_image_list)):
            train_dict['train_im_%s' % idx] = train_image_list[idx]
        for idx in range(len(val_image_list)):
            val_dict['val_im_%s' % idx] = val_image_list[idx]

    # Resume training if requested
    if resume_ckpt is not None:
        print('*' * 50)
        print('Resuming training from: %s' % resume_ckpt)
        print('*' * 50)
        saver.restore(sess, resume_ckpt)

    # Start training loop
    np.save(out_dir + 'meta_info', config)
    step, losses = 0, []
    top_vals = np.asarray([0])
    try:
        # print response
        while not coord.should_stop():
            start_time = time.time()
            train_values = sess.run(train_dict.values())
            it_train_dict = {
                k: v
                for k, v in zip(train_dict.keys(), train_values)
            }
            losses += [it_train_dict['loss']]
            duration = time.time() - start_time
            if np.isnan(it_train_dict['loss']).sum():
                assert not np.isnan(it_train_dict['loss']),\
                    'Model loss = NaN'

            if step % config.validation_steps == 0:
                if validation_data is not False:
                    val_accs = []
                    for vit in range(config.validation_iterations):
                        val_values = sess.run(val_dict.values())
                        it_val_dict = {
                            k: v
                            for k, v in zip(val_dict.keys(), val_values)
                        }
                        val_accs += [it_val_dict['val_accuracy']]
                    val_acc = np.nanmean(val_accs)
                else:
                    val_acc -= 1  # Store every checkpoint

                # Summaries
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

                # Training status and validation accuracy
                format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; '
                              '%.3f sec/batch) | Training accuracy = %s | '
                              'Validation accuracy = %s | '
                              'logdir = %s')
                print(format_str %
                      (datetime.now(), step,
                       it_train_dict['loss'], config.train_batch / duration,
                       float(duration), it_train_dict['train_accuracy'],
                       it_train_dict['val_accuracy'], summary_dir))

                # Save the model checkpoint if it's the best yet
                top_vals = top_vals[:config.num_keep_checkpoints]
                check_val = val_acc > top_vals
                if check_val.sum():
                    saver.save(sess,
                               os.path.join(config.train_checkpoint,
                                            'model_' + str(step) + '.ckpt'),
                               global_step=step)
                    # Store the new validation accuracy
                    top_vals = np.append(top_vals, val_acc)

            else:
                # Training status
                format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; '
                              '%.3f sec/batch) | Training accuracy = %s | ')
                print(format_str %
                      (datetime.now(), step,
                       it_train_dict['loss'], config.train_batch / duration,
                       float(duration), it_train_dict['train_accuracy']))
            # End iteration
            step += 1

    except tf.errors.OutOfRangeError:
        print('Done training for %d epochs, %d steps.' % (config.epochs, step))
    finally:
        coord.request_stop()
        np.save(os.path.join(config.tfrecord_dir, 'training_loss'), losses)
    coord.join(threads)
    sess.close()