def main(argv): utils.setup_main() del argv # Unused. dataset = PAIR_DATASETS()[FLAGS.dataset]() log_width = utils.ilog2(dataset.width) model = MixMatch(os.path.join(FLAGS.train_dir, dataset.name), dataset, lr=FLAGS.lr, wd=FLAGS.wd, arch=FLAGS.arch, batch=FLAGS.batch, nclass=dataset.nclass, ema=FLAGS.ema, beta=FLAGS.beta, w_match=FLAGS.w_match, scales=FLAGS.scales or (log_width - 2), filters=FLAGS.filters, repeat=FLAGS.repeat) model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
def main(argv): utils.setup_main() del argv # Unused. dataset = PAIR_DATASETS()[FLAGS.dataset]() log_width = utils.ilog2(dataset.width) model = MeanTeacher(os.path.join(FLAGS.train_dir, dataset.name), dataset, lr=FLAGS.lr, wd=FLAGS.wd, arch=FLAGS.arch, warmup_pos=FLAGS.warmup_pos, batch=FLAGS.batch, nclass=dataset.nclass, ema=FLAGS.ema, smoothing=FLAGS.smoothing, consistency_weight=FLAGS.consistency_weight, scales=FLAGS.scales or (log_width - 2), filters=FLAGS.filters, repeat=FLAGS.repeat) model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
def main(argv): utils.setup_main() del argv # Unused. dataset = PAIR_DATASETS()[FLAGS.dataset]() log_width = utils.ilog2(dataset.width) model = UDA(os.path.join(FLAGS.train_dir, dataset.name), dataset, lr=FLAGS.lr, wd=FLAGS.wd, wu=FLAGS.wu, we=FLAGS.we, arch=FLAGS.arch, batch=FLAGS.batch, nclass=dataset.nclass, temperature=FLAGS.temperature, tsa=FLAGS.tsa, tsa_pos=FLAGS.tsa_pos, confidence=FLAGS.confidence, uratio=FLAGS.uratio, scales=FLAGS.scales or (log_width - 2), filters=FLAGS.filters, repeat=FLAGS.repeat) model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10) # 1024 epochs