Beispiel #1
0
 def _build_decoder(self):
     layer_list = [
         blocks.BasicBlock(256,
                           ks=4,
                           strides=2,
                           transpose=True,
                           use_act=True),
         blocks.BasicBlock(64,
                           ks=4,
                           strides=2,
                           transpose=True,
                           use_act=True),
         blocks.BasicBlock(16,
                           ks=4,
                           strides=2,
                           transpose=True,
                           use_act=True),
         blocks.BasicBlock(3,
                           ks=4,
                           strides=2,
                           transpose=True,
                           use_act=False),
         layers.Act('Tanh')
     ]
     return self.sequential(layer_list)
Beispiel #2
0
    def __init__(self):
        super(AE, self).__init__()
        self.out_dim = FLAGS.out_dim
        self.filters_mode = FLAGS.filters_mode
        if self.filters_mode == 'small':  # cifar
            self.in_filters = 16
            self.in_ks = 3
            self.in_strides = 1
            self.groups = 3
            self.in_act_pool = False
        elif self.filters_mode == 'large':  # imagenet
            self.in_filters = 64
            self.in_ks = 7
            self.in_strides = 2
            self.groups = 4
            self.in_act_pool = True
        else:
            raise ValueError

        self.latent_dim = FLAGS.latent_dim
        self.depth = FLAGS.depth
        self.down_block = blocks.ResBlock
        self.up_block = blocks.UpBlock

        self.head = blocks.BasicBlock(self.in_filters,
                                      self.in_ks,
                                      strides=self.in_strides,
                                      preact=False,
                                      use_norm=True,
                                      use_act=self.in_act_pool)
        if self.in_act_pool:
            self.in_pool = tf.keras.layers.AveragePooling(
                pool_size=(3, ) * FLAGS.dim, strides=2, padding=FLAGS.padding)

        encoder_blocks = []
        filters = self.in_filters
        for _ in range(self.groups):
            filters = filters * 2
            encoder_blocks.append(blocks.ResBlock(filters, strides=2))
        self.encoder = self.sequential(encoder_blocks)
        self.latent_encoder = blocks.BasicBlock(self.latent_dim,
                                                3,
                                                strides=1,
                                                preact=True,
                                                use_norm=True,
                                                use_act=False)
        decoder_blocks = []
        for _ in range(self.groups):
            filters = filters // 2
            decoder_blocks.append(
                blocks.UpBlock(filters, with_concat=False, up_size=2))
        self.decoder = self.sequential(decoder_blocks)
        self.tail = blocks.BasicBlock(self.out_dim,
                                      3,
                                      strides=1,
                                      preact=True,
                                      use_norm=True,
                                      use_act=True)
Beispiel #3
0
 def __init__(self,
              filters,
              strides=1,
              preact=True,
              last_norm=True,
              pool=False,
              shortcut_type='project'):
     super(ResBlock, self).__init__(filters, strides=strides)
     self.strides = strides
     self.last_norm = last_norm
     self.preact = preact
     self.shortcut_type = shortcut_type
     self.pool = pool
     if self.pool:
         self.pool_lay = layers.Pooling(pool_size=strides)
     if self.shortcut_type == 'project' and self.preact:  # such as preact-resnet
         self.bn1 = layers.Norm()
         self.act1 = layers.Act()
         self.bk1 = blocks.BasicBlock(filters,
                                      3,
                                      strides=strides if not pool else 1,
                                      preact=False,
                                      use_norm=False,
                                      use_act=False)
     elif self.shortcut_type == 'pad' and self.preact:  # such as pyramidnet
         self.bk1 = blocks.BasicBlock(filters,
                                      3,
                                      strides=strides if not pool else 1,
                                      preact=preact,
                                      use_norm=True,
                                      use_act=False)  # no act
     else:  # resnet or preact_resnet
         self.bk1 = blocks.BasicBlock(filters,
                                      3,
                                      strides=strides if not pool else 1,
                                      preact=preact,
                                      use_norm=True,
                                      use_act=True)
     self.bk2 = blocks.BasicBlock(filters,
                                  3,
                                  strides=1,
                                  preact=preact,
                                  use_norm=True,
                                  use_act=True,
                                  last_norm=last_norm)
     if self.shortcut_type == 'pad':
         if pool:
             raise ValueError
         self.shortcut = layers.ShortcutPoolingPadding(pool_size=strides)
     elif self.shortcut_type == 'project':
         self.shortcut = blocks.BasicBlock(
             filters,
             1,
             strides=strides if not pool else 1,
             preact=preact,
             use_norm=False if preact else True,
             use_act=False if preact else True,
             last_norm=False)
