Ejemplo n.º 1
0
 def __init__(self,
              filters,
              ks,
              strides=1,
              preact=False,
              use_norm=True,
              use_act=True,
              use_bias=False,
              last_norm=False,
              transpose=False):
     super(BasicBlock, self).__init__(filters, strides=strides)
     self.preact = preact
     self.use_norm = use_norm
     self.use_act = use_act
     self.last_norm = last_norm
     if self.use_norm:
         norm_scale = False if 'relu' in FLAGS.conv_act.lower() else True
         self.bn = layers.Norm(scale=norm_scale)
     if self.use_act:
         self.act = layers.Act()
     self.conv = layers.Conv(filters,
                             ks,
                             strides=strides,
                             use_bias=use_bias,
                             transpose=transpose)
     if self.last_norm:
         self.last_bn = layers.Norm()
Ejemplo n.º 2
0
 def _build_encoder(self):
     layer_list = [
         layers.Conv(16, ks=4, strides=2),
         layers.Norm(),
         layers.Conv(64, ks=4, strides=2),
         layers.Norm(),
         layers.Conv(256, ks=4, strides=2),
         layers.Norm(),
         layers.Conv(512, ks=4, strides=2)
     ]
     return self.sequential(layer_list)
Ejemplo n.º 3
0
 def _build_classifier(self):
     layer_list = [
         layers.Dense(1024),
         layers.Act('Relu'),
         layers.Norm(),
         layers.Dense(1)
     ]
     return self.sequential(layer_list)
Ejemplo n.º 4
0
 def __init__(self,
              filters,
              strides=1,
              preact=True,
              last_norm=True,
              pool=False,
              shortcut_type='project'):
     super(ResBlock, self).__init__(filters, strides=strides)
     self.strides = strides
     self.last_norm = last_norm
     self.preact = preact
     self.shortcut_type = shortcut_type
     self.pool = pool
     if self.pool:
         self.pool_lay = layers.Pooling(pool_size=strides)
     if self.shortcut_type == 'project' and self.preact:  # such as preact-resnet
         self.bn1 = layers.Norm()
         self.act1 = layers.Act()
         self.bk1 = blocks.BasicBlock(filters,
                                      3,
                                      strides=strides if not pool else 1,
                                      preact=False,
                                      use_norm=False,
                                      use_act=False)
     elif self.shortcut_type == 'pad' and self.preact:  # such as pyramidnet
         self.bk1 = blocks.BasicBlock(filters,
                                      3,
                                      strides=strides if not pool else 1,
                                      preact=preact,
                                      use_norm=True,
                                      use_act=False)  # no act
     else:  # resnet or preact_resnet
         self.bk1 = blocks.BasicBlock(filters,
                                      3,
                                      strides=strides if not pool else 1,
                                      preact=preact,
                                      use_norm=True,
                                      use_act=True)
     self.bk2 = blocks.BasicBlock(filters,
                                  3,
                                  strides=1,
                                  preact=preact,
                                  use_norm=True,
                                  use_act=True,
                                  last_norm=last_norm)
     if self.shortcut_type == 'pad':
         if pool:
             raise ValueError
         self.shortcut = layers.ShortcutPoolingPadding(pool_size=strides)
     elif self.shortcut_type == 'project':
         self.shortcut = blocks.BasicBlock(
             filters,
             1,
             strides=strides if not pool else 1,
             preact=preact,
             use_norm=False if preact else True,
             use_act=False if preact else True,
             last_norm=False)
Ejemplo n.º 5
0
 def __init__(self):
     """
     Arguments:
         input_shape: by default, (224, 224) is as same as the input shape of imagenet pretrained model
     """
     super(AEClassifier, self).__init__()
     self.first_norm = layers.Norm()
     #         self.dense = layers.Dense(14*14*8)
     #         self.reshape = tf.keras.layers.Reshape((8, 8, 32))
     self.reshape = tf.keras.layers.Reshape((2, 2, 32 * 16))
     self.classifier = self._build_classifier()
     self.decoder = self._build_decoder()
Ejemplo n.º 6
0
    def __init__(self, preact=False, last_norm=False):
        """ By default, preact=False, last_norm=False means vanilla resnet.
        """
        super(ResNetTail, self).__init__()
        self.out_dim = FLAGS.out_dim
        self.preact = preact
        self.last_norm = last_norm

        if self.preact:
            self.preact_lastnorm = layers.Norm()
            self.preact_act = layers.Act()
        self.gpool = layers.GlobalPooling()
        self.flatten = tf.keras.layers.Flatten()
        self.dense = layers.Dense(self.out_dim, activation=None, use_bias=True)
