Beispiel #1
0
def main(argv):
    utils.setup_main()
    del argv  # Unused.
    dataset = data.MANY_DATASETS()[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)
    model = CTAReMixMatch(os.path.join(FLAGS.train_dir, dataset.name,
                                       CTAReMixMatch.cta_name()),
                          dataset,
                          lr=FLAGS.lr,
                          wd=FLAGS.wd,
                          arch=FLAGS.arch,
                          batch=FLAGS.batch,
                          nclass=dataset.nclass,
                          K=FLAGS.K,
                          beta=FLAGS.beta,
                          w_kl=FLAGS.w_kl,
                          w_match=FLAGS.w_match,
                          w_rot=FLAGS.w_rot,
                          redux=FLAGS.redux,
                          use_dm=FLAGS.use_dm,
                          use_xe=FLAGS.use_xe,
                          warmup_kimg=FLAGS.warmup_kimg,
                          scales=FLAGS.scales or (log_width - 2),
                          filters=FLAGS.filters,
                          repeat=FLAGS.repeat)
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
Beispiel #2
0
def main(argv):
    del argv  # Unused.
    dataset = DATASETS[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)

    # Generating model directory
    if FLAGS.gamma == None:
        model_dir = 'ERM'
    elif FLAGS.gamma > 0:
        model_dir = 'WDRO_' + str(FLAGS.gamma)
    else:
        assert False, 'Check the regularization parameter gamma'

    model = FSBaseline(os.path.join(FLAGS.train_dir, model_dir, dataset.name),
                       dataset,
                       lr=FLAGS.lr,
                       wd=FLAGS.wd,
                       arch=FLAGS.arch,
                       batch=FLAGS.batch,
                       nclass=dataset.nclass,
                       ema=FLAGS.ema,
                       smoothing=FLAGS.smoothing,
                       scales=FLAGS.scales or (log_width - 2),
                       filters=FLAGS.filters,
                       repeat=FLAGS.repeat,
                       gamma=FLAGS.gamma)
    model.train(FLAGS.nckpt * FLAGS.ckptsize,
                FLAGS.ckptsize)  #(total # of data, ckpt size)
Beispiel #3
0
Datei: cgan.py Projekt: xvdp/lag
    def disc(self, x, x_lr, resolution, filters):
        conv_args = dict(padding='same',
                         kernel_initializer=tf.glorot_uniform_initializer())
        lr_h, lr_w, lr_c = [
            tf.cast(v, tf.float32) for v in utils.smart_shape(x_lr)[1:]
        ]
        colors = utils.smart_shape(x)[3]
        log_res = utils.ilog2(resolution)

        with tf.variable_scope('disc', reuse=tf.AUTO_REUSE):
            h = x
            h = d_optimized_resnet_block(h, filters)
            h = d_resnet_block(h, filters)
            for block in range(1, self.log_scale - 1):
                h = d_resnet_block(h, filters << block, downsample=True)
                h = d_resnet_block(h, filters << block)
            h = d_resnet_block(h,
                               filters << (self.log_scale - 1),
                               downsample=True)
            h = d_resnet_block(h, filters << (self.log_scale - 1))
            lr_disc = layers.conv2d_spectral_norm(h, colors, 3, **
                                                  conv_args) * x_lr
            lr_disc = tf.reduce_sum(lr_disc, [1, 2, 3]) * tf.rsqrt(
                lr_h * lr_w * lr_c)
            lr_disc = tf.reshape(lr_disc, [-1, 1])
            for block in range(self.log_scale, log_res - 2):
                h = d_resnet_block(h, filters << block, downsample=True)
            h = d_resnet_block(h, filters << block)
            h = tf.reduce_sum(h, [1, 2]) * (1 / 4.)
            hr_disc = layers.dense_spectral_norm(
                h, 1, kernel_initializer=tf.glorot_uniform_initializer())
            return lr_disc + hr_disc
Beispiel #4
0
def main(argv):
    del argv  # Unused.
    dataset = DATASETS[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)

    #generating model directory
    if FLAGS.gamma == None:
        model_dir = 'MIXUP'
    elif FLAGS.gamma > 0:
        model_dir = 'WDRO_MIX_' + str(FLAGS.gamma)
    else:
        assert False, 'Check the penalty parameter gamma'

    model = MixupGrad(os.path.join(FLAGS.train_dir, model_dir, dataset.name),
                      dataset,
                      lr=FLAGS.lr,
                      wd=FLAGS.wd,
                      arch=FLAGS.arch,
                      batch=FLAGS.batch,
                      nclass=dataset.nclass,
                      ema=FLAGS.ema,
                      beta=FLAGS.beta,
                      gamma=FLAGS.gamma,
                      scales=FLAGS.scales or (log_width - 2),
                      filters=FLAGS.filters,
                      repeat=FLAGS.repeat)
    model.train(FLAGS.nckpt * FLAGS.ckptsize,
                FLAGS.ckptsize)  #(total # of data, epoch size)
