Beispiel #1
0
def make_labeled_data_filter_fn(label_table):
    """Make filter for certain classes of labeled data."""
    class_filter = tf_utils.filter_fn_from_comma_delimited(
        FLAGS.labeled_classes_filter)
    if label_table:
        return lambda _, label, fkey: class_filter(label) & label_table.lookup(
            fkey)
    else:
        return lambda _, label, fkey: class_filter(label)
Beispiel #2
0
def make_unlabeled_data_filter_fn():
    """Make filter for certain classes and a random fraction of unlabeled
    data."""
    class_filter = tf_utils.filter_fn_from_comma_delimited(
        FLAGS.unlabeled_classes_filter)

    def random_frac_filter(fkey):
        return tf_utils.hash_float(fkey) < FLAGS.unlabeled_data_random_fraction

    return lambda _, label, fkey: class_filter(label) & random_frac_filter(fkey
                                                                           )
Beispiel #3
0
def make_labeled_data_filter():
    """Make filter for certain classes of labeled data."""
    if FLAGS.primary_dataset_name in {
            "cifar100", "tinyimagenet_32", "cifar100_tinyimagenet"
    }:
        labels = FLAGS.labeled_classes_filter.split(',')
        labels = range(int(labels[0]), int(labels[1]))
        labeled_classes_filter = ",".join([str(x) for x in labels])
    else:
        labeled_classes_filter = FLAGS.labeled_classes_filter
    print(labeled_classes_filter)
    class_filter = tf_utils.filter_fn_from_comma_delimited(
        labeled_classes_filter)
    return lambda image, label, index, fkey: class_filter(label)
Beispiel #4
0
def make_unlabeled_data_filter_fn():
    """Make filter for certain classes and a random fraction of unlabeled
    data."""
    if FLAGS.secondary_dataset_name in {
            "cifar100", "tinyimagenet_32", "cifar100_tinyimagenet"
    }:
        labels = FLAGS.unlabeled_classes_filter.split(',')
        labels = range(int(labels[0]), int(labels[1]))
        labeled_classes_filter = ",".join([str(x) for x in labels])
    else:
        labeled_classes_filter = FLAGS.unlabeled_classes_filter
    print(labeled_classes_filter)
    class_filter = tf_utils.filter_fn_from_comma_delimited(
        labeled_classes_filter)

    def random_frac_filter(fkey):
        return tf_utils.hash_float(fkey) < FLAGS.unlabeled_data_random_fraction

    return lambda _, label, index, fkey: class_filter(
        label) & random_frac_filter(fkey)
def make_labeled_data_filter():
    """Make filter for certain classes of labeled data."""
    class_filter = tf_utils.filter_fn_from_comma_delimited(
        FLAGS.labeled_classes_filter)
    return lambda image, label, fkey: class_filter(label)