Esempio n. 1
0
def main(argv):
    utils.setup_main()
    del argv
    utils.setup_tf()
    nbatch = FLAGS.samples // FLAGS.batch
    dataset = data.DATASETS()[FLAGS.dataset]()
    groups = [('labeled', dataset.train_labeled),
              ('unlabeled', dataset.train_unlabeled),
              ('test', dataset.test.repeat())]
    groups = [(name, ds.batch(
        FLAGS.batch).prefetch(16).make_one_shot_iterator().get_next())
              for name, ds in groups]
    with tf.train.MonitoredSession() as sess:
        for group, train_data in groups:
            stats = np.zeros(dataset.nclass, np.int32)
            minmax = [], []
            for _ in trange(nbatch,
                            leave=False,
                            unit='img',
                            unit_scale=FLAGS.batch,
                            desc=group):
                v = sess.run(train_data)['label']
                for u in v:
                    stats[u] += 1
                minmax[0].append(v.min())
                minmax[1].append(v.max())
            print(group)
            print('  Label range', min(minmax[0]), max(minmax[1]))
            print(
                '  Stats',
                ' '.join(['%.2f' % (100 * x) for x in (stats / stats.max())]))
Esempio n. 2
0
def main(argv):
    utils.setup_main()
    del argv  # Unused.
    dataset = data.DATASETS()[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)
    model = Mixup(os.path.join(FLAGS.train_dir, dataset.name),
                  dataset,
                  lr=FLAGS.lr,
                  wd=FLAGS.wd,
                  arch=FLAGS.arch,
                  batch=FLAGS.batch,
                  nclass=dataset.nclass,
                  ema=FLAGS.ema,
                  beta=FLAGS.beta,
                  scales=FLAGS.scales or (log_width - 2),
                  filters=FLAGS.filters,
                  repeat=FLAGS.repeat)
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
Esempio n. 3
0
def main(argv):
    utils.setup_main()
    del argv
    utils.setup_tf()
    dataset = data.DATASETS()[FLAGS.dataset]()
    with tf.Session(config=utils.get_config()) as sess:
        hashes = (collect_hashes(sess, 'labeled', dataset.eval_labeled),
                  collect_hashes(sess, 'unlabeled', dataset.eval_unlabeled),
                  collect_hashes(sess, 'validation', dataset.valid),
                  collect_hashes(sess, 'test', dataset.test))
    print(
        'Overlap matrix (should be an almost perfect diagonal matrix with counts).'
    )
    groups = 'labeled unlabeled validation test'.split()
    fmt = '%-10s %10s %10s %10s %10s'
    print(fmt % tuple([''] + groups))
    for p, x in enumerate(hashes):
        overlaps = [len(x & y) for y in hashes]
        print(fmt % tuple([groups[p]] + overlaps))
def main(argv):
    utils.setup_main()
    del argv  # Unused.
    dataset = data.DATASETS()[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)
    model = PseudoLabel(os.path.join(FLAGS.train_dir, dataset.name),
                        dataset,
                        lr=FLAGS.lr,
                        wd=FLAGS.wd,
                        arch=FLAGS.arch,
                        warmup_pos=FLAGS.warmup_pos,
                        batch=FLAGS.batch,
                        nclass=dataset.nclass,
                        ema=FLAGS.ema,
                        smoothing=FLAGS.smoothing,
                        consistency_weight=FLAGS.consistency_weight,
                        threshold=FLAGS.threshold,
                        scales=FLAGS.scales or (log_width - 2),
                        filters=FLAGS.filters,
                        repeat=FLAGS.repeat)
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)