def main(argv):
    del argv  # Unused.
    assert FLAGS.dataset.split('.')[0] in [
        'cifar10', 'cifar100', 'svhn', 'svhn_extra'
    ]
    dataset = DATASETS[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)
    model = MixMatch_LinearGrow(
        os.path.join(FLAGS.train_dir,
                     dataset.name.split('@')[0] + '_train' + \
                     dataset.name.split('train')[-1] + '_Grow'),
        dataset,
        lr=FLAGS.lr,
        wd=FLAGS.wd,
        arch=FLAGS.arch,
        batch=FLAGS.batch,
        nclass=dataset.nclass,
        ema=FLAGS.ema,
        beta=FLAGS.beta,
        logit_norm=FLAGS.logit_norm,
        T=FLAGS.T,
        mixmode=FLAGS.mixmode,
        nu=FLAGS.nu,
        dbuf=FLAGS.dbuf,
        w_match=FLAGS.w_match,
        warmup_kimg=FLAGS.warmup_kimg,
        scales=FLAGS.scales or (log_width - 2),
        filters=FLAGS.filters,
        repeat=FLAGS.repeat,
        growby=FLAGS.grow_by,
        growsize=FLAGS.grow_size)
    model.train_lineargrow(FLAGS.report_kimg << 10)
Beispiel #6
0
def main(argv):
    utils.setup_main()
    del argv  # Unused.
    dataset = data.PAIR_DATASETS()[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)
    model = FixMatch(
        os.path.join(FLAGS.train_dir, dataset.name, FixMatch.cta_name()),
        dataset,
        lr=FLAGS.lr,
        wd=FLAGS.wd,
        arch=FLAGS.arch,
        batch=FLAGS.batch,
        nclass=dataset.nclass,
        wu=FLAGS.wu,
        confidence=FLAGS.confidence,
        uratio=FLAGS.uratio,
        scales=FLAGS.scales or (log_width - 2),
        filters=FLAGS.filters,
        repeat=FLAGS.repeat,
        size_unlabeled=dataset.size_unlabeled,
        alpha=FLAGS.alpha,
        inf_warm=FLAGS.inf_warm,
        inner_steps=FLAGS.inner_steps,
    )
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
Beispiel #7
0
def main(argv):
    del argv

    # Num of augmentations to perform on each image and measure consistency loss.
    # Performance does not significantly increase with more augmentations.
    assert FLAGS.nu == 2

    dataset = get_dataset()

    log_width = utils.ilog2(dataset.width)
    model = RealMix(os.path.join(FLAGS.train_dir, dataset.name),
                    dataset,
                    lr=FLAGS.lr,
                    wd=FLAGS.wd,
                    arch=FLAGS.arch,
                    batch=FLAGS.batch,
                    nclass=dataset.nclass,
                    ema=FLAGS.ema,
                    beta=FLAGS.beta,
                    w_match=FLAGS.w_match,
                    scales=FLAGS.scales or (log_width - 2),
                    filters=FLAGS.filters,
                    repeat=FLAGS.repeat,
                    tsa=FLAGS.tsa,
                    ood_mask=FLAGS.percent_mask,
                    augmentation=FLAGS.augment)

    # if FLAGS.perform_inference:
    #     print("Performing inference...")
    #     assert FLAGS.inference_dir and FLAGS.inference_ckpt
    #     inference_dir = FLAGS.inference_dir
    #     inference_ckpt = FLAGS.inference_ckpt

    #     # images = model.session.run(memoize(default_parse(dataset([inference_dir]))).prefetch(10))

    #     if inference_dir[-1] != "/":
    #         inference_dir += "/"
    #     inference_img_paths = [path for path in glob.glob(inference_dir + "*.png")]
    #     images = np.asarray([plt.imread(img_path) for img_path in inference_img_paths])
    #     images = images * (2.0 / 255) - 1.0
    #     model.eval_mode(ckpt=inference_ckpt)
    #     # batch = FLAGS.batch
    #     feed_extra = None
    #     logits = [model.session.run(model.ops.classify_op, feed_dict={
    #         model.ops.x: images[0:10], **(feed_extra or {})})]

    #     print(np.asarray(logits).shape)
    #     print(logits)
    #     for i in range(10):
    #         print(np.amax(logits, axis=-1)[:, i], inference_img_paths[i])

    print("Preparing to train the %s dataset with %d classes, img_size of %d, %s augmentation, %s tsa schedule, %f weight decay, and learning rate of %f using RealMix" \
            % (FLAGS.dataset, FLAGS.nclass, FLAGS.img_size, FLAGS.augment, FLAGS.tsa, FLAGS.wd, FLAGS.lr))
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
Beispiel #8
0
 def eval_mode(self, dataset):
     assert self.eval is None
     log_scale = utils.ilog2(self.scale)
     model = functools.partial(self.model,
                               dataset=dataset,
                               total_steps=1,
                               lod_start=log_scale,
                               lod_stop=log_scale,
                               lod_max=log_scale)
     self.eval = EvalSessionPro(model, self.checkpoint_dir, **self.params)
     print('Eval model %s at global_step %d' %
           (self.__class__.__name__,
            self.eval.sess.run(self.eval.global_step)))
     return self.eval
