Example #1
0
def gather_activations(sess: tf.Session, pipeline: TinyImageNetPipeline, model: BaselineResNet,
                       mode: tf.estimator.ModeKeys, only_correct_ones: bool = True) -> None:
    """
    Feeds samples of the given mode through the given model and accumulates the activation values for correctly
    classified samples. Writes the activations to .mat files.
    """
    values_per_file = 100000
    n = pipeline.get_num_samples(mode)

    pipeline.switch_to(mode)
    export_tensors = model.activations.copy()
    export_tensors['target_labels'] = model.labels

    def get_blank_export_vals() -> Dict:
        blank_dict = {}
        for k in export_tensors.keys():
            blank_dict[k] = []
        return blank_dict

    export_vals = get_blank_export_vals()

    skipped_ctr, file_ctr = 0, 0
    for i in range(n):
        sample_export_val, accuracy = sess.run([export_tensors, model.accuracy])
        if accuracy < 1 and only_correct_ones:
            skipped_ctr += 1
            tf.logging.info("Skipping misclassified sample #{}".format(skipped_ctr))
        else:
            for key in sample_export_val.keys():
                export_vals[key].append(sample_export_val[key][0])  # unpack batches and push into storage
        tf.logging.info("Progress: {}/{}".format(i, n))
        if (i > 0 and i % values_per_file == 0) or (i+1) == n:
            save_activations(file_ctr, FLAGS.activations_export_file, export_vals)
            export_vals = get_blank_export_vals()
            file_ctr += 1
def run_validation(model: BaselineLESCIResNet,
                   pipeline: TinyImageNetPipeline,
                   mode: tf.estimator.ModeKeys,
                   verbose: bool = False,
                   batch_size: int = 100) -> LESCIMetrics:
    """
    Feeds all validation/train samples through the model and report classification accuracy and loss.
    :return: a LESCIMetrics tuple of the following (float-) values:
            - accuracy: mean overall accuracy
            - loss: mean overall loss
            - accuracy_projection: accuracy mean of the projected samples
            - accuracy_identity: accuracy mean of the identity-mapped samples
            - percentage_identity_mapped: percentage of inputs that have been identity-mapped
    """
    tf.logging.info("Running evaluation on with mode {}.".format(mode))
    init = tf.group(tf.global_variables_initializer(),
                    tf.local_variables_initializer())

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True  # dynamic GPU memory allocation
    sess = tf.Session(config=config)

    with sess.as_default():
        try:
            sess.run(init, feed_dict=model.init_feed_dict)
        except InvalidArgumentError:
            tf.logging.info(
                "Could not execute the init op, trying to restore all variable."
            )
        model.restore(sess)

        pipeline.switch_to(mode)

        if verbose:
            tf.logging.info("Starting evaluation")
        vals = []
        acc_mean_val, loss_mean_val, acc_proj_mean_val, acc_id_mean_val, id_mapped_mean_val = 0., 0., 0., 0., 0.
        num_samples = pipeline.get_num_samples(mode)
        n = num_samples // batch_size
        missed_samples = num_samples % batch_size
        if missed_samples > 0 and verbose:
            tf.logging.warning(
                "Omitting {} samples because the batch size ({}) is not a divisor of the number of "
                "samples ({}).".format(missed_samples, num_samples,
                                       batch_size))

        fetches = [
            model.accuracy, model.loss, model.accuracy_projection,
            model.accuracy_identity, model.percentage_identity_mapped
        ]
        for i in range(n):
            vals.append(sess.run(fetches))
            acc_mean_val, loss_mean_val, acc_proj_mean_val, acc_id_mean_val, id_mapped_mean_val = np.mean(
                vals, axis=0)
            if verbose:
                tf.logging.info(
                    "[{:,}/{:,}]\tCurrent overall accuracy: {:.3f}\tprojection: {:.3f}\tid-mapping: {:.3f}"
                    "\tpercentage id-mapped: {:.3f}".format(
                        i, n, acc_mean_val, acc_proj_mean_val, acc_id_mean_val,
                        id_mapped_mean_val))

        if verbose:
            tf.logging.info(
                "[Done] Mean: accuracy {:.3f}, projection accuracy {:.3f}, identity mapping accuracy "
                "{:.3f}, loss {:.3f}, id-mapped {:.3f}".format(
                    acc_mean_val, acc_proj_mean_val, acc_id_mean_val,
                    loss_mean_val, id_mapped_mean_val))
    return LESCIMetrics(acc_mean_val, loss_mean_val, acc_proj_mean_val,
                        acc_id_mean_val, id_mapped_mean_val)