Exemple #1
0
def main(argv):
    utils.setup_main()
    del argv  # Unused.
    dataset = data.MANY_DATASETS()[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)
    model = CTAReMixMatch(os.path.join(FLAGS.train_dir, dataset.name,
                                       CTAReMixMatch.cta_name()),
                          dataset,
                          lr=FLAGS.lr,
                          wd=FLAGS.wd,
                          arch=FLAGS.arch,
                          batch=FLAGS.batch,
                          nclass=dataset.nclass,
                          K=FLAGS.K,
                          beta=FLAGS.beta,
                          w_kl=FLAGS.w_kl,
                          w_match=FLAGS.w_match,
                          w_rot=FLAGS.w_rot,
                          redux=FLAGS.redux,
                          use_dm=FLAGS.use_dm,
                          use_xe=FLAGS.use_xe,
                          warmup_kimg=FLAGS.warmup_kimg,
                          scales=FLAGS.scales or (log_width - 2),
                          filters=FLAGS.filters,
                          repeat=FLAGS.repeat)
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
Exemple #2
0
def main(argv):
    utils.setup_main()
    del argv  # Unused.
    dataset = data.PAIR_DATASETS()[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)
    model = FixMatch(
        os.path.join(FLAGS.train_dir, dataset.name, FixMatch.cta_name()),
        dataset,
        lr=FLAGS.lr,
        wd=FLAGS.wd,
        arch=FLAGS.arch,
        batch=FLAGS.batch,
        nclass=dataset.nclass,
        wu=FLAGS.wu,
        confidence=FLAGS.confidence,
        uratio=FLAGS.uratio,
        scales=FLAGS.scales or (log_width - 2),
        filters=FLAGS.filters,
        repeat=FLAGS.repeat,
        size_unlabeled=dataset.size_unlabeled,
        alpha=FLAGS.alpha,
        inf_warm=FLAGS.inf_warm,
        inner_steps=FLAGS.inner_steps,
    )
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
Exemple #3
0
def main(argv):
    utils.setup_main()
    del argv
    utils.setup_tf()
    nbatch = FLAGS.samples // FLAGS.batch
    dataset = data.DATASETS()[FLAGS.dataset]()
    groups = [('labeled', dataset.train_labeled),
              ('unlabeled', dataset.train_unlabeled),
              ('test', dataset.test.repeat())]
    groups = [(name, ds.batch(
        FLAGS.batch).prefetch(16).make_one_shot_iterator().get_next())
              for name, ds in groups]
    with tf.train.MonitoredSession() as sess:
        for group, train_data in groups:
            stats = np.zeros(dataset.nclass, np.int32)
            minmax = [], []
            for _ in trange(nbatch,
                            leave=False,
                            unit='img',
                            unit_scale=FLAGS.batch,
                            desc=group):
                v = sess.run(train_data)['label']
                for u in v:
                    stats[u] += 1
                minmax[0].append(v.min())
                minmax[1].append(v.max())
            print(group)
            print('  Label range', min(minmax[0]), max(minmax[1]))
            print(
                '  Stats',
                ' '.join(['%.2f' % (100 * x) for x in (stats / stats.max())]))
Exemple #4
0
def main(argv):
    utils.setup_main()
    del argv  # Unused.
    dataset = data.DATASETS()[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)
    model = Mixup(os.path.join(FLAGS.train_dir, dataset.name),
                  dataset,
                  lr=FLAGS.lr,
                  wd=FLAGS.wd,
                  arch=FLAGS.arch,
                  batch=FLAGS.batch,
                  nclass=dataset.nclass,
                  ema=FLAGS.ema,
                  beta=FLAGS.beta,
                  scales=FLAGS.scales or (log_width - 2),
                  filters=FLAGS.filters,
                  repeat=FLAGS.repeat)
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
def main(argv):
    utils.setup_main()
    del argv  # Unused.
    dataset = data.PAIR_DATASETS()[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)
    model = FixMatch_RA(os.path.join(FLAGS.train_dir, dataset.name),
                        dataset,
                        lr=FLAGS.lr,
                        wd=FLAGS.wd,
                        arch=FLAGS.arch,
                        batch=FLAGS.batch,
                        nclass=dataset.nclass,
                        wu=FLAGS.wu,
                        confidence=FLAGS.confidence,
                        uratio=FLAGS.uratio,
                        scales=FLAGS.scales or (log_width - 2),
                        filters=FLAGS.filters,
                        repeat=FLAGS.repeat)
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
Exemple #6
0
def main(argv):
    utils.setup_main()
    del argv
    utils.setup_tf()
    dataset = data.DATASETS()[FLAGS.dataset]()
    with tf.Session(config=utils.get_config()) as sess:
        hashes = (collect_hashes(sess, 'labeled', dataset.eval_labeled),
                  collect_hashes(sess, 'unlabeled', dataset.eval_unlabeled),
                  collect_hashes(sess, 'validation', dataset.valid),
                  collect_hashes(sess, 'test', dataset.test))
    print(
        'Overlap matrix (should be an almost perfect diagonal matrix with counts).'
    )
    groups = 'labeled unlabeled validation test'.split()
    fmt = '%-10s %10s %10s %10s %10s'
    print(fmt % tuple([''] + groups))
    for p, x in enumerate(hashes):
        overlaps = [len(x & y) for y in hashes]
        print(fmt % tuple([groups[p]] + overlaps))