Ejemplo n.º 7
0
    def init(self):
        valid_amount = 0
        for x_batch, _ in self.valid_dataset:
            valid_amount = valid_amount + x_batch.shape[0]

        input_shape = self.train_dataset.element_spec[0][0].shape[
            1:]  # [0][0] -> ((x, x_blur),y) -> x
        self.predictor = Predictor()
        self.randnet_1 = RandomNet()  # Note: fixed, not trainable
        self.randnet_2 = RandomNet()  # Note: fixed, not trainable
        self.norm_all = layers.Norm(center=False, scale=False)
        self.randnet_1.trainable = False
        self.randnet_2.trainable = False
        train_metrics = {
            'loss': tf.keras.metrics.Mean('loss'),
        }
        valid_metrics = {
            'figure':
            metrics.RNDMetrics(mode=metrics.RNDMetrics.MODE_FIGURE,
                               amount=valid_amount),
            'tnr@95tpr':
            metrics.RNDMetrics(mode=metrics.RNDMetrics.MODE_TNR95TPR,
                               amount=valid_amount)
        }
        self.loss_object = losses.MSELoss()
        self.optim = tfa.optimizers.AdamW(weight_decay=FLAGS.wd,
                                          lr=0.0,
                                          beta_1=FLAGS.adam_beta_1,
                                          beta_2=FLAGS.adam_beta_2,
                                          epsilon=FLAGS.adam_epsilon)
        if FLAGS.amp:
            self.optim = mixed_precision.LossScaleOptimizer(
                self.optim,
                loss_scale=tf.mixed_precision.experimental.DynamicLossScale(
                    initial_loss_scale=(2**15),
                    increment_period=20,
                    multiplier=2.0))
        return {
            'predictor': self.predictor,
            'randnet_1': self.randnet_1,
            'randnet_2': self.randnet_2
        }, {
            'predictor': input_shape,
            'randnet_1': input_shape,
            'randnet_2': input_shape
        }, train_metrics, valid_metrics
Ejemplo n.º 8
0
    def __init__(self):
        super(PyramidNet, self).__init__()
        self.out_dim = FLAGS.out_dim
        self.filters_mode = FLAGS.filters_mode
        if self.filters_mode == 'small':  # cifar
            self.in_filters = 16
            self.in_ks = 3
            self.in_strides = 1
            self.groups = 3
            self.in_act_pool = False
        elif self.filters_mode == 'large':  # imagenet
            self.in_filters = 64
            self.in_ks = 7
            self.in_strides = 2
            self.groups = 4
            self.in_act_pool = True
        else:
            raise ValueError
        self.depth = FLAGS.depth
        self.model_alpha = FLAGS.model_alpha
        if FLAGS.bottleneck:
            self.block = blocks.ResBottleneck
            total_blocks = int((self.depth - 2) / 3)
        else:
            self.block = blocks.ResBlock
            total_blocks = int((self.depth - 2) / 2)

        avg_blocks = int(total_blocks / self.groups)
        group_blocks = [avg_blocks for _ in range(self.groups - 1)] + [
            total_blocks - avg_blocks * (self.groups - 1),
        ]

        self.block_ratio = self.model_alpha / total_blocks
        if FLAGS.amp:
            LOGGER.warn(
                'layers_per_block = 3 if FLAGS.bottleneck else 2, total_blocks=(FLAGS.depth-2)/layers_per_block, please set FLAGS.model_alpha=total_blocks*8 to make sure channels are equal to multiple of 8.'
            )
        self.block_counter = 0

        self.head = blocks.BasicBlock(self.in_filters,
                                      self.in_ks,
                                      strides=self.in_strides,
                                      preact=False,
                                      use_norm=True,
                                      use_act=self.in_act_pool)
        if self.in_act_pool:
            self.in_pool = tf.keras.layers.AveragePooling(
                pool_size=(3, ) * FLAGS.dim, strides=2, padding=FLAGS.padding)

        self.all_groups = [
            self._build_pyramid_group(group_blocks[0], strides=1)
        ]
        for b in range(1, self.groups):
            self.all_groups.append(
                self._build_pyramid_group(group_blocks[b], strides=2))

        self.bn = layers.Norm()
        self.act = layers.Act()
        self.gpool = layers.GlobalPooling()
        self.flatten = tf.keras.layers.Flatten()
        self.dense = layers.Dense(self.out_dim, activation=None, use_bias=True)