Beispiel #9
0
 def __init__(self, resolution_start, resolution_stop, transition_kimg,
              training_kimg, stop_kimg):
     self.transition_nimg = transition_kimg << 10
     self.training_nimg = training_kimg << 10
     self.lod_start = utils.ilog2(resolution_start)
     self.lod_stop = utils.ilog2(resolution_stop)
     self.schedule = []
     nimg_cur = 0
     for lod in range(self.lod_start, self.lod_stop):
         if training_kimg:
             self.schedule.append(
                 TrainPhase(nimg_cur, nimg_cur + self.training_nimg, lod,
                            lod))
             nimg_cur += self.training_nimg
         if transition_kimg:
             self.schedule.append(
                 TrainPhase(nimg_cur, nimg_cur + self.transition_nimg, lod,
                            lod + 1))
             nimg_cur += self.transition_nimg
     stop_nimg = nimg_cur + self.training_nimg if stop_kimg == 0 else stop_kimg << 10
     if stop_nimg > nimg_cur:
         self.schedule.append(
             TrainPhase(nimg_cur, stop_nimg, self.lod_stop, self.lod_stop))
     self.schedule[0].nimg_start = 0
Beispiel #10
0
def main(argv):
    del argv  # Unused.
    if FLAGS.dataset in DATASETS.keys():
        dataset = DATASETS[FLAGS.dataset]()
    elif FLAGS.dataset not in DATASETS.keys() and FLAGS.custom_dataset:
        print("Preparing to train the " + FLAGS.dataset + " dataset.")
        label_size = [int(size) for size in FLAGS.label_size]
        valid_size = [int(size) for size in FLAGS.valid_size]

        if FLAGS.augment == "cifar10":
            augmentation = augment_cifar10
        else:
            augmentation = augment_custom

        DATASETS.update([
            DataSet.creator(
                FLAGS.dataset.split(".")[0],
                seed,
                label,
                valid,
                [augmentation, stack_augment(augmentation)],
                nclass=FLAGS.nclass,
                height=FLAGS.img_size,
                width=FLAGS.img_size)
            for seed, label, valid in itertools.product(
                range(2), label_size, valid_size)
        ])
        dataset = DATASETS[FLAGS.dataset]()

    log_width = utils.ilog2(dataset.width)
    model = VAT(os.path.join(FLAGS.train_dir, dataset.name),
                dataset,
                lr=FLAGS.lr,
                wd=FLAGS.wd,
                arch=FLAGS.arch,
                warmup_pos=FLAGS.warmup_pos,
                batch=FLAGS.batch,
                nclass=dataset.nclass,
                ema=FLAGS.ema,
                smoothing=FLAGS.smoothing,
                vat=FLAGS.vat,
                vat_eps=FLAGS.vat_eps,
                entmin_weight=FLAGS.entmin_weight,
                scales=FLAGS.scales or (log_width - 2),
                filters=FLAGS.filters,
                repeat=FLAGS.repeat)
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
Beispiel #11
0
def main(argv):
    del argv  # Unused.
    dataset = DATASETS[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)
    model = FSBaseline(os.path.join(FLAGS.train_dir, dataset.name),
                       dataset,
                       lr=FLAGS.lr,
                       wd=FLAGS.wd,
                       arch=FLAGS.arch,
                       batch=FLAGS.batch,
                       nclass=dataset.nclass,
                       ema=FLAGS.ema,
                       smoothing=FLAGS.smoothing,
                       scales=FLAGS.scales or (log_width - 2),
                       filters=FLAGS.filters,
                       repeat=FLAGS.repeat)
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
Beispiel #12
0
def main(argv):
    utils.setup_main()
    del argv  # Unused.
    dataset = data.DATASETS()[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)
    model = Mixup(os.path.join(FLAGS.train_dir, dataset.name),
                  dataset,
                  lr=FLAGS.lr,
                  wd=FLAGS.wd,
                  arch=FLAGS.arch,
                  batch=FLAGS.batch,
                  nclass=dataset.nclass,
                  ema=FLAGS.ema,
                  beta=FLAGS.beta,
                  scales=FLAGS.scales or (log_width - 2),
                  filters=FLAGS.filters,
                  repeat=FLAGS.repeat)
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
Beispiel #13
0
def main(argv):
    utils.setup_main()
    del argv  # Unused.
    dataset = data.PAIR_DATASETS()[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)
    model = FixMatch_RA(os.path.join(FLAGS.train_dir, dataset.name),
                        dataset,
                        lr=FLAGS.lr,
                        wd=FLAGS.wd,
                        arch=FLAGS.arch,
                        batch=FLAGS.batch,
                        nclass=dataset.nclass,
                        wu=FLAGS.wu,
                        confidence=FLAGS.confidence,
                        uratio=FLAGS.uratio,
                        scales=FLAGS.scales or (log_width - 2),
                        filters=FLAGS.filters,
                        repeat=FLAGS.repeat)
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
Beispiel #14
0
def main(argv):
    del argv  # Unused.
    dataset = data.DATASETS[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)
    model = NST(os.path.join(FLAGS.train_dir, dataset.name),
                dataset,
                lr=FLAGS.lr,
                wd=FLAGS.wd,
                arch=FLAGS.arch,
                warmup_pos=FLAGS.warmup_pos,
                batch=FLAGS.batch,
                nclass=dataset.nclass,
                ema=FLAGS.ema,
                smoothing=FLAGS.smoothing,
                consistency_weight=FLAGS.consistency_weight,
                scales=FLAGS.scales or (log_width - 2),
                filters=FLAGS.filters,
                repeat=FLAGS.repeat)
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
Beispiel #15
0
def main(argv):
    utils.setup_main()
    del argv  # Unused.
    dataset = data.PAIR_DATASETS()[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)
    model = AB_FixMatch_NoCutOut(
        os.path.join(FLAGS.train_dir, dataset.name, FixMatch.cta_name()),
        dataset,
        lr=FLAGS.lr,
        wd=FLAGS.wd,
        arch=FLAGS.arch,
        batch=FLAGS.batch,
        nclass=dataset.nclass,
        wu=FLAGS.wu,
        confidence=FLAGS.confidence,
        uratio=FLAGS.uratio,
        scales=FLAGS.scales or (log_width - 2),
        filters=FLAGS.filters,
        repeat=FLAGS.repeat)
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)  # 512 epochs (which is 524K parameter updates)
Beispiel #16
0
def main(argv):
    del argv  # Unused.
    # assert FLAGS.nu == 2
    dataset = DATASETS[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)
    model = MixMatch(os.path.join(FLAGS.train_dir, dataset.name),
                     dataset,
                     lr=FLAGS.lr,
                     wd=FLAGS.wd,
                     arch=FLAGS.arch,
                     batch=FLAGS.batch,
                     nclass=dataset.nclass,
                     nu=FLAGS.nu,
                     ema=FLAGS.ema,
                     num_final_layers=FLAGS.num_final_layers,
                     beta=FLAGS.beta,
                     w_match=FLAGS.w_match,
                     scales=FLAGS.scales or (log_width - 2),
                     filters=FLAGS.filters,
                     repeat=FLAGS.repeat)
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
Beispiel #17
0
def main(argv):
    del argv  # Unused.

    assert FLAGS.dataset in DATASETS.keys(
    ) or FLAGS.custom_dataset, "Please specify a dataset which is in data.py or use --custom_dataset."

    if not FLAGS.custom_dataset:
        dataset = DATASETS[FLAGS.dataset]()
    else:
        print("Preparing to train the " + FLAGS.dataset + " dataset.")
        valid_size = [int(size) for size in FLAGS.valid_size]

        augmentation = data.augment_cifar10

        # Do not name your dataset using a "-", otherwise the following line will not work for a custom dataset.
        DATASETS.update([
            DataSetFS.creator(FLAGS.dataset.split("-")[0],
                              [FLAGS.train_record], [FLAGS.test_record],
                              valid,
                              augmentation,
                              nclass=FLAGS.nclass,
                              height=FLAGS.img_size,
                              width=FLAGS.img_size) for valid in valid_size
        ])
        dataset = DATASETS[FLAGS.dataset]()

    log_width = utils.ilog2(dataset.width)
    model = FSBaseline(os.path.join(FLAGS.train_dir, dataset.name),
                       dataset,
                       lr=FLAGS.lr,
                       wd=FLAGS.wd,
                       arch=FLAGS.arch,
                       batch=FLAGS.batch,
                       nclass=dataset.nclass,
                       ema=FLAGS.ema,
                       smoothing=FLAGS.smoothing,
                       scales=FLAGS.scales or (log_width - 2),
                       filters=FLAGS.filters,
                       repeat=FLAGS.repeat)
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
Beispiel #18
0
def main(argv):
    del argv  # Unused.
    dataset = DATASETS[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)
    model = AblationMixMatch(os.path.join(FLAGS.train_dir, dataset.name),
                             dataset,
                             lr=FLAGS.lr,
                             wd=FLAGS.wd,
                             arch=FLAGS.arch,
                             batch=FLAGS.batch,
                             nclass=dataset.nclass,
                             ema=FLAGS.ema,
                             beta=FLAGS.beta,
                             use_ema_guess=FLAGS.use_ema_guess,
                             T=FLAGS.T,
                             mixmode=FLAGS.mixmode,
                             nu=FLAGS.nu,
                             w_match=FLAGS.w_match,
                             warmup_kimg=FLAGS.warmup_kimg,
                             scales=FLAGS.scales or (log_width - 2),
                             filters=FLAGS.filters,
                             repeat=FLAGS.repeat)
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
Beispiel #19
0
def main(argv):
    utils.setup_main()
    del argv  # Unused.
    dataset = data.PAIR_DATASETS()[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)
    model = TranslationConsistencyRegularization(
        os.path.join(FLAGS.train_dir, dataset.name),
        dataset,
        lr=FLAGS.lr,
        wd=FLAGS.wd,
        arch=FLAGS.arch,
        warmup_pos=FLAGS.warmup_pos,
        batch=FLAGS.batch,
        nclass=dataset.nclass,
        ema=FLAGS.ema,
        smoothing=FLAGS.smoothing,
        consistency_weight=FLAGS.consistency_weight,
        tcr_augment=FLAGS.tcr_augment,

        scales=FLAGS.scales or (log_width - 2),
        filters=FLAGS.filters,
        repeat=FLAGS.repeat)
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
Beispiel #20
0
def main(argv):
    utils.setup_main()
    del argv  # Unused.
    dataset = PAIR_DATASETS()[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)
    model = UDA(os.path.join(FLAGS.train_dir, dataset.name),
                dataset,
                lr=FLAGS.lr,
                wd=FLAGS.wd,
                wu=FLAGS.wu,
                we=FLAGS.we,
                arch=FLAGS.arch,
                batch=FLAGS.batch,
                nclass=dataset.nclass,
                temperature=FLAGS.temperature,
                tsa=FLAGS.tsa,
                tsa_pos=FLAGS.tsa_pos,
                confidence=FLAGS.confidence,
                uratio=FLAGS.uratio,
                scales=FLAGS.scales or (log_width - 2),
                filters=FLAGS.filters,
                repeat=FLAGS.repeat)
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)  # 1024 epochs
Beispiel #21
0
def main(argv):
    del argv  # Unused.
    assert FLAGS.nu == 2
    # print(DATASETS)
    dataset_json_path = './tfrecord/{0}-class.json'.format(FLAGS.dataset)
    with open(dataset_json_path, 'r') as f:
        label_dict = json.load(f)

    dataset = DataSet_2.creator(
        FLAGS.dataset,
        0,
        0,
        1, [augment_cifar10, stack_augment(augment_cifar10)],
        colors=3,
        nclass=len(label_dict),
        height=64,
        width=64)()

    log_width = utils.ilog2(dataset.width)
    model = MixMatch(os.path.join(FLAGS.train_dir, dataset.name),
                     dataset,
                     lr=FLAGS.lr,
                     wd=FLAGS.wd,
                     arch=FLAGS.arch,
                     batch=FLAGS.batch,
                     nclass=dataset.nclass,
                     ema=FLAGS.ema,
                     beta=FLAGS.beta,
                     w_match=FLAGS.w_match,
                     scales=FLAGS.scales or (log_width - 2),
                     filters=FLAGS.filters,
                     repeat=FLAGS.repeat)
    # model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
    # model.train(FLAGS.train_kimg+1, FLAGS.train_kimg+1)
    print(FLAGS.train_kimg, FLAGS.report_kimg)
    model.train(FLAGS.train_kimg, FLAGS.report_kimg)
