def _parse_normalization_kwargs(self, use_batch_norm, batch_norm_config, normalization_ctor, normalization_kwargs): """Sets up normalization, checking old and new flags.""" if use_batch_norm is not None: # Delete this whole block when deprecation is done. util.deprecation_warning( "`use_batch_norm` kwarg is deprecated. Change your code to instead " "specify `normalization_ctor` and `normalization_kwargs`.") if not use_batch_norm: # Explicitly set to False - normalization_{ctor,kwargs} has precedence. self._check_and_assign_normalization_members( normalization_ctor, normalization_kwargs or {}) else: # Explicitly set to true - new kwargs must not be used. if normalization_ctor is not None or normalization_kwargs is not None: raise ValueError( "if use_batch_norm is specified, normalization_ctor and " "normalization_kwargs must not be.") self._check_and_assign_normalization_members( batch_norm.BatchNorm, batch_norm_config or {}) else: # Old kwargs not set, this block will remain after removing old kwarg. self._check_and_assign_normalization_members( normalization_ctor, normalization_kwargs or {})
def _parse_normalization_kwargs(self, use_batch_norm, batch_norm_config, normalization_ctor, normalization_kwargs): """Sets up normalization, checking old and new flags.""" if use_batch_norm is not None: # Delete this whole block when deprecation is done. util.deprecation_warning( "`use_batch_norm` kwarg is deprecated. Change your code to instead " "specify `normalization_ctor` and `normalization_kwargs`.") if not use_batch_norm: # Explicitly set to False - normalization_{ctor,kwargs} has precedence. self._check_and_assign_normalization_members(normalization_ctor, normalization_kwargs or {}) else: # Explicitly set to true - new kwargs must not be used. if normalization_ctor is not None or normalization_kwargs is not None: raise ValueError( "if use_batch_norm is specified, normalization_ctor and " "normalization_kwargs must not be.") self._check_and_assign_normalization_members(batch_norm.BatchNorm, batch_norm_config or {}) else: # Old kwargs not set, this block will remain after removing old kwarg. self._check_and_assign_normalization_members(normalization_ctor, normalization_kwargs or {})
def __init__( self, output_channels, kernel_shapes, strides, paddings, rates=(1, ), activation=tf.nn.relu, activate_final=False, normalization_ctor=None, normalization_kwargs=None, normalize_final=None, initializers=None, partitioners=None, regularizers=None, use_batch_norm=None, # Deprecated. use_bias=True, batch_norm_config=None, # Deprecated. data_format=DATA_FORMAT_NHWC, custom_getter=None, name="conv_net_2d"): """Constructs a `ConvNet2D` module. By default, neither batch normalization nor activation are applied to the output of the final layer. Args: output_channels: Iterable of output channels, as defined in `conv.Conv2D`. Output channels can be defined either as number or via a callable. In the latter case, since the function invocation is deferred to graph construction time, the user must only ensure that entries can be called when build is called. Each entry in the iterable defines properties in the corresponding convolutional layer. kernel_shapes: Iterable of kernel sizes as defined in `conv.Conv2D`; if the list contains one element only, the same kernel shape is used in each layer of the network. strides: Iterable of kernel strides as defined in `conv.Conv2D`; if the list contains one element only, the same stride is used in each layer of the network. paddings: Iterable of padding options as defined in `conv.Conv2D`. Each can be `snt.SAME`, `snt.VALID`, `snt.FULL`, `snt.CAUSAL`, `snt.REVERSE_CAUSAL` or a pair of these to use for height and width. If the Iterable contains one element only, the same padding is used in each layer of the network. rates: Iterable of dilation rates as defined in `conv.Conv2D`; if the list contains one element only, the same rate is used in each layer of the network. activation: An activation op. activate_final: Boolean determining if the activation and batch normalization, if turned on, are applied to the final layer. normalization_ctor: Constructor to return a callable which will perform normalization at each layer. Defaults to None / no normalization. Examples of what could go here: `snt.BatchNormV2`, `snt.LayerNorm`. If a string is provided, importlib is used to convert the string to a callable, so either `snt.LayerNorm` or `"snt.LayerNorm"` can be provided. normalization_kwargs: kwargs to be provided to `normalization_ctor` when it is called. normalize_final: Whether to apply normalization after the final conv layer. Default is to take the value of activate_final. initializers: Optional dict containing ops to initialize the filters of the whole network (with key 'w') or biases (with key 'b'). partitioners: Optional dict containing partitioners to partition weights (with key 'w') or biases (with key 'b'). As a default, no partitioners are used. regularizers: Optional dict containing regularizers for the filters of the whole network (with key 'w') or biases (with key 'b'). As a default, no regularizers are used. A regularizer should be a function that takes a single `Tensor` as an input and returns a scalar `Tensor` output, e.g. the L1 and L2 regularizers in `tf.contrib.layers`. use_batch_norm: Boolean determining if batch normalization is applied after convolution. Deprecated, use `normalization_ctor` instead. use_bias: Boolean or iterable of booleans determining whether to include bias parameters in the convolutional layers. Default `True`. batch_norm_config: Optional mapping of additional configuration for the `snt.BatchNorm` modules. Deprecated, use `normalization_kwargs` instead. data_format: A string, one of "NCHW" or "NHWC". Specifies whether the channel dimension of the input and output is the last dimension (default, "NHWC"), or the second dimension ("NCHW"). custom_getter: Callable or dictionary of callables to use as custom getters inside the module. If a dictionary, the keys correspond to regexes to match variable names. See the `tf.get_variable` documentation for information about the custom_getter API. name: Name of the module. Raises: TypeError: If `output_channels` is not iterable; or if `kernel_shapes` is not iterable; or `strides` is not iterable; or `paddings` is not iterable; or if `activation` is not callable. ValueError: If `output_channels` is empty; or if `kernel_shapes` has not length 1 or `len(output_channels)`; or if `strides` has not length 1 or `len(output_channels)`; or if `paddings` has not length 1 or `len(output_channels)`; or if `rates` has not length 1 or `len(output_channels)`; or if the given data_format is not a supported format ("NHWC" or "NCHW"); or if `normalization_ctor` is provided but cannot be mapped to a callable. KeyError: If `initializers`, `partitioners` or `regularizers` contain any keys other than 'w' or 'b'. TypeError: If any of the given initializers, partitioners or regularizers are not callable. """ if not isinstance(output_channels, collections.Iterable): raise TypeError("output_channels must be iterable") output_channels = tuple(output_channels) if not isinstance(kernel_shapes, collections.Iterable): raise TypeError("kernel_shapes must be iterable") kernel_shapes = tuple(kernel_shapes) if not isinstance(strides, collections.Iterable): raise TypeError("strides must be iterable") strides = tuple(strides) if not isinstance(paddings, collections.Iterable): raise TypeError("paddings must be iterable") paddings = tuple(paddings) if not isinstance(rates, collections.Iterable): raise TypeError("rates must be iterable") rates = tuple(rates) if isinstance(use_batch_norm, collections.Iterable): raise TypeError( "use_batch_norm must be a boolean. Per-layer use of " "batch normalization is not supported. Previously, a " "test erroneously suggested use_batch_norm can be an " "iterable of booleans.") super(ConvNet2D, self).__init__(name=name, custom_getter=custom_getter) if not output_channels: raise ValueError("output_channels must not be empty") self._output_channels = tuple(output_channels) self._num_layers = len(self._output_channels) self._input_shape = None if data_format not in SUPPORTED_2D_DATA_FORMATS: raise ValueError("Invalid data_format {}. Allowed formats " "{}".format(data_format, SUPPORTED_2D_DATA_FORMATS)) self._data_format = data_format self._initializers = util.check_initializers( initializers, self.POSSIBLE_INITIALIZER_KEYS) self._partitioners = util.check_partitioners( partitioners, self.POSSIBLE_INITIALIZER_KEYS) self._regularizers = util.check_regularizers( regularizers, self.POSSIBLE_INITIALIZER_KEYS) if not callable(activation): raise TypeError("Input 'activation' must be callable") self._activation = activation self._activate_final = activate_final self._kernel_shapes = _replicate_elements(kernel_shapes, self._num_layers) if len(self._kernel_shapes) != self._num_layers: raise ValueError( "kernel_shapes must be of length 1 or len(output_channels)") self._strides = _replicate_elements(strides, self._num_layers) if len(self._strides) != self._num_layers: raise ValueError( """strides must be of length 1 or len(output_channels)""") self._paddings = _replicate_elements(paddings, self._num_layers) if len(self._paddings) != self._num_layers: raise ValueError( """paddings must be of length 1 or len(output_channels)""") self._rates = _replicate_elements(rates, self._num_layers) if len(self._rates) != self._num_layers: raise ValueError( """rates must be of length 1 or len(output_channels)""") self._parse_normalization_kwargs(use_batch_norm, batch_norm_config, normalization_ctor, normalization_kwargs) if normalize_final is None: util.deprecation_warning( "normalize_final is not specified, so using the value of " "activate_final = {}. Change your code to set this kwarg explicitly. " "In the future, normalize_final will default to True.".format( activate_final)) self._normalize_final = activate_final else: # User has provided an override, so don't link to activate_final. self._normalize_final = normalize_final if isinstance(use_bias, bool): use_bias = (use_bias, ) else: if not isinstance(use_bias, collections.Iterable): raise TypeError( "use_bias must be either a bool or an iterable") use_bias = tuple(use_bias) self._use_bias = _replicate_elements(use_bias, self._num_layers) self._instantiate_layers()
def batch_norm_config(self): util.deprecation_warning( "The `.batch_norm_config` property is deprecated. Check " "`.normalization_kwargs` instead.") return self._normalization_kwargs
def use_batch_norm(self): util.deprecation_warning( "The `.use_batch_norm` property is deprecated. Check " "`.normalization_ctor` instead.") return self._normalization_ctor == batch_norm.BatchNorm
def __init__(self, output_channels, kernel_shapes, strides, paddings, rates=(1,), activation=tf.nn.relu, activate_final=False, normalization_ctor=None, normalization_kwargs=None, normalize_final=None, initializers=None, partitioners=None, regularizers=None, use_batch_norm=None, # Deprecated. use_bias=True, batch_norm_config=None, # Deprecated. data_format=DATA_FORMAT_NHWC, custom_getter=None, name="conv_net_2d"): """Constructs a `ConvNet2D` module. By default, neither batch normalization nor activation are applied to the output of the final layer. Args: output_channels: Iterable of output channels, as defined in `conv.Conv2D`. Output channels can be defined either as number or via a callable. In the latter case, since the function invocation is deferred to graph construction time, the user must only ensure that entries can be called when build is called. Each entry in the iterable defines properties in the corresponding convolutional layer. kernel_shapes: Iterable of kernel sizes as defined in `conv.Conv2D`; if the list contains one element only, the same kernel shape is used in each layer of the network. strides: Iterable of kernel strides as defined in `conv.Conv2D`; if the list contains one element only, the same stride is used in each layer of the network. paddings: Iterable of padding options as defined in `conv.Conv2D`. Each can be `snt.SAME`, `snt.VALID`, `snt.FULL`, `snt.CAUSAL`, `snt.REVERSE_CAUSAL` or a pair of these to use for height and width. If the Iterable contains one element only, the same padding is used in each layer of the network. rates: Iterable of dilation rates as defined in `conv.Conv2D`; if the list contains one element only, the same rate is used in each layer of the network. activation: An activation op. activate_final: Boolean determining if the activation and batch normalization, if turned on, are applied to the final layer. normalization_ctor: Constructor to return a callable which will perform normalization at each layer. Defaults to None / no normalization. Examples of what could go here: `snt.BatchNormV2`, `snt.LayerNorm`. If a string is provided, importlib is used to convert the string to a callable, so either `snt.LayerNorm` or `"snt.LayerNorm"` can be provided. normalization_kwargs: kwargs to be provided to `normalization_ctor` when it is called. normalize_final: Whether to apply normalization after the final conv layer. Default is to take the value of activate_final. initializers: Optional dict containing ops to initialize the filters of the whole network (with key 'w') or biases (with key 'b'). partitioners: Optional dict containing partitioners to partition weights (with key 'w') or biases (with key 'b'). As a default, no partitioners are used. regularizers: Optional dict containing regularizers for the filters of the whole network (with key 'w') or biases (with key 'b'). As a default, no regularizers are used. A regularizer should be a function that takes a single `Tensor` as an input and returns a scalar `Tensor` output, e.g. the L1 and L2 regularizers in `tf.contrib.layers`. use_batch_norm: Boolean determining if batch normalization is applied after convolution. Deprecated, use `normalization_ctor` instead. use_bias: Boolean or iterable of booleans determining whether to include bias parameters in the convolutional layers. Default `True`. batch_norm_config: Optional mapping of additional configuration for the `snt.BatchNorm` modules. Deprecated, use `normalization_kwargs` instead. data_format: A string, one of "NCHW" or "NHWC". Specifies whether the channel dimension of the input and output is the last dimension (default, "NHWC"), or the second dimension ("NCHW"). custom_getter: Callable or dictionary of callables to use as custom getters inside the module. If a dictionary, the keys correspond to regexes to match variable names. See the `tf.get_variable` documentation for information about the custom_getter API. name: Name of the module. Raises: TypeError: If `output_channels` is not iterable; or if `kernel_shapes` is not iterable; or `strides` is not iterable; or `paddings` is not iterable; or if `activation` is not callable. ValueError: If `output_channels` is empty; or if `kernel_shapes` has not length 1 or `len(output_channels)`; or if `strides` has not length 1 or `len(output_channels)`; or if `paddings` has not length 1 or `len(output_channels)`; or if `rates` has not length 1 or `len(output_channels)`; or if the given data_format is not a supported format ("NHWC" or "NCHW"); or if `normalization_ctor` is provided but cannot be mapped to a callable. KeyError: If `initializers`, `partitioners` or `regularizers` contain any keys other than 'w' or 'b'. TypeError: If any of the given initializers, partitioners or regularizers are not callable. """ if not isinstance(output_channels, collections.Iterable): raise TypeError("output_channels must be iterable") output_channels = tuple(output_channels) if not isinstance(kernel_shapes, collections.Iterable): raise TypeError("kernel_shapes must be iterable") kernel_shapes = tuple(kernel_shapes) if not isinstance(strides, collections.Iterable): raise TypeError("strides must be iterable") strides = tuple(strides) if not isinstance(paddings, collections.Iterable): raise TypeError("paddings must be iterable") paddings = tuple(paddings) if not isinstance(rates, collections.Iterable): raise TypeError("rates must be iterable") rates = tuple(rates) if isinstance(use_batch_norm, collections.Iterable): raise TypeError("use_batch_norm must be a boolean. Per-layer use of " "batch normalization is not supported. Previously, a " "test erroneously suggested use_batch_norm can be an " "iterable of booleans.") super(ConvNet2D, self).__init__(name=name, custom_getter=custom_getter) if not output_channels: raise ValueError("output_channels must not be empty") self._output_channels = tuple(output_channels) self._num_layers = len(self._output_channels) self._input_shape = None if data_format not in SUPPORTED_2D_DATA_FORMATS: raise ValueError("Invalid data_format {}. Allowed formats " "{}".format(data_format, SUPPORTED_2D_DATA_FORMATS)) self._data_format = data_format self._initializers = util.check_initializers( initializers, self.POSSIBLE_INITIALIZER_KEYS) self._partitioners = util.check_partitioners( partitioners, self.POSSIBLE_INITIALIZER_KEYS) self._regularizers = util.check_regularizers( regularizers, self.POSSIBLE_INITIALIZER_KEYS) if not callable(activation): raise TypeError("Input 'activation' must be callable") self._activation = activation self._activate_final = activate_final self._kernel_shapes = _replicate_elements(kernel_shapes, self._num_layers) if len(self._kernel_shapes) != self._num_layers: raise ValueError( "kernel_shapes must be of length 1 or len(output_channels)") self._strides = _replicate_elements(strides, self._num_layers) if len(self._strides) != self._num_layers: raise ValueError( """strides must be of length 1 or len(output_channels)""") self._paddings = _replicate_elements(paddings, self._num_layers) if len(self._paddings) != self._num_layers: raise ValueError( """paddings must be of length 1 or len(output_channels)""") self._rates = _replicate_elements(rates, self._num_layers) if len(self._rates) != self._num_layers: raise ValueError( """rates must be of length 1 or len(output_channels)""") self._parse_normalization_kwargs( use_batch_norm, batch_norm_config, normalization_ctor, normalization_kwargs) if normalize_final is None: util.deprecation_warning( "normalize_final is not specified, so using the value of " "activate_final = {}. Change your code to set this kwarg explicitly. " "In the future, normalize_final will default to True.".format( activate_final)) self._normalize_final = activate_final else: # User has provided an override, so don't link to activate_final. self._normalize_final = normalize_final if isinstance(use_bias, bool): use_bias = (use_bias,) else: if not isinstance(use_bias, collections.Iterable): raise TypeError("use_bias must be either a bool or an iterable") use_bias = tuple(use_bias) self._use_bias = _replicate_elements(use_bias, self._num_layers) self._instantiate_layers()
def batch_norm_config(self): util.deprecation_warning( "The `.batch_norm_config` property is deprecated. Check " "`.normalization_kwargs` instead.") return self._normalization_kwargs
def use_batch_norm(self): util.deprecation_warning( "The `.use_batch_norm` property is deprecated. Check " "`.normalization_ctor` instead.") return self._normalization_ctor == batch_norm.BatchNorm