def __init__(self, units_list, out_size=None, layer_type='dense', norm=None, activation=None, kernel_initializer='glorot_uniform', name=None, out_dtype=None, out_gain=1, norm_after_activation=False, norm_kwargs={}, **kwargs): super().__init__(name=name) layer_cls = layer_registry.get(layer_type) Layer = layer_registry.get('layer') logger.debug(f'{self.name} gain: {kwargs.get("gain", None)}') self._out_dtype = out_dtype if activation is None and (len(units_list) > 1 or (units_list and out_size)): logger.warning( f'MLP({name}) with units_list({units_list}) and out_size({out_size}) has no activation.' ) self._layers = [ Layer(u, layer_type=layer_cls, norm=norm, activation=activation, kernel_initializer=kernel_initializer, norm_after_activation=norm_after_activation, norm_kwargs=norm_kwargs, name=f'{name}/{layer_type}_{i}', **kwargs) for i, u in enumerate(units_list) ] if out_size: kwargs.pop('gain', None) logger.debug(f'{self.name} out gain: {out_gain}') kernel_initializer = get_initializer(kernel_initializer, gain=out_gain) self._layers.append( layer_cls(out_size, kernel_initializer=kernel_initializer, name=f'{name}/out', **kwargs))
def __init__(self, layer_type, units, name='ds', min_std=.1, **kwargs): super().__init__(name=name) layer_cls = layer_registry.get(layer_type) self._deter_layer = layer_cls(units // 2, **kwargs, name=f'{self.scope_name}/deter') self._stoch_layer = layer_cls(units, **kwargs, name=f'{self.scope_name}/stoch') self._min_std = min_std
def __init__(self, *args, layer_type=layers.Dense, activation='sigmoid', kernel_initializer='glorot_uniform', name=None, **kwargs): super().__init__(name=name) if isinstance(layer_type, str): layer_type = layer_registry.get(layer_type) gain = kwargs.pop('gain', 1) kernel_initializer = get_initializer(kernel_initializer, gain=gain) self._layer = layer_type( *args, kernel_initializer=kernel_initializer, name=name, **kwargs) self.activation = get_activation(activation)
def build(self, input_shape): H, W, C = input_shape[1:] q_seqlen = kv_seqlen = H * W key_size, val_size = self._compute_sizes(C) self._key_size, self._val_size = key_size, val_size conv_cls = layer_registry.get(self._conv) prefix = f'{self.scope_name}/' self._q_conv = conv_cls(key_size, 1, **self._kwargs, name=prefix + 'q') self._k_conv = conv_cls(key_size, 1, **self._kwargs, name=prefix + 'k') self._v_conv = conv_cls(val_size, 1, **self._kwargs, name=prefix + 'v') if self._downsample_ratio > 1: self._k_downsample = layers.MaxPool2D(self._downsample_ratio, self._downsample_ratio, padding='same', name=prefix + 'k_pool') self._v_downsample = layers.MaxPool2D(self._downsample_ratio, self._downsample_ratio, padding='same', name=prefix + 'v_pool') kv_seqlen //= self._downsample_ratio**2 self._q_reshape = layers.Reshape((q_seqlen, key_size), name=prefix + 'q_reshape') self._k_reshape = layers.Reshape((kv_seqlen, key_size), name=prefix + 'k_reshape') self._v_reshape = layers.Reshape((kv_seqlen, val_size), name=prefix + 'v_reshape') self._att = Attention(prefix + 'attention') self._o_reshape = layers.Reshape((H, W, val_size), name=prefix + 'o_reshape') self._o_conv = conv_cls(C, 1, **self._kwargs, name=prefix + 'o') norm_cls = get_norm(self._norm) self._norm_layer = norm_cls(**self._norm_kwargs, name=prefix + f'{self._norm}') if self._use_rezero: self._rezero = tf.Variable(0., trainable=True, dtype=tf.float32, name=prefix + 'rezero') super().build(input_shape)
def __init__(self, *args, layer_type=layers.Dense, norm=None, activation=None, kernel_initializer='glorot_uniform', name=None, norm_after_activation=False, norm_kwargs={}, **kwargs): super().__init__(name=name) if isinstance(layer_type, str): layer_type = layer_registry.get(layer_type) gain = kwargs.pop('gain', calculate_gain(activation)) kernel_initializer = get_initializer(kernel_initializer, gain=gain) self._layer = layer_type( *args, kernel_initializer=kernel_initializer, name=name, **kwargs) self._norm = norm self._norm_cls = get_norm(norm) if self._norm: self._norm_layer = self._norm_cls(**norm_kwargs, name=f'{self.scope_name}/norm') self._norm_after_activation = norm_after_activation self.activation = get_activation(activation)
assert isinstance(strides, (list, tuple)) and len(strides) == 2, strides return (1,) + tuple(strides) + (1,) layer_registry.register('global_avgpool2d')(layers.GlobalAvgPool2D) layer_registry.register('global_maxpool2d')(layers.GlobalMaxPool2D) layer_registry.register('reshape')(layers.Reshape) layer_registry.register('flatten')(layers.Flatten) layer_registry.register('dense')(layers.Dense) layer_registry.register('conv2d')(layers.Conv2D) layer_registry.register('dwconv2d')(layers.DepthwiseConv2D) layer_registry.register('depthwise_conv2d')(layers.DepthwiseConv2D) layer_registry.register('maxpool2d')(layers.MaxPool2D) layer_registry.register('avgpool2d')(layers.AvgPool2D) if __name__ == '__main__': tf.random.set_seed(0) shape = (1, 2, 3) x = tf.random.normal(shape) # print('x', x[0, 0, :, 0]) print(layer_registry.get_all()) l = layer_registry.get('layer')(2, name='Layer') y = l(x) print(y) y = l(x) print(y) y = l(x) print(y) x = tf.random.normal(shape) y = l(x) print(y)
def build(self, input_shape): kwargs = self._kwargs.copy() out_filters = self._out_filters or input_shape[-1] filter_coefs = self._filter_coefs or [1 for _ in self._kernel_sizes] filters = [int(out_filters * fc) for fc in filter_coefs] if self._out_filters: filters[-1] = self._out_filters if isinstance(self._strides, int): strides = [1 for _ in self._kernel_sizes] strides[0] = self._strides # TODO: strided in the beginning else: assert isinstance(self._strides, (list, tuple)), self._strides strides = self._strides am_kwargs = self._am_kwargs.copy() am_kwargs.update(kwargs) self._layers = [] conv_cls = layer_registry.get(self._conv) self._norm_cls = get_norm(self._norm) act_cls = get_activation(self._activation, return_cls=True) subsample_cls = subsample_registry.get(self._subsample_type) prefix = f'{self.scope_name}/' assert len(filters) == len(self._kernel_sizes) == len(strides) <= 3, \ (filters, self._kernel_sizes, strides) self._build_residual_branch(filters, self._kernel_sizes, strides, prefix, subsample_cls, conv_cls, act_cls, kwargs) am_cls = am_registry.get(self._am) self._layers.append(am_cls(name=prefix + f'{self._am}', **am_kwargs)) if self._skip: if self._strides > 1: self._subsample = [ subsample_cls(name=prefix + f'identity_{self._subsample_type}', **self._subsample_kwargs), conv_cls(filters[-1], 1, name=prefix + f'identity_{self._conv}'), self._norm_cls(**self._norm_kwargs, name=prefix + f'identity_{self._norm}') ] if self._dropout_rate != 0: noise_shape = (None, 1, 1, 1) # Drop the entire residual branch with certain probability, https://arxiv.org/pdf/1603.09382.pdf # TODO: recalibrate the output at test time self._layers.append( layers.Dropout(self._dropout_rate, noise_shape, name=prefix + 'dropout')) if self._use_rezero: self._rezero = tf.Variable(0., trainable=True, dtype=tf.float32, name=prefix + 'rezero') out_act_cls = get_activation(self._out_act, return_cls=True) self._out_act = out_act_cls(name=prefix + self._out_act if self._out_act else '') self._training_cls += [subsample_cls, am_cls]