Beispiel #4
0
 def get_embedding_module(self):
     return self.sequential([
         blocks.BasicBlock(self.n_channels, 3),
         layers.Pooling(2),
         blocks.BasicBlock(self.n_channels, 3),
         layers.Pooling(2),
         blocks.BasicBlock(self.n_channels, 3),
         blocks.BasicBlock(self.n_channels, 3)
     ])
Beispiel #5
0
 def get_relation_module(self):
     return self.sequential([
         blocks.BasicBlock(self.n_channels, 3),
         layers.Pooling(2),
         blocks.BasicBlock(self.n_channels, 3),
         layers.Pooling(2),
         tf.keras.layers.Flatten(),
         layers.Dense(8),
         tf.keras.layers.ReLU(),
         layers.Dense(1)
     ])
Beispiel #6
0
 def _build_append_blocks(self):
     layer_list = [
         blocks.BasicBlock(64,
                           3,
                           strides=1,
                           preact=False,
                           use_norm=True,
                           use_act=False),
         blocks.BasicBlock(32,
                           3,
                           strides=1,
                           preact=False,
                           use_norm=True,
                           use_act=False),
         tf.keras.layers.Flatten(),
     ]
     return self.sequential(layer_list)
Beispiel #7
0
 def _input_blocks(self):
     all_blocks = [
         blocks.BasicBlock(self.init_filters,
                           3,
                           strides=1,
                           preact=False,
                           use_norm=True,
                           use_act=True),
         blocks.ResBlock(self.init_filters)
     ]
     return self.sequential(all_blocks)
Beispiel #8
0
 def __init__(self, n_filters, with_concat, up_size=2):
     super(UpBlock, self).__init__(n_filters)
     self.with_concat = with_concat
     self.upsampling = layers.UpSampling(up_size=up_size)
     if self.with_concat:
         self.concat = tf.keras.layers.Concatenate()
     self.project = blocks.BasicBlock(n_filters,
                                      3,
                                      use_norm=False,
                                      use_act=False)
     self.res_block = blocks.ResBlock(n_filters)
Beispiel #9
0
    def __init__(self,
                 num_blocks,
                 feature_mode=False,
                 preact=False,
                 last_norm=False):
        """ By default, preact=False, last_norm=False means vanilla resnet.
        """
        super(ResNet, self).__init__()
        self.out_dim = FLAGS.out_dim
        self.depth = FLAGS.depth
        self.preact = preact
        self.last_norm = last_norm
        self.feature_mode = feature_mode
        if FLAGS.bottleneck:
            self.block = blocks.ResBottleneck
        else:
            self.block = blocks.ResBlock

        self.filters_mode = FLAGS.filters_mode
        if self.filters_mode == 'small':  # cifar
            self.in_ks = 3
            self.in_strides = 1
            self.in_act_pool = False
        elif self.filters_mode == 'large':  # imagenet
            self.in_ks = 7
            self.in_strides = 2
            self.in_act_pool = True
        else:
            raise ValueError

        self.in_filters = FLAGS.in_filters if FLAGS.in_filters else 64
        self.head = blocks.BasicBlock(self.in_filters,
                                      self.in_ks,
                                      strides=self.in_strides,
                                      preact=False,
                                      use_norm=True,
                                      use_act=self.in_act_pool)

        if self.in_act_pool:
            self.in_pool = tf.keras.layers.AveragePooling(
                pool_size=(3, ) * FLAGS.dim, strides=2, padding=FLAGS.padding)

        self.all_groups = [
            self._build_group(self.in_filters, num_blocks[0], strides=1)
        ]
        for i in range(1, len(num_blocks)):
            self.all_groups.append(
                self._build_group(self.in_filters * (2**i),
                                  num_blocks[i],
                                  strides=2))
        if not self.feature_mode:
            self.tail = ResNetTail(self.preact, self.last_norm)
Beispiel #10
0
 def __init__(self, init_filters=32, depth=4):
     super(UNet, self).__init__()
     self.out_dim = FLAGS.out_dim
     self.init_filters = init_filters
     self.depth = depth
     self.head = self._input_blocks()
     self.down_group = self._get_down_blocks()
     self.up_group = self._get_up_blocks()
     self.out = blocks.BasicBlock(self.out_dim,
                                  1,
                                  use_norm=False,
                                  use_bias=True,
                                  use_act=False)
