def __init__( self, experts, activity_regularizer=None, gate_use_bias=False, gate_kernel_initializer="glorot_uniform", gate_bias_initializer="zeros", gate_kernel_regularizer=None, gate_bias_regularizer=None, gate_kernel_constraint=None, gate_bias_constraint=None, **kwargs, ): super(ContextualMixture, self).__init__( activity_regularizer=regularizers.get(activity_regularizer), **kwargs, ) # Sanity check. self.experts = tuple(experts) for expert in self.experts: if not isinstance(expert, tf.Module): raise ValueError( "Please initialize `{name}` expert with a " "`tf.Module` instance. You passed: {input}".format( name=self.__class__.__name__, input=expert)) # Regularizers and constraints for the weight generator. self.gate_use_bias = gate_use_bias self.gate_kernel_initializer = initializers.get( gate_kernel_initializer) self.gate_bias_initializer = initializers.get(gate_bias_initializer) self.gate_kernel_regularizer = regularizers.get( gate_kernel_regularizer) self.gate_bias_regularizer = regularizers.get(gate_bias_regularizer) self.gate_kernel_constraint = constraints.get(gate_kernel_constraint) self.gate_bias_constraint = constraints.get(gate_bias_constraint) self.supports_masking = True self.input_spec = [ InputSpec(min_ndim=2), # Context input spec. InputSpec(min_ndim=2), # Features input spec. ] # Instantiate contextual attention for gating. self.gating_attention = Dense(len(self.experts), activation=tf.nn.softmax, name="attention") # Internals. self.context_shape = None self.feature_shape = None
def build(self, input_shape): dtype = dtypes.as_dtype(self.dtype or K.floatx()) if not (dtype.is_floating or dtype.is_complex): raise TypeError('Unable to build `Dense` layer with non-floating point ' 'dtype %s' % (dtype,)) input_shape = tensor_shape.TensorShape(input_shape) if tensor_shape.dimension_value(input_shape[-1]) is None: raise ValueError('The last dimension of the inputs to `Dense` ' 'should be defined. Found `None`.') last_dim = tensor_shape.dimension_value(input_shape[-1]) self.input_spec = InputSpec(min_ndim=2, axes={-1: last_dim}) if self.tied_to is not None: self.kernel = K.transpose(self.tied_to.weights[0]) else: self.kernel = self.add_weight( 'kernel', shape=[last_dim, self.units], initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint, dtype=self.dtype, trainable=True) if self.use_bias: self.bias = self.add_weight( 'bias', shape=[self.units,], initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint, dtype=self.dtype, trainable=True) else: self.bias = None self.built = True
def __init__(self, mask, **kwargs): if "input_shape" not in kwargs and "input_dim" in kwargs: kwargs["input_shape"] = (kwargs.pop("input_dim")) super().__init__(**kwargs) self.mask = mask self.input_spec = InputSpec(ndim=2) self.units = np.sum(self.mask)
def build(self, input_shape): ndim = len(input_shape) if self.axis == 0: raise ValueError('Axis cannot be zero') if (self.axis is not None) and (ndim == 2): raise ValueError('Cannot specify axis for rank 1 tensor') self.input_spec = InputSpec(ndim=ndim) if self.axis is None: shape = (1, ) else: shape = (input_shape[self.axis], ) if self.scale: self.gamma = self.add_weight(shape=shape, name='gamma', initializer=self.gamma_initializer, regularizer=self.gamma_regularizer, constraint=self.gamma_constraint) else: self.gamma = None if self.center: self.beta = self.add_weight(shape=shape, name='beta', initializer=self.beta_initializer, regularizer=self.beta_regularizer, constraint=self.beta_constraint) else: self.beta = None self.built = True
def _dense_layer_input_spec(input_shape): input_shape = tensor_shape.TensorShape(input_shape) if _dimension_value(input_shape[-1]) is None: raise ValueError('The last dimension of the inputs to `Dense` ' 'should be defined. Found `None`.') last_dim = _dimension_value(input_shape[-1]) return last_dim, InputSpec(min_ndim=2, axes={-1: last_dim})
def __init__(self, units, activation=None, activity_regularizer=None, init_mu_current = 'spread', init_sigma_current=0.1, init_mu_prev = 'spread', init_sigma_prev=0.1, verbose=True, si_regularizer=None, train_mu=True,train_sigma=True, train_weights=True, normed=2, init_w=initializers.glorot_uniform(), **kwargs): self.units = units self.state_size = units self.verbose = verbose self.si_regularizer = regularizers.get(si_regularizer) self.init_mu_current=init_mu_current self.init_mu_prev=init_mu_prev self.init_sigma_current=init_sigma_current self.init_sigma_prev=init_sigma_prev self.train_mu = train_mu self.train_sigma = train_sigma self.train_weights = train_weights self.normed = normed self.activity_regularizer = regularizers.get(activity_regularizer) self.input_spec = InputSpec(min_ndim=2) self.init_weights =init_w self.activation = activations.get(activation) super(MinimalRNNFocusedCell, self).__init__(**kwargs)
def build(self, input_shape): print("--Scale--build", input_shape) self.input_spec = [InputSpec(shape=input_shape)] # 1:InputSpec(dtype=None, shape=None, ndim=None, max_ndim=None, min_ndim=None, axes=None) #Docstring: #Specifies the ndim, dtype and shape of every input to a layer. #Every layer should expose (if appropriate) an `input_spec` attribute:a list of instances of InputSpec (one per input tensor). #A None entry in a shape is compatible with any dimension #A None shape is compatible with any shape. # 2:self.input_spec: List of InputSpec class instances # each entry describes one required input: # - ndim # - dtype # A layer with `n` input tensors must have # an `input_spec` of length `n`. shape = (int(input_shape[self.axis]), ) # Compatibility with TensorFlow >= 1.0.0 self.gamma = K.variable(self.gamma_init(shape), name='{}_gamma'.format(self.name)) self.beta = K.variable(self.beta_init(shape), name='{}_beta'.format(self.name)) self.trainable_weights = [self.gamma, self.beta] if self.initial_weights is not None: self.set_weights(self.initial_weights) del self.initial_weights super(Scale, self).build(input_shape)
def build(self, input_shape): assert len(input_shape) >= 2 self.input_dim = input_shape[-1] if self.factorised: init_mu = 1. / np.sqrt(self.input_dim.value) init_sig = 0.5 / np.sqrt(self.input_dim.value) else: init_mu = np.sqrt(3. / self.input_dim.value) init_sig = 0.017 self.kernel_mu = self.add_weight(shape=(self.input_dim, self.units), initializer=RandomUniform( -init_mu, init_mu), name='kernel_mu') self.kernel_sigma = self.add_weight(shape=(self.input_dim, self.units), initializer=RandomUniform( -init_sig, init_sig), name='kernel_sigma') if self.use_bias: self.bias_mu = self.add_weight(shape=(self.units, ), initializer=RandomUniform( -init_mu, init_mu), name='bias_mu') self.bias_sigma = self.add_weight(shape=(self.units, ), initializer=RandomUniform( -init_sig, init_sig), name='bias_sigma') self.input_spec = InputSpec(min_ndim=2, axes={-1: self.input_dim}) self.built = True
def __init__(self, padding=(1, 1, 1), data_format=None, **kwargs): super(ReflectionPadding3D, self).__init__(**kwargs) self.data_format = conv_utils.normalize_data_format(data_format) if isinstance(padding, int): self.padding = ((padding, padding), (padding, padding), (padding, padding)) elif hasattr(padding, '__len__'): if len(padding) != 3: raise ValueError('`padding` should have 3 elements. ' 'Found: ' + str(padding)) dim1_padding = conv_utils.normalize_tuple(padding[0], 2, '1st entry of padding') dim2_padding = conv_utils.normalize_tuple(padding[1], 2, '2nd entry of padding') dim3_padding = conv_utils.normalize_tuple(padding[2], 2, '3rd entry of padding') self.padding = (dim1_padding, dim2_padding, dim3_padding) else: raise ValueError( '`padding` should be either an int, ' 'a tuple of 3 ints ' '(symmetric_dim1_pad, symmetric_dim2_pad, symmetric_dim3_pad), ' 'or a tuple of 3 tuples of 2 ints ' '((left_dim1_pad, right_dim1_pad),' ' (left_dim2_pad, right_dim2_pad),' ' (left_dim3_pad, right_dim2_pad)). ' 'Found: ' + str(padding)) self.input_spec = InputSpec(ndim=5)
def __init__(self, output_dim, data_format=None, activation=None, use_bias=True, kernel_initializer='glorot_uniform', bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, **kwargs): if 'input_shape' not in kwargs and 'input_dim' in kwargs: kwargs['input_shape'] = (kwargs.pop('input_dim'), ) super(TensorProduct, self).__init__( activity_regularizer=regularizers.get(activity_regularizer), **kwargs) self.output_dim = int(output_dim) self.data_format = conv_utils.normalize_data_format(data_format) self.activation = activations.get(activation) self.use_bias = use_bias self.kernel_initializer = initializers.get(kernel_initializer) self.bias_initializer = initializers.get(bias_initializer) self.kernel_regularizer = regularizers.get(kernel_regularizer) self.bias_regularizer = regularizers.get(bias_regularizer) self.activity_regularizer = regularizers.get(activity_regularizer) self.kernel_constraint = constraints.get(kernel_constraint) self.bias_constraint = constraints.get(bias_constraint) self.supports_masking = True self.input_spec = InputSpec(min_ndim=2)
def __init__(self, input_dim, output_dim, data_format=None, activation=None, use_bias=True, kernel_initializer='glorot_uniform', bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, **kwargs): super(TensorProd2D, self).__init__(**kwargs) self.input_dim = input_dim self.output_dim = output_dim self.data_format = conv_utils.normalize_data_format(data_format) self.activation = activations.get(activation) self.use_bias = use_bias self.kernel_initializer = initializers.get(kernel_initializer) self.bias_initializer = initializers.get(bias_initializer) self.kernel_regularizer = regularizers.get(kernel_regularizer) self.bias_regularizer = regularizers.get(bias_regularizer) self.activity_regularizer = regularizers.get(activity_regularizer) self.kernel_constraint = constraints.get(kernel_constraint) self.bias_constraint = constraints.get(bias_constraint) self.input_spec = InputSpec(min_ndim=2)
def build(self, input_shape): self.init_layers(input_shape) intermediate_shape = input_shape with tf.name_scope("residual_basic_block_weights"): for i in range(self.depth): self.conv_layers[i].build(intermediate_shape) intermediate_shape = self.conv_layers[i].compute_output_shape(intermediate_shape) if self.projection_layer is not None: self.projection_layer.build(input_shape) self.residual_multiplier = self.add_weight(name="residual_multiplier", shape=[], dtype=backend.floatx(), initializer=tf.ones_initializer) if self.use_bias: for i in range(self.depth): conv_bias = self.add_weight(name="conv_bias", shape=[], dtype=backend.floatx(), initializer=tf.zeros_initializer) self.conv_biases.append(conv_bias) if i < (self.depth - 1): activation_bias = self.add_weight(name="activation_bias", shape=[], dtype=backend.floatx(), initializer=tf.zeros_initializer) self.activation_biases.append(activation_bias) self.residual_bias = self.add_weight(name="residual_bias", shape=[], dtype=backend.floatx(), initializer=tf.zeros_initializer) self.input_spec = InputSpec(ndim=self.rank + 2, axes={self.channel_axis: input_shape[self.channel_axis]}) super(ResBasicBlockND, self).build(input_shape)
def build(self, input_shape): self.init_layers() input_dim = input_shape[self.channel_axis] self.input_spec = InputSpec(ndim=self.rank + 2, axes={self.channel_axis: input_dim}) super(DenseBlockND, self).build(input_shape)
def __init__(self, numin, **kwargs): if "input_shape" not in kwargs and "input_dim" in kwargs: kwargs["input_shape"] = (kwargs.pop("input_dim")) super().__init__(**kwargs) self.numin = numin self.input_spec = InputSpec(ndim=2) self.units = 1
def build(self, input_shape): assert len(input_shape) >= 2 input_dim = input_shape[self.axis] self.input_spec = InputSpec(min_ndim=self.rank + 2, axes={self.axis: input_dim}) self.built = True
def __init__(self, padding=(1, 1), data_format=None, **kwargs): # self.padding = conv_utils.normalize_tuple(padding, 2, 'padding') # self.input_spec = [InputSpec(ndim=4)] # self.data_format = conv_utils.normalize_data_format(data_format) super(ReflectionPadding2D, self).__init__(**kwargs) self.data_format = conv_utils.normalize_data_format(data_format) if isinstance(padding, int): self.padding = ((padding, padding), (padding, padding)) elif hasattr(padding, '__len__'): if len(padding) != 2: raise ValueError('`padding` should have two elements. ' 'Found: ' + str(padding)) height_padding = conv_utils.normalize_tuple(padding[0], 2, '1st entry of padding') width_padding = conv_utils.normalize_tuple(padding[1], 2, '2nd entry of padding') self.padding = (height_padding, width_padding) else: raise ValueError('`padding` should be either an int, ' 'a tuple of 2 ints ' '(symmetric_height_pad, symmetric_width_pad), ' 'or a tuple of 2 tuples of 2 ints ' '((top_pad, bottom_pad), (left_pad, right_pad)). ' 'Found: ' + str(padding)) self.input_spec = InputSpec(ndim=4)
def build(self, input_shape=None): context_shape, feature_shape = input_shape # Sanity checks. self._build_sanity_check(context_shape, feature_shape) # Update input spec. self.context_dim = tensor_shape.dimension_value(context_shape[-1]) self.feature_dim = tensor_shape.dimension_value(feature_shape[-1]) self.input_spec[0] = InputSpec(min_ndim=2, axes={-1: self.context_dim}) self.input_spec[1] = InputSpec(min_ndim=2, axes={-1: self.feature_dim}) # Build contextual weight generator. self.build_weight_generator(context_shape, feature_shape) self.built = True
def build(self, input_shape: tf.TensorShape): if self.rank is None: if input_shape.rank not in (3, 4, 5): raise ValueError("Inputs' rank is invalid. Expected 1, 2 or 3. Found {}.".format(input_shape.rank)) self.rank = input_shape.rank - 2 self.input_spec = InputSpec(shape=input_shape)
def build(self, input_shape): if self.data_format == 'channels_first': channel_axis = 1 else: channel_axis = -1 if input_shape[channel_axis] is None: raise ValueError( 'The channel dimension of the inputs should be defined. Found None' ) input_dim = input_shape[channel_axis] self.kernel = self.add_weight(shape=(input_dim, self.output_dim), initializer=self.kernel_initializer, name='kernel', regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) if self.use_bias: self.bias = self.add_weight(shape=(self.output_dim, ), initializer=self.bias_initializer, name='bias', regularizer=self.bias_regularizer, constraint=self.bias_constraint) else: self.bias = None # Set input spec. self.input_spec = InputSpec(min_ndim=2, axes={channel_axis: input_dim}) self.built = True
def build(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape) if len(input_shape) != 4: raise ValueError( 'Inputs should have rank 4. Received input shape: ' + str(input_shape)) if self.data_format == 'channels_first': channel_axis = 1 else: channel_axis = -1 if input_shape[channel_axis].value is None: raise ValueError('The channel dimension of the inputs ' 'should be defined. Found `None`.') input_dim = int(input_shape[channel_axis]) kernel_shape = (input_dim, self.output_dim) self.kernel = self.add_weight(name='kernel', shape=kernel_shape, initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) if self.use_bias: self.bias = self.add_weight(name='bias', shape=(self.output_dim, ), initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint) else: self.bias = None # Set input spec. self.input_spec = InputSpec(min_ndim=2, axes={channel_axis: input_dim}) self.built = True
def build(self, input_shape): if self.data_format == 'channels_first': channel_axis = 1 else: channel_axis = -1 input_shape = tensor_shape.TensorShape(input_shape) if input_shape[channel_axis].value is None: raise ValueError('The channel dimension of the inputs to ' '`TensorProduct` should be defined. ' 'Found `None`.') input_dim = int(input_shape[channel_axis]) kernel_shape = (input_dim, self.output_dim) self.kernel = self.add_weight('kernel', shape=kernel_shape, initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint, dtype=self.dtype, trainable=True) if self.use_bias: self.bias = self.add_weight('bias', shape=(self.output_dim, ), initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint, dtype=self.dtype, trainable=True) else: self.bias = None # Set input spec. self.input_spec = InputSpec(min_ndim=2, axes={channel_axis: input_dim}) self.built = True
def __init__(self, rank, kernel_size, growth_rate, depth, output_filters=None, use_bottleneck=True, bottleneck_filters_multiplier=4, use_batch_normalization=True, data_format=None, activation="relu", use_bias=True, kernel_initializer="he_normal", bias_initializer="zeros", kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, **kwargs): assert rank in [1, 2, 3] super(DenseBlockND, self).__init__(**kwargs) self.rank = rank self.kernel_size = conv_utils.normalize_tuple(kernel_size, rank, "kernel_size") self.output_filters = output_filters self.growth_rate = growth_rate if use_bottleneck: assert ( depth % 2 ) == 0, "Depth must be a multiple of 2 when using bottlenecks." self._depth = depth // 2 if use_bottleneck else depth self.use_bottleneck = use_bottleneck self.bottleneck_filters_multiplier = bottleneck_filters_multiplier self.use_batch_normalization = use_batch_normalization self.data_format = conv_utils.normalize_data_format(data_format) self.channel_axis = -1 if self.data_format == "channels_last" else 1 self.activation = activations.get(activation) self.use_bias = use_bias self.kernel_initializer = initializers.get(kernel_initializer) self.bias_initializer = initializers.get(bias_initializer) self.kernel_regularizer = regularizers.get(kernel_regularizer) self.bias_regularizer = regularizers.get(bias_regularizer) self.activity_regularizer = regularizers.get(activity_regularizer) self.kernel_constraint = constraints.get(kernel_constraint) self.bias_constraint = constraints.get(bias_constraint) self.composite_function_blocks: Optional[ List[CompositeFunctionBlock]] = None self.transition_layer = None self.input_spec = InputSpec(ndim=self.rank + 2)
def build(self, input_shape): if self.use_projection(input_shape): self.init_projection_layer() self.input_spec = InputSpec( ndim=self.rank + 2, axes={self.channel_axis: input_shape[self.channel_axis]}) super(ResBasicBlockND, self).build(input_shape)
def __init__(self, n_clusters, weights=None, alpha=1.0, **kwargs): if 'input_shape' not in kwargs and 'input_dim' in kwargs: kwargs['input_shape'] = (kwargs.pop('input_dim'),) super(ClusteringLayer, self).__init__(**kwargs) self.n_clusters = n_clusters self.alpha = alpha self.initial_weights = weights self.input_spec = InputSpec(ndim=2)
def __init__(self, size=(2, 2), data_format=None, **kwargs): super(BilinearUpSampling2D, self).__init__(**kwargs) if data_format is None: self.data_format = K.image_data_format() else: self.data_format = conv_utils.normalize_data_format(data_format) self.size = conv_utils.normalize_tuple(size, 2, 'size') self.input_spec = InputSpec(ndim=4)
def __init__(self, units, chunk_size, activation='tanh', recurrent_activation='hard_sigmoid', use_bias=True, kernel_initializer='glorot_uniform', recurrent_initializer='orthogonal', bias_initializer='zeros', unit_forget_bias=True, kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, recurrent_constraint=None, bias_constraint=None, dropout=0., recurrent_dropout=0., implementation=1, return_sequences=False, return_state=False, go_backwards=False, stateful=False, unroll=False, **kwargs): cell = OrderedNeuronLSTMCell( units=units, chunk_size=chunk_size, activation=activation, recurrent_activation=recurrent_activation, use_bias=use_bias, kernel_initializer=kernel_initializer, recurrent_initializer=recurrent_initializer, unit_forget_bias=unit_forget_bias, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, recurrent_regularizer=recurrent_regularizer, bias_regularizer=bias_regularizer, kernel_constraint=kernel_constraint, recurrent_constraint=recurrent_constraint, bias_constraint=bias_constraint, dropout=dropout, recurrent_dropout=recurrent_dropout, implementation=implementation, dtype=kwargs.get('dtype'), trainable=kwargs.get('trainable', True)) super(OrderedNeuronLSTM, self).__init__(cell, return_sequences=return_sequences, return_state=return_state, go_backwards=go_backwards, stateful=stateful, unroll=unroll, **kwargs) self.activity_regularizer = regularizers.get(activity_regularizer) self.input_spec = [InputSpec(ndim=3)]
def __init__(self, rank: int, filters: int, depth: int, kernel_size: Union[int, Tuple, List], strides: Union[int, Tuple, List], data_format: Union[None, AnyStr], dilation_rate: Union[int, Tuple, List], activation: Union[None, AnyStr, Callable], use_residual_bias: bool, use_conv_bias: bool, use_batch_norm: bool, kernel_initializer: Union[Dict, AnyStr, Callable], bias_initializer: Union[Dict, AnyStr, Callable], kernel_regularizer: Union[None, Dict, AnyStr, Callable], bias_regularizer: Union[None, Dict, AnyStr, Callable], activity_regularizer: Union[None, Dict, AnyStr, Callable], kernel_constraint: Union[None, Dict, AnyStr, Callable], bias_constraint: Union[None, Dict, AnyStr, Callable], **kwargs): assert rank in [1, 2, 3] assert depth > 0 super(ResBasicBlockND, self).__init__(**kwargs) self.rank = rank self.filters = filters self.depth = depth self.kernel_size = conv_utils.normalize_tuple(kernel_size, rank, "kernel_size") self.strides = conv_utils.normalize_tuple(strides, rank, "strides") self.data_format = conv_utils.normalize_data_format(data_format) self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, rank, "dilation_rate") self.activation = activations.get(activation) self.use_residual_bias = use_residual_bias self.use_conv_bias = use_conv_bias self.use_conv_bias = use_conv_bias self.use_batch_norm = use_batch_norm self.kernel_initializer = initializers.get(kernel_initializer) self.bias_initializer = initializers.get(bias_initializer) self.kernel_regularizer = regularizers.get(kernel_regularizer) self.bias_regularizer = regularizers.get(bias_regularizer) self.activity_regularizer = regularizers.get(activity_regularizer) self.kernel_constraint = constraints.get(kernel_constraint) self.bias_constraint = constraints.get(bias_constraint) self.conv_layers: List[Layer] = [] self.projection_layer: Optional[Layer] = None self.batch_norm_layers: Optional[List[BatchNormalization]] = None self.residual_multiplier = None self.conv_biases = [] self.activation_biases = [] self.residual_bias = None self.input_spec = InputSpec(ndim=self.rank + 2) self.init_layers()
def build(self, input_shape): self.input_spec = [InputSpec(ndim=3)] assert len(input_shape) == 3 self.w = self.add_weight(shape=(input_shape[2], 1), name='{}_w'.format(self.name), initializer=self.init) self.trainable_weights = [self.w] super().build(input_shape)
def build(self, input_shape): # This currently only works for 4D inputs: assuming (B, H, W, C) self.input_spec = [InputSpec(shape=input_shape)] shape = (self.nb_classes, 1, 1, input_shape[-1]) self.gamma = self.gamma_init(shape, name='{}_gamma'.format(self.name)) self.beta = self.beta_init(shape, name='{}_beta'.format(self.name)) self.trainable_weights = [self.gamma, self.beta] self.built = True
def build(self, input_shape): """ Method for creating the layer weights. :param input_shape: Keras tensor (future input to layer) or list/tuple of Keras tensors to reference for weight shape computations """ assert input_shape is not None and len(input_shape) >= 2 input_dimension = int(input_shape[-1]) # Initialize expert weights (number of input features * number of units per expert * number of experts) self.expert_kernels = self.add_weight( name='expert_kernel', shape=(input_dimension, self.units, self.num_experts), initializer=self.expert_kernel_initializer, regularizer=self.expert_kernel_regularizer, constraint=self.expert_kernel_constraint, ) # Initialize expert bias (number of units per expert * number of experts) if self.use_expert_bias: self.expert_bias = self.add_weight( name='expert_bias', shape=(self.units, self.num_experts), initializer=self.expert_bias_initializer, regularizer=self.expert_bias_regularizer, constraint=self.expert_bias_constraint, ) # Initialize gate weights (number of input features * number of experts * number of tasks) self.gate_kernels = [ self.add_weight(name='gate_kernel_task_{}'.format(i), shape=(input_dimension, self.num_experts), initializer=self.gate_kernel_initializer, regularizer=self.gate_kernel_regularizer, constraint=self.gate_kernel_constraint) for i in range(self.num_tasks) ] # Initialize gate bias (number of experts * number of tasks) if self.use_gate_bias: self.gate_bias = [ self.add_weight(name='gate_bias_task_{}'.format(i), shape=(self.num_experts, ), initializer=self.gate_bias_initializer, regularizer=self.gate_bias_regularizer, constraint=self.gate_bias_constraint) for i in range(self.num_tasks) ] self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dimension}) super(MMoE, self).build(input_shape)