def rebalanced_distorted_inputs(pos_sample_dir): """Construct distorted input with augmented positive samples for specific classes (symmetry). Args: pos_sample_dir: string. Directory of positive data. Returns: """ if not FLAGS.data_dir: raise ValueError('Please supply a data_dir') data_dir = os.path.join(FLAGS.data_dir, 'fbb_output') coefs1, images1, labels1 = nn_input.distorted_inputs(data_dir=data_dir, batch_size=int(FLAGS.batch_size / 2)) coefs2, images2, labels2 = nn_input.distorted_inputs(data_dir=pos_sample_dir, batch_size=int(FLAGS.batch_size / 2), input_type='aug') return tf.concat(0, [coefs1, coefs2]), \ tf.concat(0, [images1, images2]), \ tf.concat(0, [labels1, labels2])
def distorted_inputs(shuffle=True, num_threads=16, nodistort=False): """Construct distorted input for CIFAR training using the Reader ops. Returns: images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. labels: Labels. 1D tensor of [batch_size] size. Raises: ValueError: If no data_dir """ # if not FLAGS.data_dir: # raise ValueError('Please supply a data_dir') # data_dir = os.path.join(FLAGS.data_dir, 'fbb_output') return nn_input.distorted_inputs(run_config=run_config, batch_size=FLAGS.batch_size, shuffle=shuffle, num_threads=num_threads, nodistort=nodistort)