Beispiel #22
0
    def disc(self, x, resolution, filters):
        conv_args = dict(
            padding='same',
            kernel_initializer=tf.random_normal_initializer(stddev=0.02))
        log_res = utils.ilog2(resolution)

        def f(stage):
            return min(filters << stage, 512)

        with tf.variable_scope('disc', reuse=tf.AUTO_REUSE):
            y = x
            for r in range(log_res - 2):
                y = tf.layers.conv2d(y, f(r), 3, **conv_args)
                if r > 0:
                    y = tf.layers.batch_normalization(y, training=True)
                y = tf.nn.leaky_relu(y)
                y = tf.layers.conv2d(y, f(r), 3, strides=2, **conv_args)
                y = tf.nn.leaky_relu(
                    tf.layers.batch_normalization(y, training=True))

            # single image = 4 x 4 x (filters << (log(resolution) - 3))
            y = tf.layers.dense(y, 1024, activation=tf.nn.leaky_relu)
            y = tf.layers.dense(y, 1)
            return y
Beispiel #23
0
    def model(self, dataset, scale, blocks, filters, decay_start, decay_stop, lr_decay,
              adv_weight, pcp_weight, layer_name, **kwargs):
        del kwargs
        x = tf.placeholder(tf.float32, [None, dataset.colors, dataset.height, dataset.width], 'x')
        y = tf.placeholder(tf.float32, [None, dataset.colors, None, None], 'y')

        log_scale = utils.ilog2(scale)
        cur_lr = tf.train.exponential_decay(FLAGS.lr, tf.train.get_global_step() - decay_start,
                                            decay_stop - decay_start, lr_decay)
        utils.HookReport.log_tensor(cur_lr, 'lr')

        def sres(x0, train):
            conv_args = dict(padding='same', data_format='channels_first',
                             kernel_initializer=tf.random_normal_initializer(stddev=0.02))

            with tf.variable_scope("sres", reuse=tf.AUTO_REUSE) as vs:
                x1 = x = tf.layers.conv2d(x0, filters, 3, activation=tf.nn.relu, **conv_args)

                # Residuals
                for i in range(blocks):
                    xb = tf.layers.conv2d(x, filters, 3, **conv_args)
                    xb = tf.layers.batch_normalization(xb, axis=1, training=train)
                    xb = tf.nn.relu(xb)
                    xb = tf.layers.conv2d(xb, filters, 3, **conv_args)
                    xb = tf.layers.batch_normalization(xb, axis=1, training=train)
                    x += xb

                x = tf.layers.conv2d(x, filters, 3, **conv_args)
                x = tf.layers.batch_normalization(x, axis=1, training=train)
                x += x1

                # Upsampling
                for _ in range(log_scale):
                    x = tf.layers.conv2d(x, filters * 4, 3, activation=tf.nn.relu, **conv_args)
                    x = layers.channels_to_space(x)

                x = tf.layers.conv2d(x, x0.shape[1], 1, activation=tf.nn.tanh, **conv_args)
                return x

        def disc(x):
            conv_args = dict(padding='same', data_format='channels_first',
                             kernel_initializer=tf.random_normal_initializer(stddev=0.02))

            with tf.variable_scope('disc', reuse=tf.AUTO_REUSE):
                y = tf.layers.conv2d(x, filters, 4, strides=2, activation=tf.nn.leaky_relu, **conv_args)
                y = tf.layers.conv2d(y, filters * 2, 4, strides=2, **conv_args)
                y = tf.nn.leaky_relu(tf.layers.batch_normalization(y, axis=1, training=True))
                y = tf.layers.conv2d(y, filters * 4, 4, strides=2, **conv_args)
                y = tf.nn.leaky_relu(tf.layers.batch_normalization(y, axis=1, training=True))
                y = tf.layers.conv2d(y, filters * 8, 4, strides=2, **conv_args)
                y = tf.nn.leaky_relu(tf.layers.batch_normalization(y, axis=1, training=True))
                if dataset.width > 32:
                    y = tf.layers.conv2d(y, filters * 16, 4, strides=2, **conv_args)
                    y = tf.nn.leaky_relu(tf.layers.batch_normalization(y, axis=1, training=True))
                if dataset.width > 64:
                    y = tf.layers.conv2d(y, filters * 32, 4, strides=2, **conv_args)
                    y = tf.nn.leaky_relu(tf.layers.batch_normalization(y, axis=1, training=True))
                    y = tf.layers.conv2d(y, filters * 16, 1, **conv_args)
                    y = tf.nn.leaky_relu(tf.layers.batch_normalization(y, axis=1, training=True))
                if dataset.width > 32:
                    y = tf.layers.conv2d(y, filters * 8, 1, **conv_args)
                    y = tf.nn.leaky_relu(tf.layers.batch_normalization(y, axis=1, training=True))
                y7 = y
                y = tf.layers.conv2d(y, filters * 2, 1, **conv_args)
                y = tf.nn.leaky_relu(tf.layers.batch_normalization(y, axis=1, training=True))
                y = tf.layers.conv2d(y, filters * 2, 3, **conv_args)
                y = tf.nn.leaky_relu(tf.layers.batch_normalization(y, axis=1, training=True))
                y = tf.layers.conv2d(y, filters * 8, 3, **conv_args)
                y8 = tf.nn.leaky_relu(y7 + tf.layers.batch_normalization(y, axis=1, training=True))
                logits = tf.layers.conv2d(y8, 1, 3, **conv_args)
                return tf.reshape(logits, [-1, 1])

        def tower(real):
            lores = self.downscale(real)
            fake = sres(lores, True)
            disc_real = disc(real)
            disc_fake = disc(fake)

            with tf.variable_scope('VGG', reuse=tf.AUTO_REUSE):
                vgg19 = vgg.Vgg19()
                real_embed = vgg19.build(layer_name, real, channels_last=False)
                fake_embed = vgg19.build(layer_name, fake, channels_last=False)

            loss_gmse = tf.losses.mean_squared_error(fake, real)
            loss_gpcp = tf.losses.mean_squared_error(real_embed, fake_embed)
            loss_ggan = tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake, labels=tf.ones_like(disc_fake))
            loss_dreal = tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_real, labels=tf.ones_like(disc_real))
            loss_dfake = tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake, labels=tf.zeros_like(disc_fake))
            return (loss_gmse, loss_gpcp,
                    tf.reduce_mean(loss_ggan), tf.reduce_mean(loss_dreal), tf.reduce_mean(loss_dfake))

        loss_gmse, loss_gpcp, loss_ggan, loss_dreal, loss_dfake = utils.para_mean(tower, x)
        loss_disc = loss_dreal + loss_dfake
        loss_gen = (loss_gmse
                    + pcp_weight * loss_gpcp
                    + adv_weight * loss_ggan)

        utils.HookReport.log_tensor(loss_dreal, 'dreal')
        utils.HookReport.log_tensor(loss_dfake, 'dfake')
        utils.HookReport.log_tensor(loss_gmse, 'gmse')
        utils.HookReport.log_tensor(pcp_weight * loss_gpcp, 'gpcp')
        utils.HookReport.log_tensor(adv_weight * loss_ggan, 'ggan')
        utils.HookReport.log_tensor(loss_gen, 'gen')
        utils.HookReport.log_tensor(tf.sqrt(loss_gmse) * 127.5, 'rmse')

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_d = tf.train.AdamOptimizer(cur_lr, 0.9).minimize(
                loss_disc, var_list=utils.model_vars('disc'),
                colocate_gradients_with_ops=True)
            train_g = tf.train.AdamOptimizer(cur_lr, 0.9).minimize(
                loss_gen, var_list=utils.model_vars('sres'),
                colocate_gradients_with_ops=True,
                global_step=tf.train.get_global_step())

        return EasyDict(x=x, y=y, sres_op=sres(y, False), eval_op=sres(self.downscale(x), False),
                        train_op=tf.group(train_d, train_g))
