Esempio n. 1
0
def main(argv):
    utils.setup_main()
    del argv  # Unused.
    dataset = data.PAIR_DATASETS()[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)
    model = FixMatch(
        os.path.join(FLAGS.train_dir, dataset.name, FixMatch.cta_name()),
        dataset,
        lr=FLAGS.lr,
        wd=FLAGS.wd,
        arch=FLAGS.arch,
        batch=FLAGS.batch,
        nclass=dataset.nclass,
        wu=FLAGS.wu,
        confidence=FLAGS.confidence,
        uratio=FLAGS.uratio,
        scales=FLAGS.scales or (log_width - 2),
        filters=FLAGS.filters,
        repeat=FLAGS.repeat,
        size_unlabeled=dataset.size_unlabeled,
        alpha=FLAGS.alpha,
        inf_warm=FLAGS.inf_warm,
        inner_steps=FLAGS.inner_steps,
    )
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
Esempio n. 2
0
def main(argv):
    utils.setup_main()
    del argv  # Unused.
    dataset = data.PAIR_DATASETS()[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)
    model = FixMatch_RA(os.path.join(FLAGS.train_dir, dataset.name),
                        dataset,
                        lr=FLAGS.lr,
                        wd=FLAGS.wd,
                        arch=FLAGS.arch,
                        batch=FLAGS.batch,
                        nclass=dataset.nclass,
                        wu=FLAGS.wu,
                        confidence=FLAGS.confidence,
                        uratio=FLAGS.uratio,
                        scales=FLAGS.scales or (log_width - 2),
                        filters=FLAGS.filters,
                        repeat=FLAGS.repeat)
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
Esempio n. 3
0
def main(argv):
    utils.setup_main()
    del argv  # Unused.
    dataset = data.PAIR_DATASETS()[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)
    model = CTAMixMatch(os.path.join(FLAGS.train_dir, dataset.name,
                                     CTAMixMatch.cta_name()),
                        dataset,
                        lr=FLAGS.lr,
                        wd=FLAGS.wd,
                        arch=FLAGS.arch,
                        batch=FLAGS.batch,
                        nclass=dataset.nclass,
                        ema=FLAGS.ema,
                        beta=FLAGS.beta,
                        w_match=FLAGS.w_match,
                        scales=FLAGS.scales or (log_width - 2),
                        filters=FLAGS.filters,
                        repeat=FLAGS.repeat)
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
Esempio n. 4
0
def main(argv):
    utils.setup_main()
    del argv  # Unused.
    dataset = data.PAIR_DATASETS()[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)
    model = AB_FixMatch_NoCutOut(
        os.path.join(FLAGS.train_dir, dataset.name, FixMatch.cta_name()),
        dataset,
        lr=FLAGS.lr,
        wd=FLAGS.wd,
        arch=FLAGS.arch,
        batch=FLAGS.batch,
        nclass=dataset.nclass,
        wu=FLAGS.wu,
        confidence=FLAGS.confidence,
        uratio=FLAGS.uratio,
        scales=FLAGS.scales or (log_width - 2),
        filters=FLAGS.filters,
        repeat=FLAGS.repeat)
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)  # 512 epochs (which is 524K parameter updates)
Esempio n. 5
0
def main(argv):
    utils.setup_main()
    del argv  # Unused.
    dataset = data.PAIR_DATASETS()[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)
    model = TranslationConsistencyRegularization(
        os.path.join(FLAGS.train_dir, dataset.name),
        dataset,
        lr=FLAGS.lr,
        wd=FLAGS.wd,
        arch=FLAGS.arch,
        warmup_pos=FLAGS.warmup_pos,
        batch=FLAGS.batch,
        nclass=dataset.nclass,
        ema=FLAGS.ema,
        smoothing=FLAGS.smoothing,
        consistency_weight=FLAGS.consistency_weight,
        tcr_augment=FLAGS.tcr_augment,

        scales=FLAGS.scales or (log_width - 2),
        filters=FLAGS.filters,
        repeat=FLAGS.repeat)
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
Esempio n. 6
0
def main(argv):
    utils.setup_main()
    del argv  # Unused.
    seedIndx = FLAGS.dataset.find('@')
    seed = int(FLAGS.dataset[seedIndx - 1])

    dataset = data.PAIR_DATASETS()[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)
    model = Frost(os.path.join(FLAGS.train_dir, dataset.name,
                               Frost.cta_name()),
                  dataset,
                  lr=FLAGS.lr,
                  wd=FLAGS.wd,
                  arch=FLAGS.arch,
                  batch=FLAGS.batch,
                  nclass=dataset.nclass,
                  wu=FLAGS.wu,
                  wclr=FLAGS.wclr,
                  mom=FLAGS.mom,
                  confidence=FLAGS.confidence,
                  balance=FLAGS.balance,
                  delT=FLAGS.delT,
                  uratio=FLAGS.uratio,
                  clrratio=FLAGS.clrratio,
                  temperature=FLAGS.temperature,
                  scales=FLAGS.scales or (log_width - 2),
                  filters=FLAGS.filters,
                  repeat=FLAGS.repeat)
    ###################### New code
    tic = time.perf_counter()
    if FLAGS.boot_factor > 1:
        numIter = 2
        numToLabel = [
            FLAGS.boot_factor, FLAGS.boot_factor * FLAGS.boot_factor, 0
        ]
        numImgs = [(FLAGS.train_kimg << 9), 3 * (FLAGS.train_kimg << 8),
                   (FLAGS.train_kimg << 10)]
        if FLAGS.boot_schedule == 1:
            steps = int((FLAGS.train_kimg << 10) / 3)
            numImgs = [steps, 2 * steps, 3 * steps]
        elif FLAGS.boot_schedule == 2:
            numIter = 3
            steps = FLAGS.train_kimg << 8
            numImgs = [steps, 2 * steps, 3 * steps, 4 * steps]
            numToLabel = [
                FLAGS.boot_factor, FLAGS.boot_factor**2, FLAGS.boot_factor**3,
                0
            ]

        datasetName = dataset.name[:dataset.name.find('.')]
        print("Dataset Name ", datasetName)
        letters = string.ascii_letters
        subfolder = ''.join(random.choice(letters) for i in range(8))
        FLAGS.data_subfolder = subfolder
        tf.gfile.MakeDirs(data.DATA_DIR + '/' + subfolder)
        if not tf.gfile.Exists(data.DATA_DIR + '/' + subfolder + '/' +
                               datasetName + '-unlabel.json'):
            infile = data.DATA_DIR + '/SSL2/' + datasetName + '-unlabel.'
            outfile = data.DATA_DIR + '/' + subfolder + '/' + datasetName + '-unlabel.'
            print("Copied from ", infile, "* to ", outfile + '*')
            tf.io.gfile.copy(infile + 'json', outfile + 'json')
            tf.io.gfile.copy(infile + 'tfrecord', outfile + 'tfrecord')

        for it in range(numIter):
            print(" Iiteration ", it, " until ", numImgs[it])
            model.train(numImgs[it], FLAGS.report_kimg << 10, numToLabel[it],
                        it)
            elapse = (time.perf_counter() - tic) / 3600
            print("After iteration ", it, " training time ", elapse, " hours")

            bootstrap = CreateSplit(
                os.path.join(FLAGS.train_dir, dataset.name, Frost.cta_name()))
            bootstrap.create_split(datasetName=datasetName,
                                   seed=seed,
                                   size=numToLabel[it] * dataset.nclass)

            target = datasetName + '.' + str(seed) + '@' + str(
                numToLabel[it] * dataset.nclass) + '-1'
            print("Target ", target)
            dataset = data.PAIR_DATASETS()[target]()
            log_width = utils.ilog2(dataset.width)
            model.updateDataset(dataset)

        print(" Iiteration 2 until ", numImgs[numIter])
        model.train(numImgs[numIter], FLAGS.report_kimg << 10, 0, numIter)
        tf.compat.v1.gfile.DeleteRecursively(data.DATA_DIR + '/' + subfolder)
    else:
        model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10, 0, 0)

    elapse = (time.perf_counter() - tic) / 3600
    print(f"Total training time {elapse:0.4f} hours")