Beispiel #11
0
    def __init__(self, preact=False, last_norm=False):
        """ By default, preact=False, last_norm=False means vanilla resnet.
        """
        super(HResResNetTail, self).__init__()
        self.out_dim = FLAGS.out_dim
        self.preact = preact
        self.last_norm = last_norm

        self.tail = blocks.BasicBlock(FLAGS.out_dim,
                                      1,
                                      strides=1,
                                      preact=self.preact,
                                      use_norm=self.preact,
                                      use_act=self.preact,
                                      use_bias=True)
Beispiel #12
0
    def __init__(self,
                 num_blocks,
                 feature_mode=False,
                 preact=False,
                 last_norm=False):
        """ By default, preact=False, last_norm=False means vanilla resnet.
        """
        super(HResResNet, self).__init__()
        self.out_dim = FLAGS.out_dim
        self.depth = FLAGS.depth
        self.preact = preact
        self.last_norm = last_norm
        self.feature_mode = feature_mode
        if FLAGS.bottleneck:
            self.block = blocks.ResBottleneck
        else:
            self.block = blocks.ResBlock

        self.filters_mode = FLAGS.filters_mode
        if self.filters_mode == 'small':  # cifar
            self.in_ks = 3
        elif self.filters_mode == 'large':  # imagenet
            self.in_ks = 7
        else:
            raise ValueError

        self.in_filters = FLAGS.in_filters if FLAGS.in_filters else 64
        self.head = blocks.BasicBlock(self.in_filters,
                                      self.in_ks,
                                      strides=1,
                                      preact=False,
                                      use_norm=True,
                                      use_act=True)

        self.all_groups = [
            self._build_group(self.in_filters, num_blocks[0], strides=1)
        ]
        for i in range(1, len(num_blocks)):
            self.all_groups.append(
                self._build_group(self.in_filters * (2**i),
                                  num_blocks[i],
                                  strides=1))
        if not self.feature_mode:
            self.tail = HResResNetTail(self.preact, self.last_norm)
Beispiel #13
0
    def __init__(self):
        super(PyramidNet, self).__init__()
        self.out_dim = FLAGS.out_dim
        self.filters_mode = FLAGS.filters_mode
        if self.filters_mode == 'small':  # cifar
            self.in_filters = 16
            self.in_ks = 3
            self.in_strides = 1
            self.groups = 3
            self.in_act_pool = False
        elif self.filters_mode == 'large':  # imagenet
            self.in_filters = 64
            self.in_ks = 7
            self.in_strides = 2
            self.groups = 4
            self.in_act_pool = True
        else:
            raise ValueError
        self.depth = FLAGS.depth
        self.model_alpha = FLAGS.model_alpha
        if FLAGS.bottleneck:
            self.block = blocks.ResBottleneck
            total_blocks = int((self.depth - 2) / 3)
        else:
            self.block = blocks.ResBlock
            total_blocks = int((self.depth - 2) / 2)

        avg_blocks = int(total_blocks / self.groups)
        group_blocks = [avg_blocks for _ in range(self.groups - 1)] + [
            total_blocks - avg_blocks * (self.groups - 1),
        ]

        self.block_ratio = self.model_alpha / total_blocks
        if FLAGS.amp:
            LOGGER.warn(
                'layers_per_block = 3 if FLAGS.bottleneck else 2, total_blocks=(FLAGS.depth-2)/layers_per_block, please set FLAGS.model_alpha=total_blocks*8 to make sure channels are equal to multiple of 8.'
            )
        self.block_counter = 0

        self.head = blocks.BasicBlock(self.in_filters,
                                      self.in_ks,
                                      strides=self.in_strides,
                                      preact=False,
                                      use_norm=True,
                                      use_act=self.in_act_pool)
        if self.in_act_pool:
            self.in_pool = tf.keras.layers.AveragePooling(
                pool_size=(3, ) * FLAGS.dim, strides=2, padding=FLAGS.padding)

        self.all_groups = [
            self._build_pyramid_group(group_blocks[0], strides=1)
        ]
        for b in range(1, self.groups):
            self.all_groups.append(
                self._build_pyramid_group(group_blocks[b], strides=2))

        self.bn = layers.Norm()
        self.act = layers.Act()
        self.gpool = layers.GlobalPooling()
        self.flatten = tf.keras.layers.Flatten()
        self.dense = layers.Dense(self.out_dim, activation=None, use_bias=True)