Beispiel #24
0
def main(argv):
    del argv  # Unused.
    assert FLAGS.dataset in DATASETS.keys(
    ) or FLAGS.custom_dataset, "Please specify a dataset which is in data.py or use --custom_dataset."

    if not FLAGS.custom_dataset:
        dataset = DATASETS[FLAGS.dataset]()
    else:
        print("Preparing to train the " + FLAGS.dataset + " dataset.")
        valid_size = [int(size) for size in FLAGS.valid_size]

        if FLAGS.augment == "cifar10":
            augmentation = data.augment_cifar10
        else:
            augmentation = data.augment_color

        # Do not name your dataset using a "-", otherwise the following line will not work for a custom dataset.
        DATASETS.update([
            DataSetFS.creator(FLAGS.dataset.split("-")[0],
                              [FLAGS.train_record], [FLAGS.test_record],
                              valid,
                              augmentation,
                              nclass=FLAGS.nclass,
                              height=FLAGS.img_size,
                              width=FLAGS.img_size) for valid in valid_size
        ])
        dataset = DATASETS[FLAGS.dataset]()

    log_width = utils.ilog2(dataset.width)
    model = FSMixup(os.path.join(FLAGS.train_dir, dataset.name),
                    dataset,
                    lr=FLAGS.lr,
                    wd=FLAGS.wd,
                    arch=FLAGS.arch,
                    batch=FLAGS.batch,
                    nclass=dataset.nclass,
                    ema=FLAGS.ema,
                    beta=FLAGS.beta,
                    scales=FLAGS.scales or (log_width - 2),
                    filters=FLAGS.filters,
                    repeat=FLAGS.repeat)

    if FLAGS.perform_inference:
        print("Performing inference...")
        assert FLAGS.inference_dir
        assert FLAGS.inference_ckpt
        inference_dir = FLAGS.inference_dir
        inference_ckpt = FLAGS.inference_ckpt

        if inference_dir[-1] != "/":
            inference_dir += "/"
        inference_img_paths = [
            path for path in glob.glob(inference_dir + "*.jpg")
        ]
        images = np.asarray(
            [plt.imread(img_path) for img_path in inference_img_paths])
        images = images * (2.0 / 255) - 1.0
        model.eval_mode(ckpt=inference_ckpt)
        batch = FLAGS.batch
        feed_extra = None
        logits = np.concatenate([
            model.session.run(model.ops.classify_op,
                              feed_dict={
                                  model.ops.x: images[x:x + batch],
                                  **(feed_extra or {})
                              }) for x in range(0, images.shape[0], batch)
        ],
                                axis=0)
        class_dict = model.get_class_mapping()
        class_names = [value for key, value in class_dict.items()]
        gt_classes = []

        for i, path in enumerate(inference_img_paths):
            gt_classes.append(class_names.index(path.split('_')[-1][:-4]))

        gt_classes = np.asarray(gt_classes)
        print("Overall Acc: ", (logits.argmax(1) == gt_classes).mean() * 100)

        np.save('predictions_fs_mixup.npy', logits.argmax(1))
    else:
        print("Preparing to train the " + FLAGS.dataset + " dataset.")
        model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