def main(argv):
    utils.setup_main()
    del argv  # Unused.
    dataset = PAIR_DATASETS()[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)
    model = ICT(os.path.join(FLAGS.train_dir, dataset.name),
                dataset,
                lr=FLAGS.lr,
                wd=FLAGS.wd,
                arch=FLAGS.arch,
                warmup_pos=FLAGS.warmup_pos,
                batch=FLAGS.batch,
                nclass=dataset.nclass,
                ema=FLAGS.ema,
                beta=FLAGS.beta,
                consistency_weight=FLAGS.consistency_weight,
                scales=FLAGS.scales or (log_width - 2),
                filters=FLAGS.filters,
                repeat=FLAGS.repeat)
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
Exemple #8
0
def main(argv):
    utils.setup_main()
    del argv  # Unused.
    dataset = data.PAIR_DATASETS()[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)
    model = AB_FixMatch_NoCutOut(
        os.path.join(FLAGS.train_dir, dataset.name, FixMatch.cta_name()),
        dataset,
        lr=FLAGS.lr,
        wd=FLAGS.wd,
        arch=FLAGS.arch,
        batch=FLAGS.batch,
        nclass=dataset.nclass,
        wu=FLAGS.wu,
        confidence=FLAGS.confidence,
        uratio=FLAGS.uratio,
        scales=FLAGS.scales or (log_width - 2),
        filters=FLAGS.filters,
        repeat=FLAGS.repeat)
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)  # 512 epochs (which is 524K parameter updates)
Exemple #9
0
def main(argv):
    utils.setup_main()
    del argv  # Unused.
    dataset = PAIR_DATASETS()[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)
    model = UDA(os.path.join(FLAGS.train_dir, dataset.name),
                dataset,
                lr=FLAGS.lr,
                wd=FLAGS.wd,
                wu=FLAGS.wu,
                we=FLAGS.we,
                arch=FLAGS.arch,
                batch=FLAGS.batch,
                nclass=dataset.nclass,
                temperature=FLAGS.temperature,
                tsa=FLAGS.tsa,
                tsa_pos=FLAGS.tsa_pos,
                confidence=FLAGS.confidence,
                uratio=FLAGS.uratio,
                scales=FLAGS.scales or (log_width - 2),
                filters=FLAGS.filters,
                repeat=FLAGS.repeat)
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)  # 1024 epochs
Exemple #10
0
def main(argv):
    utils.setup_main()
    del argv  # Unused.
    dataset = data.PAIR_DATASETS()[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)
    model = TranslationConsistencyRegularization(
        os.path.join(FLAGS.train_dir, dataset.name),
        dataset,
        lr=FLAGS.lr,
        wd=FLAGS.wd,
        arch=FLAGS.arch,
        warmup_pos=FLAGS.warmup_pos,
        batch=FLAGS.batch,
        nclass=dataset.nclass,
        ema=FLAGS.ema,
        smoothing=FLAGS.smoothing,
        consistency_weight=FLAGS.consistency_weight,
        tcr_augment=FLAGS.tcr_augment,

        scales=FLAGS.scales or (log_width - 2),
        filters=FLAGS.filters,
        repeat=FLAGS.repeat)
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
Exemple #11
0
def main(argv):
    utils.setup_main()
    del argv  # Unused.
    dataset = data.DATASETS()[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)
    model = VAT(
        os.path.join(FLAGS.train_dir, dataset.name),
        dataset,
        lr=FLAGS.lr,
        wd=FLAGS.wd,
        arch=FLAGS.arch,
        warmup_pos=FLAGS.warmup_pos,
        batch=FLAGS.batch,
        nclass=dataset.nclass,
        ema=FLAGS.ema,
        smoothing=FLAGS.smoothing,
        vat=FLAGS.vat,
        vat_eps=FLAGS.vat_eps,
        entmin_weight=FLAGS.entmin_weight,

        scales=FLAGS.scales or (log_width - 2),
        filters=FLAGS.filters,
        repeat=FLAGS.repeat)
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
Exemple #12
0
def main(argv):
    utils.setup_main()
    del argv  # Unused.
    seedIndx = FLAGS.dataset.find('@')
    seed = int(FLAGS.dataset[seedIndx - 1])

    dataset = data.PAIR_DATASETS()[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)
    model = Frost(os.path.join(FLAGS.train_dir, dataset.name,
                               Frost.cta_name()),
                  dataset,
                  lr=FLAGS.lr,
                  wd=FLAGS.wd,
                  arch=FLAGS.arch,
                  batch=FLAGS.batch,
                  nclass=dataset.nclass,
                  wu=FLAGS.wu,
                  wclr=FLAGS.wclr,
                  mom=FLAGS.mom,
                  confidence=FLAGS.confidence,
                  balance=FLAGS.balance,
                  delT=FLAGS.delT,
                  uratio=FLAGS.uratio,
                  clrratio=FLAGS.clrratio,
                  temperature=FLAGS.temperature,
                  scales=FLAGS.scales or (log_width - 2),
                  filters=FLAGS.filters,
                  repeat=FLAGS.repeat)
    ###################### New code
    tic = time.perf_counter()
    if FLAGS.boot_factor > 1:
        numIter = 2
        numToLabel = [
            FLAGS.boot_factor, FLAGS.boot_factor * FLAGS.boot_factor, 0
        ]
        numImgs = [(FLAGS.train_kimg << 9), 3 * (FLAGS.train_kimg << 8),
                   (FLAGS.train_kimg << 10)]
        if FLAGS.boot_schedule == 1:
            steps = int((FLAGS.train_kimg << 10) / 3)
            numImgs = [steps, 2 * steps, 3 * steps]
        elif FLAGS.boot_schedule == 2:
            numIter = 3
            steps = FLAGS.train_kimg << 8
            numImgs = [steps, 2 * steps, 3 * steps, 4 * steps]
            numToLabel = [
                FLAGS.boot_factor, FLAGS.boot_factor**2, FLAGS.boot_factor**3,
                0
            ]

        datasetName = dataset.name[:dataset.name.find('.')]
        print("Dataset Name ", datasetName)
        letters = string.ascii_letters
        subfolder = ''.join(random.choice(letters) for i in range(8))
        FLAGS.data_subfolder = subfolder
        tf.gfile.MakeDirs(data.DATA_DIR + '/' + subfolder)
        if not tf.gfile.Exists(data.DATA_DIR + '/' + subfolder + '/' +
                               datasetName + '-unlabel.json'):
            infile = data.DATA_DIR + '/SSL2/' + datasetName + '-unlabel.'
            outfile = data.DATA_DIR + '/' + subfolder + '/' + datasetName + '-unlabel.'
            print("Copied from ", infile, "* to ", outfile + '*')
            tf.io.gfile.copy(infile + 'json', outfile + 'json')
            tf.io.gfile.copy(infile + 'tfrecord', outfile + 'tfrecord')

        for it in range(numIter):
            print(" Iiteration ", it, " until ", numImgs[it])
            model.train(numImgs[it], FLAGS.report_kimg << 10, numToLabel[it],
                        it)
            elapse = (time.perf_counter() - tic) / 3600
            print("After iteration ", it, " training time ", elapse, " hours")

            bootstrap = CreateSplit(
                os.path.join(FLAGS.train_dir, dataset.name, Frost.cta_name()))
            bootstrap.create_split(datasetName=datasetName,
                                   seed=seed,
                                   size=numToLabel[it] * dataset.nclass)

            target = datasetName + '.' + str(seed) + '@' + str(
                numToLabel[it] * dataset.nclass) + '-1'
            print("Target ", target)
            dataset = data.PAIR_DATASETS()[target]()
            log_width = utils.ilog2(dataset.width)
            model.updateDataset(dataset)

        print(" Iiteration 2 until ", numImgs[numIter])
        model.train(numImgs[numIter], FLAGS.report_kimg << 10, 0, numIter)
        tf.compat.v1.gfile.DeleteRecursively(data.DATA_DIR + '/' + subfolder)
    else:
        model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10, 0, 0)

    elapse = (time.perf_counter() - tic) / 3600
    print(f"Total training time {elapse:0.4f} hours")