def main(argv): utils.setup_main() del argv # Unused. dataset = data.MANY_DATASETS()[FLAGS.dataset]() log_width = utils.ilog2(dataset.width) model = CTAReMixMatch(os.path.join(FLAGS.train_dir, dataset.name, CTAReMixMatch.cta_name()), dataset, lr=FLAGS.lr, wd=FLAGS.wd, arch=FLAGS.arch, batch=FLAGS.batch, nclass=dataset.nclass, K=FLAGS.K, beta=FLAGS.beta, w_kl=FLAGS.w_kl, w_match=FLAGS.w_match, w_rot=FLAGS.w_rot, redux=FLAGS.redux, use_dm=FLAGS.use_dm, use_xe=FLAGS.use_xe, warmup_kimg=FLAGS.warmup_kimg, scales=FLAGS.scales or (log_width - 2), filters=FLAGS.filters, repeat=FLAGS.repeat) model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
def main(argv): del argv # Unused. dataset = DATASETS[FLAGS.dataset]() log_width = utils.ilog2(dataset.width) # Generating model directory if FLAGS.gamma == None: model_dir = 'ERM' elif FLAGS.gamma > 0: model_dir = 'WDRO_' + str(FLAGS.gamma) else: assert False, 'Check the regularization parameter gamma' model = FSBaseline(os.path.join(FLAGS.train_dir, model_dir, dataset.name), dataset, lr=FLAGS.lr, wd=FLAGS.wd, arch=FLAGS.arch, batch=FLAGS.batch, nclass=dataset.nclass, ema=FLAGS.ema, smoothing=FLAGS.smoothing, scales=FLAGS.scales or (log_width - 2), filters=FLAGS.filters, repeat=FLAGS.repeat, gamma=FLAGS.gamma) model.train(FLAGS.nckpt * FLAGS.ckptsize, FLAGS.ckptsize) #(total # of data, ckpt size)
def disc(self, x, x_lr, resolution, filters): conv_args = dict(padding='same', kernel_initializer=tf.glorot_uniform_initializer()) lr_h, lr_w, lr_c = [ tf.cast(v, tf.float32) for v in utils.smart_shape(x_lr)[1:] ] colors = utils.smart_shape(x)[3] log_res = utils.ilog2(resolution) with tf.variable_scope('disc', reuse=tf.AUTO_REUSE): h = x h = d_optimized_resnet_block(h, filters) h = d_resnet_block(h, filters) for block in range(1, self.log_scale - 1): h = d_resnet_block(h, filters << block, downsample=True) h = d_resnet_block(h, filters << block) h = d_resnet_block(h, filters << (self.log_scale - 1), downsample=True) h = d_resnet_block(h, filters << (self.log_scale - 1)) lr_disc = layers.conv2d_spectral_norm(h, colors, 3, ** conv_args) * x_lr lr_disc = tf.reduce_sum(lr_disc, [1, 2, 3]) * tf.rsqrt( lr_h * lr_w * lr_c) lr_disc = tf.reshape(lr_disc, [-1, 1]) for block in range(self.log_scale, log_res - 2): h = d_resnet_block(h, filters << block, downsample=True) h = d_resnet_block(h, filters << block) h = tf.reduce_sum(h, [1, 2]) * (1 / 4.) hr_disc = layers.dense_spectral_norm( h, 1, kernel_initializer=tf.glorot_uniform_initializer()) return lr_disc + hr_disc
def main(argv): del argv # Unused. dataset = DATASETS[FLAGS.dataset]() log_width = utils.ilog2(dataset.width) #generating model directory if FLAGS.gamma == None: model_dir = 'MIXUP' elif FLAGS.gamma > 0: model_dir = 'WDRO_MIX_' + str(FLAGS.gamma) else: assert False, 'Check the penalty parameter gamma' model = MixupGrad(os.path.join(FLAGS.train_dir, model_dir, dataset.name), dataset, lr=FLAGS.lr, wd=FLAGS.wd, arch=FLAGS.arch, batch=FLAGS.batch, nclass=dataset.nclass, ema=FLAGS.ema, beta=FLAGS.beta, gamma=FLAGS.gamma, scales=FLAGS.scales or (log_width - 2), filters=FLAGS.filters, repeat=FLAGS.repeat) model.train(FLAGS.nckpt * FLAGS.ckptsize, FLAGS.ckptsize) #(total # of data, epoch size)
def main(argv): del argv # Unused. assert FLAGS.dataset.split('.')[0] in [ 'cifar10', 'cifar100', 'svhn', 'svhn_extra' ] dataset = DATASETS[FLAGS.dataset]() log_width = utils.ilog2(dataset.width) model = MixMatch_LinearGrow( os.path.join(FLAGS.train_dir, dataset.name.split('@')[0] + '_train' + \ dataset.name.split('train')[-1] + '_Grow'), dataset, lr=FLAGS.lr, wd=FLAGS.wd, arch=FLAGS.arch, batch=FLAGS.batch, nclass=dataset.nclass, ema=FLAGS.ema, beta=FLAGS.beta, logit_norm=FLAGS.logit_norm, T=FLAGS.T, mixmode=FLAGS.mixmode, nu=FLAGS.nu, dbuf=FLAGS.dbuf, w_match=FLAGS.w_match, warmup_kimg=FLAGS.warmup_kimg, scales=FLAGS.scales or (log_width - 2), filters=FLAGS.filters, repeat=FLAGS.repeat, growby=FLAGS.grow_by, growsize=FLAGS.grow_size) model.train_lineargrow(FLAGS.report_kimg << 10)
def main(argv): utils.setup_main() del argv # Unused. dataset = data.PAIR_DATASETS()[FLAGS.dataset]() log_width = utils.ilog2(dataset.width) model = FixMatch( os.path.join(FLAGS.train_dir, dataset.name, FixMatch.cta_name()), dataset, lr=FLAGS.lr, wd=FLAGS.wd, arch=FLAGS.arch, batch=FLAGS.batch, nclass=dataset.nclass, wu=FLAGS.wu, confidence=FLAGS.confidence, uratio=FLAGS.uratio, scales=FLAGS.scales or (log_width - 2), filters=FLAGS.filters, repeat=FLAGS.repeat, size_unlabeled=dataset.size_unlabeled, alpha=FLAGS.alpha, inf_warm=FLAGS.inf_warm, inner_steps=FLAGS.inner_steps, ) model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
def main(argv): del argv # Num of augmentations to perform on each image and measure consistency loss. # Performance does not significantly increase with more augmentations. assert FLAGS.nu == 2 dataset = get_dataset() log_width = utils.ilog2(dataset.width) model = RealMix(os.path.join(FLAGS.train_dir, dataset.name), dataset, lr=FLAGS.lr, wd=FLAGS.wd, arch=FLAGS.arch, batch=FLAGS.batch, nclass=dataset.nclass, ema=FLAGS.ema, beta=FLAGS.beta, w_match=FLAGS.w_match, scales=FLAGS.scales or (log_width - 2), filters=FLAGS.filters, repeat=FLAGS.repeat, tsa=FLAGS.tsa, ood_mask=FLAGS.percent_mask, augmentation=FLAGS.augment) # if FLAGS.perform_inference: # print("Performing inference...") # assert FLAGS.inference_dir and FLAGS.inference_ckpt # inference_dir = FLAGS.inference_dir # inference_ckpt = FLAGS.inference_ckpt # # images = model.session.run(memoize(default_parse(dataset([inference_dir]))).prefetch(10)) # if inference_dir[-1] != "/": # inference_dir += "/" # inference_img_paths = [path for path in glob.glob(inference_dir + "*.png")] # images = np.asarray([plt.imread(img_path) for img_path in inference_img_paths]) # images = images * (2.0 / 255) - 1.0 # model.eval_mode(ckpt=inference_ckpt) # # batch = FLAGS.batch # feed_extra = None # logits = [model.session.run(model.ops.classify_op, feed_dict={ # model.ops.x: images[0:10], **(feed_extra or {})})] # print(np.asarray(logits).shape) # print(logits) # for i in range(10): # print(np.amax(logits, axis=-1)[:, i], inference_img_paths[i]) print("Preparing to train the %s dataset with %d classes, img_size of %d, %s augmentation, %s tsa schedule, %f weight decay, and learning rate of %f using RealMix" \ % (FLAGS.dataset, FLAGS.nclass, FLAGS.img_size, FLAGS.augment, FLAGS.tsa, FLAGS.wd, FLAGS.lr)) model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
def eval_mode(self, dataset): assert self.eval is None log_scale = utils.ilog2(self.scale) model = functools.partial(self.model, dataset=dataset, total_steps=1, lod_start=log_scale, lod_stop=log_scale, lod_max=log_scale) self.eval = EvalSessionPro(model, self.checkpoint_dir, **self.params) print('Eval model %s at global_step %d' % (self.__class__.__name__, self.eval.sess.run(self.eval.global_step))) return self.eval
def __init__(self, resolution_start, resolution_stop, transition_kimg, training_kimg, stop_kimg): self.transition_nimg = transition_kimg << 10 self.training_nimg = training_kimg << 10 self.lod_start = utils.ilog2(resolution_start) self.lod_stop = utils.ilog2(resolution_stop) self.schedule = [] nimg_cur = 0 for lod in range(self.lod_start, self.lod_stop): if training_kimg: self.schedule.append( TrainPhase(nimg_cur, nimg_cur + self.training_nimg, lod, lod)) nimg_cur += self.training_nimg if transition_kimg: self.schedule.append( TrainPhase(nimg_cur, nimg_cur + self.transition_nimg, lod, lod + 1)) nimg_cur += self.transition_nimg stop_nimg = nimg_cur + self.training_nimg if stop_kimg == 0 else stop_kimg << 10 if stop_nimg > nimg_cur: self.schedule.append( TrainPhase(nimg_cur, stop_nimg, self.lod_stop, self.lod_stop)) self.schedule[0].nimg_start = 0
def main(argv): del argv # Unused. if FLAGS.dataset in DATASETS.keys(): dataset = DATASETS[FLAGS.dataset]() elif FLAGS.dataset not in DATASETS.keys() and FLAGS.custom_dataset: print("Preparing to train the " + FLAGS.dataset + " dataset.") label_size = [int(size) for size in FLAGS.label_size] valid_size = [int(size) for size in FLAGS.valid_size] if FLAGS.augment == "cifar10": augmentation = augment_cifar10 else: augmentation = augment_custom DATASETS.update([ DataSet.creator( FLAGS.dataset.split(".")[0], seed, label, valid, [augmentation, stack_augment(augmentation)], nclass=FLAGS.nclass, height=FLAGS.img_size, width=FLAGS.img_size) for seed, label, valid in itertools.product( range(2), label_size, valid_size) ]) dataset = DATASETS[FLAGS.dataset]() log_width = utils.ilog2(dataset.width) model = VAT(os.path.join(FLAGS.train_dir, dataset.name), dataset, lr=FLAGS.lr, wd=FLAGS.wd, arch=FLAGS.arch, warmup_pos=FLAGS.warmup_pos, batch=FLAGS.batch, nclass=dataset.nclass, ema=FLAGS.ema, smoothing=FLAGS.smoothing, vat=FLAGS.vat, vat_eps=FLAGS.vat_eps, entmin_weight=FLAGS.entmin_weight, scales=FLAGS.scales or (log_width - 2), filters=FLAGS.filters, repeat=FLAGS.repeat) model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
def main(argv): del argv # Unused. dataset = DATASETS[FLAGS.dataset]() log_width = utils.ilog2(dataset.width) model = FSBaseline(os.path.join(FLAGS.train_dir, dataset.name), dataset, lr=FLAGS.lr, wd=FLAGS.wd, arch=FLAGS.arch, batch=FLAGS.batch, nclass=dataset.nclass, ema=FLAGS.ema, smoothing=FLAGS.smoothing, scales=FLAGS.scales or (log_width - 2), filters=FLAGS.filters, repeat=FLAGS.repeat) model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
def main(argv): utils.setup_main() del argv # Unused. dataset = data.DATASETS()[FLAGS.dataset]() log_width = utils.ilog2(dataset.width) model = Mixup(os.path.join(FLAGS.train_dir, dataset.name), dataset, lr=FLAGS.lr, wd=FLAGS.wd, arch=FLAGS.arch, batch=FLAGS.batch, nclass=dataset.nclass, ema=FLAGS.ema, beta=FLAGS.beta, scales=FLAGS.scales or (log_width - 2), filters=FLAGS.filters, repeat=FLAGS.repeat) model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
def main(argv): utils.setup_main() del argv # Unused. dataset = data.PAIR_DATASETS()[FLAGS.dataset]() log_width = utils.ilog2(dataset.width) model = FixMatch_RA(os.path.join(FLAGS.train_dir, dataset.name), dataset, lr=FLAGS.lr, wd=FLAGS.wd, arch=FLAGS.arch, batch=FLAGS.batch, nclass=dataset.nclass, wu=FLAGS.wu, confidence=FLAGS.confidence, uratio=FLAGS.uratio, scales=FLAGS.scales or (log_width - 2), filters=FLAGS.filters, repeat=FLAGS.repeat) model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
def main(argv): del argv # Unused. dataset = data.DATASETS[FLAGS.dataset]() log_width = utils.ilog2(dataset.width) model = NST(os.path.join(FLAGS.train_dir, dataset.name), dataset, lr=FLAGS.lr, wd=FLAGS.wd, arch=FLAGS.arch, warmup_pos=FLAGS.warmup_pos, batch=FLAGS.batch, nclass=dataset.nclass, ema=FLAGS.ema, smoothing=FLAGS.smoothing, consistency_weight=FLAGS.consistency_weight, scales=FLAGS.scales or (log_width - 2), filters=FLAGS.filters, repeat=FLAGS.repeat) model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
def main(argv): utils.setup_main() del argv # Unused. dataset = data.PAIR_DATASETS()[FLAGS.dataset]() log_width = utils.ilog2(dataset.width) model = AB_FixMatch_NoCutOut( os.path.join(FLAGS.train_dir, dataset.name, FixMatch.cta_name()), dataset, lr=FLAGS.lr, wd=FLAGS.wd, arch=FLAGS.arch, batch=FLAGS.batch, nclass=dataset.nclass, wu=FLAGS.wu, confidence=FLAGS.confidence, uratio=FLAGS.uratio, scales=FLAGS.scales or (log_width - 2), filters=FLAGS.filters, repeat=FLAGS.repeat) model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10) # 512 epochs (which is 524K parameter updates)
def main(argv): del argv # Unused. # assert FLAGS.nu == 2 dataset = DATASETS[FLAGS.dataset]() log_width = utils.ilog2(dataset.width) model = MixMatch(os.path.join(FLAGS.train_dir, dataset.name), dataset, lr=FLAGS.lr, wd=FLAGS.wd, arch=FLAGS.arch, batch=FLAGS.batch, nclass=dataset.nclass, nu=FLAGS.nu, ema=FLAGS.ema, num_final_layers=FLAGS.num_final_layers, beta=FLAGS.beta, w_match=FLAGS.w_match, scales=FLAGS.scales or (log_width - 2), filters=FLAGS.filters, repeat=FLAGS.repeat) model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
def main(argv): del argv # Unused. assert FLAGS.dataset in DATASETS.keys( ) or FLAGS.custom_dataset, "Please specify a dataset which is in data.py or use --custom_dataset." if not FLAGS.custom_dataset: dataset = DATASETS[FLAGS.dataset]() else: print("Preparing to train the " + FLAGS.dataset + " dataset.") valid_size = [int(size) for size in FLAGS.valid_size] augmentation = data.augment_cifar10 # Do not name your dataset using a "-", otherwise the following line will not work for a custom dataset. DATASETS.update([ DataSetFS.creator(FLAGS.dataset.split("-")[0], [FLAGS.train_record], [FLAGS.test_record], valid, augmentation, nclass=FLAGS.nclass, height=FLAGS.img_size, width=FLAGS.img_size) for valid in valid_size ]) dataset = DATASETS[FLAGS.dataset]() log_width = utils.ilog2(dataset.width) model = FSBaseline(os.path.join(FLAGS.train_dir, dataset.name), dataset, lr=FLAGS.lr, wd=FLAGS.wd, arch=FLAGS.arch, batch=FLAGS.batch, nclass=dataset.nclass, ema=FLAGS.ema, smoothing=FLAGS.smoothing, scales=FLAGS.scales or (log_width - 2), filters=FLAGS.filters, repeat=FLAGS.repeat) model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
def main(argv): del argv # Unused. dataset = DATASETS[FLAGS.dataset]() log_width = utils.ilog2(dataset.width) model = AblationMixMatch(os.path.join(FLAGS.train_dir, dataset.name), dataset, lr=FLAGS.lr, wd=FLAGS.wd, arch=FLAGS.arch, batch=FLAGS.batch, nclass=dataset.nclass, ema=FLAGS.ema, beta=FLAGS.beta, use_ema_guess=FLAGS.use_ema_guess, T=FLAGS.T, mixmode=FLAGS.mixmode, nu=FLAGS.nu, w_match=FLAGS.w_match, warmup_kimg=FLAGS.warmup_kimg, scales=FLAGS.scales or (log_width - 2), filters=FLAGS.filters, repeat=FLAGS.repeat) model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
def main(argv): utils.setup_main() del argv # Unused. dataset = data.PAIR_DATASETS()[FLAGS.dataset]() log_width = utils.ilog2(dataset.width) model = TranslationConsistencyRegularization( os.path.join(FLAGS.train_dir, dataset.name), dataset, lr=FLAGS.lr, wd=FLAGS.wd, arch=FLAGS.arch, warmup_pos=FLAGS.warmup_pos, batch=FLAGS.batch, nclass=dataset.nclass, ema=FLAGS.ema, smoothing=FLAGS.smoothing, consistency_weight=FLAGS.consistency_weight, tcr_augment=FLAGS.tcr_augment, scales=FLAGS.scales or (log_width - 2), filters=FLAGS.filters, repeat=FLAGS.repeat) model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
def main(argv): utils.setup_main() del argv # Unused. dataset = PAIR_DATASETS()[FLAGS.dataset]() log_width = utils.ilog2(dataset.width) model = UDA(os.path.join(FLAGS.train_dir, dataset.name), dataset, lr=FLAGS.lr, wd=FLAGS.wd, wu=FLAGS.wu, we=FLAGS.we, arch=FLAGS.arch, batch=FLAGS.batch, nclass=dataset.nclass, temperature=FLAGS.temperature, tsa=FLAGS.tsa, tsa_pos=FLAGS.tsa_pos, confidence=FLAGS.confidence, uratio=FLAGS.uratio, scales=FLAGS.scales or (log_width - 2), filters=FLAGS.filters, repeat=FLAGS.repeat) model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10) # 1024 epochs
def main(argv): del argv # Unused. assert FLAGS.nu == 2 # print(DATASETS) dataset_json_path = './tfrecord/{0}-class.json'.format(FLAGS.dataset) with open(dataset_json_path, 'r') as f: label_dict = json.load(f) dataset = DataSet_2.creator( FLAGS.dataset, 0, 0, 1, [augment_cifar10, stack_augment(augment_cifar10)], colors=3, nclass=len(label_dict), height=64, width=64)() log_width = utils.ilog2(dataset.width) model = MixMatch(os.path.join(FLAGS.train_dir, dataset.name), dataset, lr=FLAGS.lr, wd=FLAGS.wd, arch=FLAGS.arch, batch=FLAGS.batch, nclass=dataset.nclass, ema=FLAGS.ema, beta=FLAGS.beta, w_match=FLAGS.w_match, scales=FLAGS.scales or (log_width - 2), filters=FLAGS.filters, repeat=FLAGS.repeat) # model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10) # model.train(FLAGS.train_kimg+1, FLAGS.train_kimg+1) print(FLAGS.train_kimg, FLAGS.report_kimg) model.train(FLAGS.train_kimg, FLAGS.report_kimg)
def disc(self, x, resolution, filters): conv_args = dict( padding='same', kernel_initializer=tf.random_normal_initializer(stddev=0.02)) log_res = utils.ilog2(resolution) def f(stage): return min(filters << stage, 512) with tf.variable_scope('disc', reuse=tf.AUTO_REUSE): y = x for r in range(log_res - 2): y = tf.layers.conv2d(y, f(r), 3, **conv_args) if r > 0: y = tf.layers.batch_normalization(y, training=True) y = tf.nn.leaky_relu(y) y = tf.layers.conv2d(y, f(r), 3, strides=2, **conv_args) y = tf.nn.leaky_relu( tf.layers.batch_normalization(y, training=True)) # single image = 4 x 4 x (filters << (log(resolution) - 3)) y = tf.layers.dense(y, 1024, activation=tf.nn.leaky_relu) y = tf.layers.dense(y, 1) return y
def model(self, dataset, scale, blocks, filters, decay_start, decay_stop, lr_decay, adv_weight, pcp_weight, layer_name, **kwargs): del kwargs x = tf.placeholder(tf.float32, [None, dataset.colors, dataset.height, dataset.width], 'x') y = tf.placeholder(tf.float32, [None, dataset.colors, None, None], 'y') log_scale = utils.ilog2(scale) cur_lr = tf.train.exponential_decay(FLAGS.lr, tf.train.get_global_step() - decay_start, decay_stop - decay_start, lr_decay) utils.HookReport.log_tensor(cur_lr, 'lr') def sres(x0, train): conv_args = dict(padding='same', data_format='channels_first', kernel_initializer=tf.random_normal_initializer(stddev=0.02)) with tf.variable_scope("sres", reuse=tf.AUTO_REUSE) as vs: x1 = x = tf.layers.conv2d(x0, filters, 3, activation=tf.nn.relu, **conv_args) # Residuals for i in range(blocks): xb = tf.layers.conv2d(x, filters, 3, **conv_args) xb = tf.layers.batch_normalization(xb, axis=1, training=train) xb = tf.nn.relu(xb) xb = tf.layers.conv2d(xb, filters, 3, **conv_args) xb = tf.layers.batch_normalization(xb, axis=1, training=train) x += xb x = tf.layers.conv2d(x, filters, 3, **conv_args) x = tf.layers.batch_normalization(x, axis=1, training=train) x += x1 # Upsampling for _ in range(log_scale): x = tf.layers.conv2d(x, filters * 4, 3, activation=tf.nn.relu, **conv_args) x = layers.channels_to_space(x) x = tf.layers.conv2d(x, x0.shape[1], 1, activation=tf.nn.tanh, **conv_args) return x def disc(x): conv_args = dict(padding='same', data_format='channels_first', kernel_initializer=tf.random_normal_initializer(stddev=0.02)) with tf.variable_scope('disc', reuse=tf.AUTO_REUSE): y = tf.layers.conv2d(x, filters, 4, strides=2, activation=tf.nn.leaky_relu, **conv_args) y = tf.layers.conv2d(y, filters * 2, 4, strides=2, **conv_args) y = tf.nn.leaky_relu(tf.layers.batch_normalization(y, axis=1, training=True)) y = tf.layers.conv2d(y, filters * 4, 4, strides=2, **conv_args) y = tf.nn.leaky_relu(tf.layers.batch_normalization(y, axis=1, training=True)) y = tf.layers.conv2d(y, filters * 8, 4, strides=2, **conv_args) y = tf.nn.leaky_relu(tf.layers.batch_normalization(y, axis=1, training=True)) if dataset.width > 32: y = tf.layers.conv2d(y, filters * 16, 4, strides=2, **conv_args) y = tf.nn.leaky_relu(tf.layers.batch_normalization(y, axis=1, training=True)) if dataset.width > 64: y = tf.layers.conv2d(y, filters * 32, 4, strides=2, **conv_args) y = tf.nn.leaky_relu(tf.layers.batch_normalization(y, axis=1, training=True)) y = tf.layers.conv2d(y, filters * 16, 1, **conv_args) y = tf.nn.leaky_relu(tf.layers.batch_normalization(y, axis=1, training=True)) if dataset.width > 32: y = tf.layers.conv2d(y, filters * 8, 1, **conv_args) y = tf.nn.leaky_relu(tf.layers.batch_normalization(y, axis=1, training=True)) y7 = y y = tf.layers.conv2d(y, filters * 2, 1, **conv_args) y = tf.nn.leaky_relu(tf.layers.batch_normalization(y, axis=1, training=True)) y = tf.layers.conv2d(y, filters * 2, 3, **conv_args) y = tf.nn.leaky_relu(tf.layers.batch_normalization(y, axis=1, training=True)) y = tf.layers.conv2d(y, filters * 8, 3, **conv_args) y8 = tf.nn.leaky_relu(y7 + tf.layers.batch_normalization(y, axis=1, training=True)) logits = tf.layers.conv2d(y8, 1, 3, **conv_args) return tf.reshape(logits, [-1, 1]) def tower(real): lores = self.downscale(real) fake = sres(lores, True) disc_real = disc(real) disc_fake = disc(fake) with tf.variable_scope('VGG', reuse=tf.AUTO_REUSE): vgg19 = vgg.Vgg19() real_embed = vgg19.build(layer_name, real, channels_last=False) fake_embed = vgg19.build(layer_name, fake, channels_last=False) loss_gmse = tf.losses.mean_squared_error(fake, real) loss_gpcp = tf.losses.mean_squared_error(real_embed, fake_embed) loss_ggan = tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake, labels=tf.ones_like(disc_fake)) loss_dreal = tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_real, labels=tf.ones_like(disc_real)) loss_dfake = tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake, labels=tf.zeros_like(disc_fake)) return (loss_gmse, loss_gpcp, tf.reduce_mean(loss_ggan), tf.reduce_mean(loss_dreal), tf.reduce_mean(loss_dfake)) loss_gmse, loss_gpcp, loss_ggan, loss_dreal, loss_dfake = utils.para_mean(tower, x) loss_disc = loss_dreal + loss_dfake loss_gen = (loss_gmse + pcp_weight * loss_gpcp + adv_weight * loss_ggan) utils.HookReport.log_tensor(loss_dreal, 'dreal') utils.HookReport.log_tensor(loss_dfake, 'dfake') utils.HookReport.log_tensor(loss_gmse, 'gmse') utils.HookReport.log_tensor(pcp_weight * loss_gpcp, 'gpcp') utils.HookReport.log_tensor(adv_weight * loss_ggan, 'ggan') utils.HookReport.log_tensor(loss_gen, 'gen') utils.HookReport.log_tensor(tf.sqrt(loss_gmse) * 127.5, 'rmse') update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_d = tf.train.AdamOptimizer(cur_lr, 0.9).minimize( loss_disc, var_list=utils.model_vars('disc'), colocate_gradients_with_ops=True) train_g = tf.train.AdamOptimizer(cur_lr, 0.9).minimize( loss_gen, var_list=utils.model_vars('sres'), colocate_gradients_with_ops=True, global_step=tf.train.get_global_step()) return EasyDict(x=x, y=y, sres_op=sres(y, False), eval_op=sres(self.downscale(x), False), train_op=tf.group(train_d, train_g))
def main(argv): del argv # Unused. assert FLAGS.dataset in DATASETS.keys( ) or FLAGS.custom_dataset, "Please specify a dataset which is in data.py or use --custom_dataset." if not FLAGS.custom_dataset: dataset = DATASETS[FLAGS.dataset]() else: print("Preparing to train the " + FLAGS.dataset + " dataset.") valid_size = [int(size) for size in FLAGS.valid_size] if FLAGS.augment == "cifar10": augmentation = data.augment_cifar10 else: augmentation = data.augment_color # Do not name your dataset using a "-", otherwise the following line will not work for a custom dataset. DATASETS.update([ DataSetFS.creator(FLAGS.dataset.split("-")[0], [FLAGS.train_record], [FLAGS.test_record], valid, augmentation, nclass=FLAGS.nclass, height=FLAGS.img_size, width=FLAGS.img_size) for valid in valid_size ]) dataset = DATASETS[FLAGS.dataset]() log_width = utils.ilog2(dataset.width) model = FSMixup(os.path.join(FLAGS.train_dir, dataset.name), dataset, lr=FLAGS.lr, wd=FLAGS.wd, arch=FLAGS.arch, batch=FLAGS.batch, nclass=dataset.nclass, ema=FLAGS.ema, beta=FLAGS.beta, scales=FLAGS.scales or (log_width - 2), filters=FLAGS.filters, repeat=FLAGS.repeat) if FLAGS.perform_inference: print("Performing inference...") assert FLAGS.inference_dir assert FLAGS.inference_ckpt inference_dir = FLAGS.inference_dir inference_ckpt = FLAGS.inference_ckpt if inference_dir[-1] != "/": inference_dir += "/" inference_img_paths = [ path for path in glob.glob(inference_dir + "*.jpg") ] images = np.asarray( [plt.imread(img_path) for img_path in inference_img_paths]) images = images * (2.0 / 255) - 1.0 model.eval_mode(ckpt=inference_ckpt) batch = FLAGS.batch feed_extra = None logits = np.concatenate([ model.session.run(model.ops.classify_op, feed_dict={ model.ops.x: images[x:x + batch], **(feed_extra or {}) }) for x in range(0, images.shape[0], batch) ], axis=0) class_dict = model.get_class_mapping() class_names = [value for key, value in class_dict.items()] gt_classes = [] for i, path in enumerate(inference_img_paths): gt_classes.append(class_names.index(path.split('_')[-1][:-4])) gt_classes = np.asarray(gt_classes) print("Overall Acc: ", (logits.argmax(1) == gt_classes).mean() * 100) np.save('predictions_fs_mixup.npy', logits.argmax(1)) else: print("Preparing to train the " + FLAGS.dataset + " dataset.") model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
def log_scale(self): return utils.ilog2(self.scale)
def main(argv): utils.setup_main() del argv # Unused. seedIndx = FLAGS.dataset.find('@') seed = int(FLAGS.dataset[seedIndx - 1]) dataset = data.PAIR_DATASETS()[FLAGS.dataset]() log_width = utils.ilog2(dataset.width) model = Frost(os.path.join(FLAGS.train_dir, dataset.name, Frost.cta_name()), dataset, lr=FLAGS.lr, wd=FLAGS.wd, arch=FLAGS.arch, batch=FLAGS.batch, nclass=dataset.nclass, wu=FLAGS.wu, wclr=FLAGS.wclr, mom=FLAGS.mom, confidence=FLAGS.confidence, balance=FLAGS.balance, delT=FLAGS.delT, uratio=FLAGS.uratio, clrratio=FLAGS.clrratio, temperature=FLAGS.temperature, scales=FLAGS.scales or (log_width - 2), filters=FLAGS.filters, repeat=FLAGS.repeat) ###################### New code tic = time.perf_counter() if FLAGS.boot_factor > 1: numIter = 2 numToLabel = [ FLAGS.boot_factor, FLAGS.boot_factor * FLAGS.boot_factor, 0 ] numImgs = [(FLAGS.train_kimg << 9), 3 * (FLAGS.train_kimg << 8), (FLAGS.train_kimg << 10)] if FLAGS.boot_schedule == 1: steps = int((FLAGS.train_kimg << 10) / 3) numImgs = [steps, 2 * steps, 3 * steps] elif FLAGS.boot_schedule == 2: numIter = 3 steps = FLAGS.train_kimg << 8 numImgs = [steps, 2 * steps, 3 * steps, 4 * steps] numToLabel = [ FLAGS.boot_factor, FLAGS.boot_factor**2, FLAGS.boot_factor**3, 0 ] datasetName = dataset.name[:dataset.name.find('.')] print("Dataset Name ", datasetName) letters = string.ascii_letters subfolder = ''.join(random.choice(letters) for i in range(8)) FLAGS.data_subfolder = subfolder tf.gfile.MakeDirs(data.DATA_DIR + '/' + subfolder) if not tf.gfile.Exists(data.DATA_DIR + '/' + subfolder + '/' + datasetName + '-unlabel.json'): infile = data.DATA_DIR + '/SSL2/' + datasetName + '-unlabel.' outfile = data.DATA_DIR + '/' + subfolder + '/' + datasetName + '-unlabel.' print("Copied from ", infile, "* to ", outfile + '*') tf.io.gfile.copy(infile + 'json', outfile + 'json') tf.io.gfile.copy(infile + 'tfrecord', outfile + 'tfrecord') for it in range(numIter): print(" Iiteration ", it, " until ", numImgs[it]) model.train(numImgs[it], FLAGS.report_kimg << 10, numToLabel[it], it) elapse = (time.perf_counter() - tic) / 3600 print("After iteration ", it, " training time ", elapse, " hours") bootstrap = CreateSplit( os.path.join(FLAGS.train_dir, dataset.name, Frost.cta_name())) bootstrap.create_split(datasetName=datasetName, seed=seed, size=numToLabel[it] * dataset.nclass) target = datasetName + '.' + str(seed) + '@' + str( numToLabel[it] * dataset.nclass) + '-1' print("Target ", target) dataset = data.PAIR_DATASETS()[target]() log_width = utils.ilog2(dataset.width) model.updateDataset(dataset) print(" Iiteration 2 until ", numImgs[numIter]) model.train(numImgs[numIter], FLAGS.report_kimg << 10, 0, numIter) tf.compat.v1.gfile.DeleteRecursively(data.DATA_DIR + '/' + subfolder) else: model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10, 0, 0) elapse = (time.perf_counter() - tic) / 3600 print(f"Total training time {elapse:0.4f} hours")