Beispiel #25
0
 def log_scale(self):
     return utils.ilog2(self.scale)
Beispiel #26
0
def main(argv):
    utils.setup_main()
    del argv  # Unused.
    seedIndx = FLAGS.dataset.find('@')
    seed = int(FLAGS.dataset[seedIndx - 1])

    dataset = data.PAIR_DATASETS()[FLAGS.dataset]()
    log_width = utils.ilog2(dataset.width)
    model = Frost(os.path.join(FLAGS.train_dir, dataset.name,
                               Frost.cta_name()),
                  dataset,
                  lr=FLAGS.lr,
                  wd=FLAGS.wd,
                  arch=FLAGS.arch,
                  batch=FLAGS.batch,
                  nclass=dataset.nclass,
                  wu=FLAGS.wu,
                  wclr=FLAGS.wclr,
                  mom=FLAGS.mom,
                  confidence=FLAGS.confidence,
                  balance=FLAGS.balance,
                  delT=FLAGS.delT,
                  uratio=FLAGS.uratio,
                  clrratio=FLAGS.clrratio,
                  temperature=FLAGS.temperature,
                  scales=FLAGS.scales or (log_width - 2),
                  filters=FLAGS.filters,
                  repeat=FLAGS.repeat)
    ###################### New code
    tic = time.perf_counter()
    if FLAGS.boot_factor > 1:
        numIter = 2
        numToLabel = [
            FLAGS.boot_factor, FLAGS.boot_factor * FLAGS.boot_factor, 0
        ]
        numImgs = [(FLAGS.train_kimg << 9), 3 * (FLAGS.train_kimg << 8),
                   (FLAGS.train_kimg << 10)]
        if FLAGS.boot_schedule == 1:
            steps = int((FLAGS.train_kimg << 10) / 3)
            numImgs = [steps, 2 * steps, 3 * steps]
        elif FLAGS.boot_schedule == 2:
            numIter = 3
            steps = FLAGS.train_kimg << 8
            numImgs = [steps, 2 * steps, 3 * steps, 4 * steps]
            numToLabel = [
                FLAGS.boot_factor, FLAGS.boot_factor**2, FLAGS.boot_factor**3,
                0
            ]

        datasetName = dataset.name[:dataset.name.find('.')]
        print("Dataset Name ", datasetName)
        letters = string.ascii_letters
        subfolder = ''.join(random.choice(letters) for i in range(8))
        FLAGS.data_subfolder = subfolder
        tf.gfile.MakeDirs(data.DATA_DIR + '/' + subfolder)
        if not tf.gfile.Exists(data.DATA_DIR + '/' + subfolder + '/' +
                               datasetName + '-unlabel.json'):
            infile = data.DATA_DIR + '/SSL2/' + datasetName + '-unlabel.'
            outfile = data.DATA_DIR + '/' + subfolder + '/' + datasetName + '-unlabel.'
            print("Copied from ", infile, "* to ", outfile + '*')
            tf.io.gfile.copy(infile + 'json', outfile + 'json')
            tf.io.gfile.copy(infile + 'tfrecord', outfile + 'tfrecord')

        for it in range(numIter):
            print(" Iiteration ", it, " until ", numImgs[it])
            model.train(numImgs[it], FLAGS.report_kimg << 10, numToLabel[it],
                        it)
            elapse = (time.perf_counter() - tic) / 3600
            print("After iteration ", it, " training time ", elapse, " hours")

            bootstrap = CreateSplit(
                os.path.join(FLAGS.train_dir, dataset.name, Frost.cta_name()))
            bootstrap.create_split(datasetName=datasetName,
                                   seed=seed,
                                   size=numToLabel[it] * dataset.nclass)

            target = datasetName + '.' + str(seed) + '@' + str(
                numToLabel[it] * dataset.nclass) + '-1'
            print("Target ", target)
            dataset = data.PAIR_DATASETS()[target]()
            log_width = utils.ilog2(dataset.width)
            model.updateDataset(dataset)

        print(" Iiteration 2 until ", numImgs[numIter])
        model.train(numImgs[numIter], FLAGS.report_kimg << 10, 0, numIter)
        tf.compat.v1.gfile.DeleteRecursively(data.DATA_DIR + '/' + subfolder)
    else:
        model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10, 0, 0)

    elapse = (time.perf_counter() - tic) / 3600
    print(f"Total training time {elapse:0.4f} hours")