def __init__(self, return_sequences=False, return_state=False, go_backwards=False, stateful=False, time_major=False, **kwargs): # We invoke the base layer's initializer directly here because we do not # want to create RNN cell instance. super(RNN, self).__init__(**kwargs) # pylint: disable=bad-super-call self.return_sequences = return_sequences self.return_state = return_state self.go_backwards = go_backwards self.stateful = stateful self.time_major = time_major self.supports_masking = False self.input_spec = [InputSpec(ndim=3)] if hasattr(self.cell.state_size, '__len__'): state_size = self.cell.state_size else: state_size = [self.cell.state_size] self.state_spec = [InputSpec(shape=(None, dim)) for dim in state_size] self.constants_spec = None self._states = None self._num_constants = 0 self._vector_shape = tf.constant([-1])
def build(self, input_shape): # Note input_shape will be list of shapes of initial states and # constants if these are passed in __call__. if self._num_constants is not None: constants_shape = input_shape[ -self._num_constants : ] # pylint: disable=invalid-unary-operand-type else: constants_shape = None if isinstance(input_shape, list): input_shape = input_shape[0] batch_size = input_shape[0] if self.stateful else None self.input_spec[0] = InputSpec( shape=(batch_size, None) + input_shape[2 : self.rank + 3] ) # allow cell (if layer) to build before we set or validate state_spec if isinstance(self.cell, base_layer.Layer): step_input_shape = (input_shape[0],) + input_shape[2:] if constants_shape is not None: self.cell.build([step_input_shape] + constants_shape) else: self.cell.build(step_input_shape) # set or validate state_spec if hasattr(self.cell.state_size, "__len__"): state_size = list(self.cell.state_size) else: state_size = [self.cell.state_size] if self.state_spec is not None: # initial_state was passed in call, check compatibility if self.cell.data_format == "channels_first": ch_dim = 1 elif self.cell.data_format == "channels_last": ch_dim = self.rank + 1 if [spec.shape[ch_dim] for spec in self.state_spec] != state_size: raise ValueError( "An `initial_state` was passed that is not compatible with " "`cell.state_size`. Received state shapes " f"{[spec.shape for spec in self.state_spec]}. " f"However `cell.state_size` is {self.cell.state_size}" ) else: img_dims = tuple((None for _ in range(self.rank))) if self.cell.data_format == "channels_first": self.state_spec = [ InputSpec(shape=(None, dim) + img_dims) for dim in state_size ] elif self.cell.data_format == "channels_last": self.state_spec = [ InputSpec(shape=(None,) + img_dims + (dim,)) for dim in state_size ] if self.stateful: self.reset_states() self.built = True
def build(self, input_shape): # Note input_shape will be list of shapes of initial states and # constants if these are passed in __call__. if self._num_constants is not None: constants_shape = input_shape[-self._num_constants:] # pylint: disable=E1130 else: constants_shape = None if isinstance(input_shape, list): input_shape = input_shape[0] batch_size = input_shape[0] if self.stateful else None self.input_spec[0] = InputSpec(shape=(batch_size, None) + input_shape[2:5]) # allow cell (if layer) to build before we set or validate state_spec if isinstance(self.cell, Layer): step_input_shape = (input_shape[0], ) + input_shape[2:] if constants_shape is not None: self.cell.build([step_input_shape] + constants_shape) else: self.cell.build(step_input_shape) # set or validate state_spec if hasattr(self.cell.state_size, '__len__'): state_size = list(self.cell.state_size) else: state_size = [self.cell.state_size] if self.state_spec is not None: # initial_state was passed in call, check compatibility if self.cell.data_format == 'channels_first': ch_dim = 1 elif self.cell.data_format == 'channels_last': ch_dim = 3 if [spec.shape[ch_dim] for spec in self.state_spec] != state_size: raise ValueError( 'An initial_state was passed that is not compatible with ' '`cell.state_size`. Received `state_spec`={}; ' 'However `cell.state_size` is ' '{}'.format([spec.shape for spec in self.state_spec], self.cell.state_size)) else: if self.cell.data_format == 'channels_first': self.state_spec = [ InputSpec(shape=(None, dim, None, None)) for dim in state_size ] elif self.cell.data_format == 'channels_last': self.state_spec = [ InputSpec(shape=(None, None, None, dim)) for dim in state_size ] if self.stateful: self.reset_states() self.built = True
def __init__(self, cropping=(1, 1), **kwargs): super().__init__(**kwargs) self.cropping = conv_utils.normalize_tuple(cropping, 2, "cropping", allow_zero=True) self.input_spec = InputSpec(ndim=3)
def __init__(self, cropping=((0, 0), (0, 0)), data_format=None, **kwargs): super().__init__(**kwargs) self.data_format = conv_utils.normalize_data_format(data_format) if isinstance(cropping, int): self.cropping = ((cropping, cropping), (cropping, cropping)) elif hasattr(cropping, "__len__"): if len(cropping) != 2: raise ValueError( "`cropping` should have two elements. " f"Received: {cropping}." ) height_cropping = conv_utils.normalize_tuple( cropping[0], 2, "1st entry of cropping", allow_zero=True ) width_cropping = conv_utils.normalize_tuple( cropping[1], 2, "2nd entry of cropping", allow_zero=True ) self.cropping = (height_cropping, width_cropping) else: raise ValueError( "`cropping` should be either an int, " "a tuple of 2 ints " "(symmetric_height_crop, symmetric_width_crop), " "or a tuple of 2 tuples of 2 ints " "((top_crop, bottom_crop), (left_crop, right_crop)). " f"Received: {cropping}." ) self.input_spec = InputSpec(ndim=4)
def build(self, input_shape): dtype = tf.as_dtype(self.dtype or backend.floatx()) if not (dtype.is_floating or dtype.is_complex): raise TypeError( 'A Dense layer can only be built with a floating-point ' f'dtype. Received: dtype={dtype}') input_shape = tf.TensorShape(input_shape) last_dim = tf.compat.dimension_value(input_shape[-1]) if last_dim is None: raise ValueError( 'The last dimension of the inputs to a Dense layer ' 'should be defined. Found None. ' f'Full input shape received: {input_shape}') self.input_spec = InputSpec(min_ndim=2, axes={-1: last_dim}) self.kernel = self.add_weight('kernel', shape=[last_dim, self.units], initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint, dtype=self.dtype, trainable=True) if self.use_bias: self.bias = self.add_weight('bias', shape=[ self.units, ], initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint, dtype=self.dtype, trainable=True) else: self.bias = None self.built = True
def __init__(self, factor, fill_mode='reflect', interpolation='bilinear', seed=None, fill_value=0.0, **kwargs): self.factor = factor if isinstance(factor, (tuple, list)): self.lower = factor[0] self.upper = factor[1] else: self.lower = -factor self.upper = factor if self.upper < self.lower: raise ValueError('Factor cannot have negative values, ' 'got {}'.format(factor)) check_fill_mode_and_interpolation(fill_mode, interpolation) self.fill_mode = fill_mode self.fill_value = fill_value self.interpolation = interpolation self.seed = seed self._rng = make_generator(self.seed) self.input_spec = InputSpec(ndim=4) super(RandomRotation, self).__init__(**kwargs) base_preprocessing_layer.keras_kpl_gauge.get_cell('RandomRotation').set( True)
def __init__(self, units, activation=None, use_bias=True, kernel_initializer='glorot_uniform', bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, **kwargs): super(Dense, self).__init__(activity_regularizer=activity_regularizer, **kwargs) self.units = int(units) if not isinstance(units, int) else units if self.units < 0: raise ValueError( f'Received an invalid value for `units`, expected ' f'a positive integer. Received: units={units}') self.activation = activations.get(activation) self.use_bias = use_bias self.kernel_initializer = initializers.get(kernel_initializer) self.bias_initializer = initializers.get(bias_initializer) self.kernel_regularizer = regularizers.get(kernel_regularizer) self.bias_regularizer = regularizers.get(bias_regularizer) self.kernel_constraint = constraints.get(kernel_constraint) self.bias_constraint = constraints.get(bias_constraint) self.input_spec = InputSpec(min_ndim=2) self.supports_masking = True
def __init__(self, size=(2, 2), data_format=None, interpolation='nearest', **kwargs): super(UpSampling2D, self).__init__(**kwargs) self.data_format = conv_utils.normalize_data_format(data_format) self.size = conv_utils.normalize_tuple(size, 2, 'size') interpolations = { 'area': tf.image.ResizeMethod.AREA, 'bicubic': tf.image.ResizeMethod.BICUBIC, 'bilinear': tf.image.ResizeMethod.BILINEAR, 'gaussian': tf.image.ResizeMethod.GAUSSIAN, 'lanczos3': tf.image.ResizeMethod.LANCZOS3, 'lanczos5': tf.image.ResizeMethod.LANCZOS5, 'mitchellcubic': tf.image.ResizeMethod.MITCHELLCUBIC, 'nearest': tf.image.ResizeMethod.NEAREST_NEIGHBOR, } interploations_list = '"' + '", "'.join(interpolations.keys()) + '"' if interpolation not in interpolations: raise ValueError( '`interpolation` argument should be one of: ' f'{interploations_list}. Received: "{interpolation}".') self.interpolation = interpolation self.input_spec = InputSpec(ndim=4)
def build(self, input_shape): input_shape = tf.TensorShape(input_shape) if len(input_shape) != 4: raise ValueError('Inputs should have rank 4. ' f'Received input_shape={input_shape}.') channel_axis = self._get_channel_axis() if input_shape.dims[channel_axis].value is None: raise ValueError( 'The channel dimension of the inputs ' 'to `Conv2DTranspose` should be defined. ' f'The input_shape received is {input_shape}, ' f'where axis {channel_axis} (0-based) ' 'is the channel dimension, which found to be `None`.') input_dim = int(input_shape[channel_axis]) self.input_spec = InputSpec(ndim=4, axes={channel_axis: input_dim}) kernel_shape = self.kernel_size + (self.filters, input_dim) self.kernel = self.add_weight(name='kernel', shape=kernel_shape, initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint, trainable=True, dtype=self.dtype) if self.use_bias: self.bias = self.add_weight(name='bias', shape=(self.filters, ), initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint, trainable=True, dtype=self.dtype) else: self.bias = None self.built = True
def __init__(self, size=(2, 2), data_format=None, interpolation="nearest", **kwargs): super().__init__(**kwargs) self.data_format = conv_utils.normalize_data_format(data_format) self.size = conv_utils.normalize_tuple(size, 2, "size") interpolations = { "area": tf.image.ResizeMethod.AREA, "bicubic": tf.image.ResizeMethod.BICUBIC, "bilinear": tf.image.ResizeMethod.BILINEAR, "gaussian": tf.image.ResizeMethod.GAUSSIAN, "lanczos3": tf.image.ResizeMethod.LANCZOS3, "lanczos5": tf.image.ResizeMethod.LANCZOS5, "mitchellcubic": tf.image.ResizeMethod.MITCHELLCUBIC, "nearest": tf.image.ResizeMethod.NEAREST_NEIGHBOR, } interploations_list = '"' + '", "'.join(interpolations.keys()) + '"' if interpolation not in interpolations: raise ValueError( "`interpolation` argument should be one of: " f'{interploations_list}. Received: "{interpolation}".') self.interpolation = interpolation self.input_spec = InputSpec(ndim=4)
def __init__(self, padding=1, **kwargs): super().__init__(**kwargs) self.padding = conv_utils.normalize_tuple(padding, 2, 'padding', allow_zero=True) self.input_spec = InputSpec(ndim=3)
def __init__(self, n, **kwargs): super(RepeatVector, self).__init__(**kwargs) self.n = n if not isinstance(n, int): raise TypeError( f'Expected an integer value for `n`, got {type(n)}.') self.input_spec = InputSpec(ndim=2)
def __init__(self, cropping=((1, 1), (1, 1), (1, 1)), data_format=None, **kwargs): super(Cropping3D, self).__init__(**kwargs) self.data_format = conv_utils.normalize_data_format(data_format) if isinstance(cropping, int): self.cropping = ((cropping, cropping), (cropping, cropping), (cropping, cropping)) elif hasattr(cropping, '__len__'): if len(cropping) != 3: raise ValueError('`cropping` should have 3 elements. ' f'Received: {cropping}.') dim1_cropping = conv_utils.normalize_tuple( cropping[0], 2, '1st entry of cropping', allow_zero=True) dim2_cropping = conv_utils.normalize_tuple( cropping[1], 2, '2nd entry of cropping', allow_zero=True) dim3_cropping = conv_utils.normalize_tuple( cropping[2], 2, '3rd entry of cropping', allow_zero=True) self.cropping = (dim1_cropping, dim2_cropping, dim3_cropping) else: raise ValueError( '`cropping` should be either an int, ' 'a tuple of 3 ints ' '(symmetric_dim1_crop, symmetric_dim2_crop, symmetric_dim3_crop), ' 'or a tuple of 3 tuples of 2 ints ' '((left_dim1_crop, right_dim1_crop),' ' (left_dim2_crop, right_dim2_crop),' ' (left_dim3_crop, right_dim2_crop)). ' f'Received: {cropping}.') self.input_spec = InputSpec(ndim=5)
def __init__(self, padding=(1, 1), data_format=None, **kwargs): super().__init__(**kwargs) self.data_format = conv_utils.normalize_data_format(data_format) if isinstance(padding, int): self.padding = ((padding, padding), (padding, padding)) elif hasattr(padding, '__len__'): if len(padding) != 2: raise ValueError('`padding` should have two elements. ' f'Received: {padding}.') height_padding = conv_utils.normalize_tuple(padding[0], 2, '1st entry of padding', allow_zero=True) width_padding = conv_utils.normalize_tuple(padding[1], 2, '2nd entry of padding', allow_zero=True) self.padding = (height_padding, width_padding) else: raise ValueError('`padding` should be either an int, ' 'a tuple of 2 ints ' '(symmetric_height_pad, symmetric_width_pad), ' 'or a tuple of 2 tuples of 2 ints ' '((top_pad, bottom_pad), (left_pad, right_pad)). ' f'Received: {padding}.') self.input_spec = InputSpec(ndim=4)
def __init__(self, factor, interpolation='bilinear', seed=None, **kwargs): self.factor = factor if isinstance(factor, (tuple, list)): self.width_lower = factor[0] self.width_upper = factor[1] else: self.width_lower = -factor self.width_upper = factor if self.width_upper < self.width_lower: raise ValueError('`factor` cannot have upper bound less than ' 'lower bound, got {}'.format(factor)) if self.width_lower < -1. or self.width_upper < -1.: raise ValueError('`factor` must have values larger than -1, ' 'got {}'.format(factor)) self.interpolation = interpolation self._interpolation_method = get_interpolation(interpolation) self.input_spec = InputSpec(ndim=4) self.seed = seed self._rng = make_generator(self.seed) super(RandomWidth, self).__init__(**kwargs) base_preprocessing_layer.keras_kpl_gauge.get_cell('RandomWidth').set(True)
def __init__(self, rank, cell, return_sequences=False, return_state=False, go_backwards=False, stateful=False, unroll=False, **kwargs): if unroll: raise TypeError( 'Unrolling is not possible with convolutional RNNs. ' f'Received: unroll={unroll}') if isinstance(cell, (list, tuple)): # The StackedConvRNN3DCells isn't implemented yet. raise TypeError('It is not possible at the moment to' 'stack convolutional cells. Only pass a single cell ' 'instance as the `cell` argument. Received: ' f'cell={cell}') super().__init__(cell, return_sequences, return_state, go_backwards, stateful, unroll, **kwargs) self.rank = rank self.input_spec = [InputSpec(ndim=rank + 3)] self.states = None self._num_constants = None
def __init__(self, units, activation='tanh', use_bias=True, kernel_initializer='glorot_uniform', recurrent_initializer='orthogonal', bias_initializer='zeros', kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, recurrent_constraint=None, bias_constraint=None, dropout=0., recurrent_dropout=0., return_sequences=False, return_state=False, go_backwards=False, stateful=False, unroll=False, **kwargs): if 'implementation' in kwargs: kwargs.pop('implementation') logging.warning('The `implementation` argument ' 'in `SimpleRNN` has been deprecated. ' 'Please remove it from your layer call.') if 'enable_caching_device' in kwargs: cell_kwargs = { 'enable_caching_device': kwargs.pop('enable_caching_device') } else: cell_kwargs = {} cell = SimpleRNNCell(units, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer, recurrent_initializer=recurrent_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, recurrent_regularizer=recurrent_regularizer, bias_regularizer=bias_regularizer, kernel_constraint=kernel_constraint, recurrent_constraint=recurrent_constraint, bias_constraint=bias_constraint, dropout=dropout, recurrent_dropout=recurrent_dropout, dtype=kwargs.get('dtype'), trainable=kwargs.get('trainable', True), **cell_kwargs) super().__init__(cell, return_sequences=return_sequences, return_state=return_state, go_backwards=go_backwards, stateful=stateful, unroll=unroll, **kwargs) self.activity_regularizer = regularizers.get(activity_regularizer) self.input_spec = [InputSpec(ndim=3)]
def __init__(self, height, width, seed=None, **kwargs): self.height = height self.width = width self.seed = seed self._rng = make_generator(self.seed) self.input_spec = InputSpec(ndim=4) super(RandomCrop, self).__init__(**kwargs) base_preprocessing_layer.keras_kpl_gauge.get_cell('RandomCrop').set(True)
def __init__(self, dims, **kwargs): super(Permute, self).__init__(**kwargs) self.dims = tuple(dims) if sorted(dims) != list(range(1, len(dims) + 1)): raise ValueError( 'Invalid permutation argument `dims` for Permute Layer. ' 'The set of indices in `dims` must be consecutive and start from 1. ' f'Received dims={dims}') self.input_spec = InputSpec(ndim=len(self.dims) + 1)
def __init__(self, rate, data_format=None, **kwargs): super(SpatialDropout3D, self).__init__(rate, **kwargs) if data_format is None: data_format = backend.image_data_format() if data_format not in {'channels_last', 'channels_first'}: raise ValueError( f'`data_format` must be "channels_last" or "channels_first". ' f'Received: data_format={data_format}.') self.data_format = data_format self.input_spec = InputSpec(ndim=5)
def __init__(self, rate, data_format=None, **kwargs): super().__init__(rate, **kwargs) if data_format is None: data_format = backend.image_data_format() if data_format not in {"channels_last", "channels_first"}: raise ValueError( f'`data_format` must be "channels_last" or "channels_first". ' f"Received: data_format={data_format}.") self.data_format = data_format self.input_spec = InputSpec(ndim=5)
def build(self, input_shape): input_shape = tf.TensorShape(input_shape) channel_axis = self._get_channel_axis() if input_shape.dims[channel_axis].value is None: raise ValueError( "The channel dimension of the inputs should be defined. " f"The input_shape received is {input_shape}, " f"where axis {channel_axis} (0-based) " "is the channel dimension, which found to be `None`." ) input_dim = int(input_shape[channel_axis]) self.input_spec = InputSpec( ndim=self.rank + 2, axes={channel_axis: input_dim} ) depthwise_kernel_shape = self.kernel_size + ( input_dim, self.depth_multiplier, ) pointwise_kernel_shape = (1,) * self.rank + ( self.depth_multiplier * input_dim, self.filters, ) self.depthwise_kernel = self.add_weight( name="depthwise_kernel", shape=depthwise_kernel_shape, initializer=self.depthwise_initializer, regularizer=self.depthwise_regularizer, constraint=self.depthwise_constraint, trainable=True, dtype=self.dtype, ) self.pointwise_kernel = self.add_weight( name="pointwise_kernel", shape=pointwise_kernel_shape, initializer=self.pointwise_initializer, regularizer=self.pointwise_regularizer, constraint=self.pointwise_constraint, trainable=True, dtype=self.dtype, ) if self.use_bias: self.bias = self.add_weight( name="bias", shape=(self.filters,), initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint, trainable=True, dtype=self.dtype, ) else: self.bias = None self.built = True
def __init__(self, height, width, interpolation='bilinear', **kwargs): self.target_height = height self.target_width = width self.interpolation = interpolation self._interpolation_method = get_interpolation(interpolation) self.input_spec = InputSpec(ndim=4) super(Resizing, self).__init__(**kwargs) base_preprocessing_layer.keras_kpl_gauge.get_cell('Resizing').set(True)
def get_input_spec(shape): """Convert input shape to InputSpec.""" if isinstance(shape, tf.TensorShape): input_spec_shape = shape.as_list() else: input_spec_shape = list(shape) batch_index, time_step_index = (1, 0) if self.time_major else (0, 1) if not self.stateful: input_spec_shape[batch_index] = None input_spec_shape[time_step_index] = None return InputSpec(shape=tuple(input_spec_shape))
def __init__(self, size=(2, 2), data_format=None, interpolation='nearest', **kwargs): super(UpSampling2D, self).__init__(**kwargs) self.data_format = conv_utils.normalize_data_format(data_format) self.size = conv_utils.normalize_tuple(size, 2, 'size') if interpolation not in {'nearest', 'bilinear'}: raise ValueError( '`interpolation` argument should be one of `"nearest"` ' f'or `"bilinear"`. Received: "{interpolation}".') self.interpolation = interpolation self.input_spec = InputSpec(ndim=4)
def __init__(self, max_tokens=None, output_mode=BINARY, sparse=False, **kwargs): # 'output_mode' must be one of (COUNT, BINARY, TFIDF) layer_utils.validate_string_arg(output_mode, allowable_strings=(COUNT, BINARY, TFIDF), layer_name="CategoryEncoding", arg_name="output_mode") # If max_tokens is set, the value must be greater than 1 - otherwise we # are creating a 0-element vocab, which doesn't make sense. if max_tokens is not None and max_tokens < 1: raise ValueError("max_tokens must be > 1.") # We need to call super() before we call _add_state_variable(). combiner = _CategoryEncodingCombiner(max_tokens=max_tokens, compute_idf=output_mode == TFIDF) super(CategoryEncoding, self).__init__(combiner=combiner, **kwargs) base_preprocessing_layer.keras_kpl_gauge.get_cell( "CategoryEncoding").set(True) self.max_tokens = max_tokens self.output_mode = output_mode self.sparse = sparse self._called = False if self.output_mode == TFIDF: # The TF-IDF weight may have a (None,) tensorshape. This creates # a 1D variable with arbitrary shape, which we can assign any weight to # so long as it has 1 dimension. In order to properly initialize this # weight in Keras, we need to provide a custom callable initializer which # does not depend on the shape of the weight (as all other initializers # do) since the weight is not known. Hence the lambda shape, dtype: [0]. if max_tokens is None: initializer = lambda shape, dtype: [0] else: initializer = tf.compat.v1.zeros_initializer # We are adding these here instead of in build() since they do not depend # on the input shape at all. self.tf_idf_weights = self._add_state_variable( name=_IDF_NAME, shape=tf.TensorShape((max_tokens, )), dtype=K.floatx(), initializer=initializer) self.input_spec = InputSpec(ndim=2)
def __init__(self, pool_function, pool_size, strides, padding='valid', data_format='channels_last', name=None, **kwargs): super(Pooling1D, self).__init__(name=name, **kwargs) if data_format is None: data_format = backend.image_data_format() if strides is None: strides = pool_size self.pool_function = pool_function self.pool_size = conv_utils.normalize_tuple(pool_size, 1, 'pool_size') self.strides = conv_utils.normalize_tuple( strides, 1, 'strides', allow_zero=True) self.padding = conv_utils.normalize_padding(padding) self.data_format = conv_utils.normalize_data_format(data_format) self.input_spec = InputSpec(ndim=3)
def __init__(self, factor, seed=None, **kwargs): self.factor = factor if isinstance(factor, (tuple, list)): self.lower = factor[0] self.upper = factor[1] else: self.lower = self.upper = factor if self.lower < 0. or self.upper < 0. or self.lower > 1.: raise ValueError('Factor cannot have negative values or greater than 1.0,' ' got {}'.format(factor)) self.seed = seed self.input_spec = InputSpec(ndim=4) super(RandomContrast, self).__init__(**kwargs) base_preprocessing_layer.keras_kpl_gauge.get_cell('RandomContrast').set( True)
def build(self, input_shape): input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False) input_dims = tf.nest.flatten( tf.nest.map_structure(lambda x: x.ndims, input_shape)) if any(dim < 3 for dim in input_dims): raise ValueError( '`TimeDistributed` Layer should be passed an `input_shape ` ' 'with at least 3 dimensions, received: ' + str(input_shape)) # Don't enforce the batch or time dimension. self.input_spec = tf.nest.map_structure( lambda x: InputSpec(shape=[None, None] + x.as_list()[2:]), input_shape) child_input_shape = tf.nest.map_structure(self._remove_timesteps, input_shape) child_input_shape = tf_utils.convert_shapes(child_input_shape) super(TimeDistributed, self).build(tuple(child_input_shape)) self.built = True