def make_labeled_data_filter_fn(label_table): """Make filter for certain classes of labeled data.""" class_filter = tf_utils.filter_fn_from_comma_delimited( FLAGS.labeled_classes_filter) if label_table: return lambda _, label, fkey: class_filter(label) & label_table.lookup( fkey) else: return lambda _, label, fkey: class_filter(label)
def make_unlabeled_data_filter_fn(): """Make filter for certain classes and a random fraction of unlabeled data.""" class_filter = tf_utils.filter_fn_from_comma_delimited( FLAGS.unlabeled_classes_filter) def random_frac_filter(fkey): return tf_utils.hash_float(fkey) < FLAGS.unlabeled_data_random_fraction return lambda _, label, fkey: class_filter(label) & random_frac_filter(fkey )
def make_labeled_data_filter(): """Make filter for certain classes of labeled data.""" if FLAGS.primary_dataset_name in { "cifar100", "tinyimagenet_32", "cifar100_tinyimagenet" }: labels = FLAGS.labeled_classes_filter.split(',') labels = range(int(labels[0]), int(labels[1])) labeled_classes_filter = ",".join([str(x) for x in labels]) else: labeled_classes_filter = FLAGS.labeled_classes_filter print(labeled_classes_filter) class_filter = tf_utils.filter_fn_from_comma_delimited( labeled_classes_filter) return lambda image, label, index, fkey: class_filter(label)
def make_unlabeled_data_filter_fn(): """Make filter for certain classes and a random fraction of unlabeled data.""" if FLAGS.secondary_dataset_name in { "cifar100", "tinyimagenet_32", "cifar100_tinyimagenet" }: labels = FLAGS.unlabeled_classes_filter.split(',') labels = range(int(labels[0]), int(labels[1])) labeled_classes_filter = ",".join([str(x) for x in labels]) else: labeled_classes_filter = FLAGS.unlabeled_classes_filter print(labeled_classes_filter) class_filter = tf_utils.filter_fn_from_comma_delimited( labeled_classes_filter) def random_frac_filter(fkey): return tf_utils.hash_float(fkey) < FLAGS.unlabeled_data_random_fraction return lambda _, label, index, fkey: class_filter( label) & random_frac_filter(fkey)
def make_labeled_data_filter(): """Make filter for certain classes of labeled data.""" class_filter = tf_utils.filter_fn_from_comma_delimited( FLAGS.labeled_classes_filter) return lambda image, label, fkey: class_filter(label)