def main(argv): utils.setup_main() del argv utils.setup_tf() nbatch = FLAGS.samples // FLAGS.batch dataset = data.DATASETS()[FLAGS.dataset]() groups = [('labeled', dataset.train_labeled), ('unlabeled', dataset.train_unlabeled), ('test', dataset.test.repeat())] groups = [(name, ds.batch( FLAGS.batch).prefetch(16).make_one_shot_iterator().get_next()) for name, ds in groups] with tf.train.MonitoredSession() as sess: for group, train_data in groups: stats = np.zeros(dataset.nclass, np.int32) minmax = [], [] for _ in trange(nbatch, leave=False, unit='img', unit_scale=FLAGS.batch, desc=group): v = sess.run(train_data)['label'] for u in v: stats[u] += 1 minmax[0].append(v.min()) minmax[1].append(v.max()) print(group) print(' Label range', min(minmax[0]), max(minmax[1])) print( ' Stats', ' '.join(['%.2f' % (100 * x) for x in (stats / stats.max())]))
def main(argv): del argv utils.setup_tf() dataset = DATASETS[FLAGS.dataset]() with tf.Session(config=utils.get_config()) as sess: hashes = (collect_hashes(sess, 'labeled', dataset.eval_labeled), collect_hashes(sess, 'unlabeled', dataset.eval_unlabeled), collect_hashes(sess, 'validation', dataset.valid), collect_hashes(sess, 'test', dataset.test)) print('Overlap matrix (should be an almost perfect diagonal matrix with counts).') groups = 'labeled unlabeled validation test'.split() fmt = '%-10s %10s %10s %10s %10s' print(fmt % tuple([''] + groups)) for p, x in enumerate(hashes): overlaps = [len(x & y) for y in hashes] print(fmt % tuple([groups[p]] + overlaps))
wd=FLAGS.wd, arch=FLAGS.arch, warmup_pos=FLAGS.warmup_pos, batch=FLAGS.batch, nclass=dataset.nclass, ema=FLAGS.ema, beta=FLAGS.beta, consistency_weight=FLAGS.consistency_weight, scales=FLAGS.scales or (log_width - 2), filters=FLAGS.filters, repeat=FLAGS.repeat) model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10) if __name__ == '__main__': utils.setup_tf() flags.DEFINE_float('consistency_weight', 50., 'Consistency weight.') flags.DEFINE_float( 'warmup_pos', 0.4, 'Relative position at which constraint loss warmup ends.') flags.DEFINE_float('wd', 0.02, 'Weight decay.') flags.DEFINE_float('ema', 0.999, 'Exponential moving average of params.') flags.DEFINE_float('beta', 0.5, 'Mixup beta.') flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.') flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.') flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.') FLAGS.set_default('augment', 'd.d.d') FLAGS.set_default('dataset', 'cifar10.3@250-5000') FLAGS.set_default('batch', 64) FLAGS.set_default('lr', 0.002)