Exemple #1
0
def main(argv):
    utils.setup_main()
    del argv
    utils.setup_tf()
    nbatch = FLAGS.samples // FLAGS.batch
    dataset = data.DATASETS()[FLAGS.dataset]()
    groups = [('labeled', dataset.train_labeled),
              ('unlabeled', dataset.train_unlabeled),
              ('test', dataset.test.repeat())]
    groups = [(name, ds.batch(
        FLAGS.batch).prefetch(16).make_one_shot_iterator().get_next())
              for name, ds in groups]
    with tf.train.MonitoredSession() as sess:
        for group, train_data in groups:
            stats = np.zeros(dataset.nclass, np.int32)
            minmax = [], []
            for _ in trange(nbatch,
                            leave=False,
                            unit='img',
                            unit_scale=FLAGS.batch,
                            desc=group):
                v = sess.run(train_data)['label']
                for u in v:
                    stats[u] += 1
                minmax[0].append(v.min())
                minmax[1].append(v.max())
            print(group)
            print('  Label range', min(minmax[0]), max(minmax[1]))
            print(
                '  Stats',
                ' '.join(['%.2f' % (100 * x) for x in (stats / stats.max())]))
def main(argv):
    del argv
    utils.setup_tf()
    dataset = DATASETS[FLAGS.dataset]()
    with tf.Session(config=utils.get_config()) as sess:
        hashes = (collect_hashes(sess, 'labeled', dataset.eval_labeled),
                  collect_hashes(sess, 'unlabeled', dataset.eval_unlabeled),
                  collect_hashes(sess, 'validation', dataset.valid),
                  collect_hashes(sess, 'test', dataset.test))
    print('Overlap matrix (should be an almost perfect diagonal matrix with counts).')
    groups = 'labeled unlabeled validation test'.split()
    fmt = '%-10s %10s %10s %10s %10s'
    print(fmt % tuple([''] + groups))
    for p, x in enumerate(hashes):
        overlaps = [len(x & y) for y in hashes]
        print(fmt % tuple([groups[p]] + overlaps))
                wd=FLAGS.wd,
                arch=FLAGS.arch,
                warmup_pos=FLAGS.warmup_pos,
                batch=FLAGS.batch,
                nclass=dataset.nclass,
                ema=FLAGS.ema,
                beta=FLAGS.beta,
                consistency_weight=FLAGS.consistency_weight,
                scales=FLAGS.scales or (log_width - 2),
                filters=FLAGS.filters,
                repeat=FLAGS.repeat)
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)


if __name__ == '__main__':
    utils.setup_tf()
    flags.DEFINE_float('consistency_weight', 50., 'Consistency weight.')
    flags.DEFINE_float(
        'warmup_pos', 0.4,
        'Relative position at which constraint loss warmup ends.')
    flags.DEFINE_float('wd', 0.02, 'Weight decay.')
    flags.DEFINE_float('ema', 0.999, 'Exponential moving average of params.')
    flags.DEFINE_float('beta', 0.5, 'Mixup beta.')
    flags.DEFINE_integer('scales', 0,
                         'Number of 2x2 downscalings in the classifier.')
    flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.')
    flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.')
    FLAGS.set_default('augment', 'd.d.d')
    FLAGS.set_default('dataset', 'cifar10.3@250-5000')
    FLAGS.set_default('batch', 64)
    FLAGS.set_default('lr', 0.002)