def _build_decoder(self): layer_list = [ blocks.BasicBlock(256, ks=4, strides=2, transpose=True, use_act=True), blocks.BasicBlock(64, ks=4, strides=2, transpose=True, use_act=True), blocks.BasicBlock(16, ks=4, strides=2, transpose=True, use_act=True), blocks.BasicBlock(3, ks=4, strides=2, transpose=True, use_act=False), layers.Act('Tanh') ] return self.sequential(layer_list)
def __init__(self): super(AE, self).__init__() self.out_dim = FLAGS.out_dim self.filters_mode = FLAGS.filters_mode if self.filters_mode == 'small': # cifar self.in_filters = 16 self.in_ks = 3 self.in_strides = 1 self.groups = 3 self.in_act_pool = False elif self.filters_mode == 'large': # imagenet self.in_filters = 64 self.in_ks = 7 self.in_strides = 2 self.groups = 4 self.in_act_pool = True else: raise ValueError self.latent_dim = FLAGS.latent_dim self.depth = FLAGS.depth self.down_block = blocks.ResBlock self.up_block = blocks.UpBlock self.head = blocks.BasicBlock(self.in_filters, self.in_ks, strides=self.in_strides, preact=False, use_norm=True, use_act=self.in_act_pool) if self.in_act_pool: self.in_pool = tf.keras.layers.AveragePooling( pool_size=(3, ) * FLAGS.dim, strides=2, padding=FLAGS.padding) encoder_blocks = [] filters = self.in_filters for _ in range(self.groups): filters = filters * 2 encoder_blocks.append(blocks.ResBlock(filters, strides=2)) self.encoder = self.sequential(encoder_blocks) self.latent_encoder = blocks.BasicBlock(self.latent_dim, 3, strides=1, preact=True, use_norm=True, use_act=False) decoder_blocks = [] for _ in range(self.groups): filters = filters // 2 decoder_blocks.append( blocks.UpBlock(filters, with_concat=False, up_size=2)) self.decoder = self.sequential(decoder_blocks) self.tail = blocks.BasicBlock(self.out_dim, 3, strides=1, preact=True, use_norm=True, use_act=True)
def __init__(self, filters, strides=1, preact=True, last_norm=True, pool=False, shortcut_type='project'): super(ResBlock, self).__init__(filters, strides=strides) self.strides = strides self.last_norm = last_norm self.preact = preact self.shortcut_type = shortcut_type self.pool = pool if self.pool: self.pool_lay = layers.Pooling(pool_size=strides) if self.shortcut_type == 'project' and self.preact: # such as preact-resnet self.bn1 = layers.Norm() self.act1 = layers.Act() self.bk1 = blocks.BasicBlock(filters, 3, strides=strides if not pool else 1, preact=False, use_norm=False, use_act=False) elif self.shortcut_type == 'pad' and self.preact: # such as pyramidnet self.bk1 = blocks.BasicBlock(filters, 3, strides=strides if not pool else 1, preact=preact, use_norm=True, use_act=False) # no act else: # resnet or preact_resnet self.bk1 = blocks.BasicBlock(filters, 3, strides=strides if not pool else 1, preact=preact, use_norm=True, use_act=True) self.bk2 = blocks.BasicBlock(filters, 3, strides=1, preact=preact, use_norm=True, use_act=True, last_norm=last_norm) if self.shortcut_type == 'pad': if pool: raise ValueError self.shortcut = layers.ShortcutPoolingPadding(pool_size=strides) elif self.shortcut_type == 'project': self.shortcut = blocks.BasicBlock( filters, 1, strides=strides if not pool else 1, preact=preact, use_norm=False if preact else True, use_act=False if preact else True, last_norm=False)
def get_embedding_module(self): return self.sequential([ blocks.BasicBlock(self.n_channels, 3), layers.Pooling(2), blocks.BasicBlock(self.n_channels, 3), layers.Pooling(2), blocks.BasicBlock(self.n_channels, 3), blocks.BasicBlock(self.n_channels, 3) ])
def get_relation_module(self): return self.sequential([ blocks.BasicBlock(self.n_channels, 3), layers.Pooling(2), blocks.BasicBlock(self.n_channels, 3), layers.Pooling(2), tf.keras.layers.Flatten(), layers.Dense(8), tf.keras.layers.ReLU(), layers.Dense(1) ])
def _build_append_blocks(self): layer_list = [ blocks.BasicBlock(64, 3, strides=1, preact=False, use_norm=True, use_act=False), blocks.BasicBlock(32, 3, strides=1, preact=False, use_norm=True, use_act=False), tf.keras.layers.Flatten(), ] return self.sequential(layer_list)
def _input_blocks(self): all_blocks = [ blocks.BasicBlock(self.init_filters, 3, strides=1, preact=False, use_norm=True, use_act=True), blocks.ResBlock(self.init_filters) ] return self.sequential(all_blocks)
def __init__(self, n_filters, with_concat, up_size=2): super(UpBlock, self).__init__(n_filters) self.with_concat = with_concat self.upsampling = layers.UpSampling(up_size=up_size) if self.with_concat: self.concat = tf.keras.layers.Concatenate() self.project = blocks.BasicBlock(n_filters, 3, use_norm=False, use_act=False) self.res_block = blocks.ResBlock(n_filters)
def __init__(self, num_blocks, feature_mode=False, preact=False, last_norm=False): """ By default, preact=False, last_norm=False means vanilla resnet. """ super(ResNet, self).__init__() self.out_dim = FLAGS.out_dim self.depth = FLAGS.depth self.preact = preact self.last_norm = last_norm self.feature_mode = feature_mode if FLAGS.bottleneck: self.block = blocks.ResBottleneck else: self.block = blocks.ResBlock self.filters_mode = FLAGS.filters_mode if self.filters_mode == 'small': # cifar self.in_ks = 3 self.in_strides = 1 self.in_act_pool = False elif self.filters_mode == 'large': # imagenet self.in_ks = 7 self.in_strides = 2 self.in_act_pool = True else: raise ValueError self.in_filters = FLAGS.in_filters if FLAGS.in_filters else 64 self.head = blocks.BasicBlock(self.in_filters, self.in_ks, strides=self.in_strides, preact=False, use_norm=True, use_act=self.in_act_pool) if self.in_act_pool: self.in_pool = tf.keras.layers.AveragePooling( pool_size=(3, ) * FLAGS.dim, strides=2, padding=FLAGS.padding) self.all_groups = [ self._build_group(self.in_filters, num_blocks[0], strides=1) ] for i in range(1, len(num_blocks)): self.all_groups.append( self._build_group(self.in_filters * (2**i), num_blocks[i], strides=2)) if not self.feature_mode: self.tail = ResNetTail(self.preact, self.last_norm)
def __init__(self, init_filters=32, depth=4): super(UNet, self).__init__() self.out_dim = FLAGS.out_dim self.init_filters = init_filters self.depth = depth self.head = self._input_blocks() self.down_group = self._get_down_blocks() self.up_group = self._get_up_blocks() self.out = blocks.BasicBlock(self.out_dim, 1, use_norm=False, use_bias=True, use_act=False)
def __init__(self, preact=False, last_norm=False): """ By default, preact=False, last_norm=False means vanilla resnet. """ super(HResResNetTail, self).__init__() self.out_dim = FLAGS.out_dim self.preact = preact self.last_norm = last_norm self.tail = blocks.BasicBlock(FLAGS.out_dim, 1, strides=1, preact=self.preact, use_norm=self.preact, use_act=self.preact, use_bias=True)
def __init__(self, num_blocks, feature_mode=False, preact=False, last_norm=False): """ By default, preact=False, last_norm=False means vanilla resnet. """ super(HResResNet, self).__init__() self.out_dim = FLAGS.out_dim self.depth = FLAGS.depth self.preact = preact self.last_norm = last_norm self.feature_mode = feature_mode if FLAGS.bottleneck: self.block = blocks.ResBottleneck else: self.block = blocks.ResBlock self.filters_mode = FLAGS.filters_mode if self.filters_mode == 'small': # cifar self.in_ks = 3 elif self.filters_mode == 'large': # imagenet self.in_ks = 7 else: raise ValueError self.in_filters = FLAGS.in_filters if FLAGS.in_filters else 64 self.head = blocks.BasicBlock(self.in_filters, self.in_ks, strides=1, preact=False, use_norm=True, use_act=True) self.all_groups = [ self._build_group(self.in_filters, num_blocks[0], strides=1) ] for i in range(1, len(num_blocks)): self.all_groups.append( self._build_group(self.in_filters * (2**i), num_blocks[i], strides=1)) if not self.feature_mode: self.tail = HResResNetTail(self.preact, self.last_norm)
def __init__(self): super(PyramidNet, self).__init__() self.out_dim = FLAGS.out_dim self.filters_mode = FLAGS.filters_mode if self.filters_mode == 'small': # cifar self.in_filters = 16 self.in_ks = 3 self.in_strides = 1 self.groups = 3 self.in_act_pool = False elif self.filters_mode == 'large': # imagenet self.in_filters = 64 self.in_ks = 7 self.in_strides = 2 self.groups = 4 self.in_act_pool = True else: raise ValueError self.depth = FLAGS.depth self.model_alpha = FLAGS.model_alpha if FLAGS.bottleneck: self.block = blocks.ResBottleneck total_blocks = int((self.depth - 2) / 3) else: self.block = blocks.ResBlock total_blocks = int((self.depth - 2) / 2) avg_blocks = int(total_blocks / self.groups) group_blocks = [avg_blocks for _ in range(self.groups - 1)] + [ total_blocks - avg_blocks * (self.groups - 1), ] self.block_ratio = self.model_alpha / total_blocks if FLAGS.amp: LOGGER.warn( 'layers_per_block = 3 if FLAGS.bottleneck else 2, total_blocks=(FLAGS.depth-2)/layers_per_block, please set FLAGS.model_alpha=total_blocks*8 to make sure channels are equal to multiple of 8.' ) self.block_counter = 0 self.head = blocks.BasicBlock(self.in_filters, self.in_ks, strides=self.in_strides, preact=False, use_norm=True, use_act=self.in_act_pool) if self.in_act_pool: self.in_pool = tf.keras.layers.AveragePooling( pool_size=(3, ) * FLAGS.dim, strides=2, padding=FLAGS.padding) self.all_groups = [ self._build_pyramid_group(group_blocks[0], strides=1) ] for b in range(1, self.groups): self.all_groups.append( self._build_pyramid_group(group_blocks[b], strides=2)) self.bn = layers.Norm() self.act = layers.Act() self.gpool = layers.GlobalPooling() self.flatten = tf.keras.layers.Flatten() self.dense = layers.Dense(self.out_dim, activation=None, use_bias=True)