Ejemplo n.º 1
0
 def __init__(self, n_clusters, weights=None, alpha=1.0, **kwargs):
     if 'input_shape' not in kwargs and 'input_dim' in kwargs:
         kwargs['input_shape'] = (kwargs.pop('input_dim'), )
     super(ClusteringLayer, self).__init__(**kwargs)
     self.n_clusters = n_clusters
     self.alpha = alpha
     self.initial_weights = weights
     self.input_spec = InputSpec(
         ndim=2
     )  #ndim: Integer, expected rank of the input. 该方法Specify input,相当于是对input做出一定的说明和限制?例如说ndim, dtype and shape of every input
Ejemplo n.º 2
0
 def build(self, input_shape):
     assert len(input_shape) == 2
     input_dim = input_shape[1]
     self.input_spec = InputSpec(dtype=K.floatx(), shape=(None, input_dim))
     self.clusters = self.add_weight(shape=(self.n_clusters, input_dim), \
         initializer='glorot_uniform', name='clusters')
     if self.initial_weights is not None:
         self.set_weights(self.initial_weights)
         del self.initial_weights
     self.built = True
Ejemplo n.º 3
0
    def build(self, input_shape):
        assert len(input_shape) >= 2
        input_dim = input_shape[-1]

        self.bias = self.add_weight(shape=(1, ),
                                    initializer=self.bias_initializer,
                                    name='bias',
                                    trainable=True)
        self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim})
        self.built = True
Ejemplo n.º 4
0
Archivo: util.py Proyecto: sdfs1231/P4
 def build(self, input_shape):
     dim = input_shape[self.axis]
     if dim is None:
         raise ValueError('Axis ' + str(self.axis) + ' of '
                          'input tensor should have a defined dimension '
                          'but the layer received an input with shape ' +
                          str(input_shape) + '.')
     self.input_spec = InputSpec(ndim=len(input_shape),
                                 axes={self.axis: dim})
     self.built = True
Ejemplo n.º 5
0
    def build(self, input_shape):
        self.input_spec = [InputSpec(shape=input_shape)]

        # Self.scale will
        self.gamma = self.add_weight(name='{}'.format(self.name),
                                     shape=(input_shape[self.channel_axis], ),
                                     initializer=Constant(self.scale),
                                     trainable=True)

        super(L2Norm, self).build(input_shape)
Ejemplo n.º 6
0
 def build(self,
           input_shape):  #weight shape = (n_clusters, input_dim) = (20,10)
     assert len(input_shape
                ) == 2  #input_dim = 10 (encoder o/p dim set in autoencoder)
     input_dim = input_shape[1]
     self.input_spec = InputSpec(dtype=K.floatx(), shape=(None, input_dim))
     self.clusters = self.add_weight((self.n_clusters, input_dim),
                                     initializer='glorot_uniform',
                                     name='clusters')
     self.built = True
Ejemplo n.º 7
0
    def build(self, input_shape):
        """
        Method for creating the layer weights.

        :param input_shape: Keras tensor (future input to layer)
                            or list/tuple of Keras tensors to reference
                            for weight shape computations
        """
        assert input_shape is not None and len(input_shape) >= 2

        input_dimension = input_shape[-1]

        # Initialize expert weights (number of input features * number of units per expert * number of experts)
        self.expert_kernels = self.add_weight(
            name='expert_kernel',
            shape=(input_dimension, self.units, self.num_experts),
            initializer=self.expert_kernel_initializer,
            regularizer=self.expert_kernel_regularizer,
            constraint=self.expert_kernel_constraint,
        )

        # Initialize expert bias (number of units per expert * number of experts)
        if self.use_expert_bias:
            self.expert_bias = self.add_weight(
                name='expert_bias',
                shape=(self.units, self.num_experts),
                initializer=self.expert_bias_initializer,
                regularizer=self.expert_bias_regularizer,
                constraint=self.expert_bias_constraint,
            )

        # Initialize gate weights (number of input features * number of experts * number of tasks)
        self.gate_kernels = [
            self.add_weight(name='gate_kernel_task_{}'.format(i),
                            shape=(input_dimension, self.num_experts),
                            initializer=self.gate_kernel_initializer,
                            regularizer=self.gate_kernel_regularizer,
                            constraint=self.gate_kernel_constraint)
            for i in range(self.num_tasks)
        ]

        # Initialize gate bias (number of experts * number of tasks)
        if self.use_gate_bias:
            self.gate_bias = [
                self.add_weight(name='gate_bias_task_{}'.format(i),
                                shape=(self.num_experts, ),
                                initializer=self.gate_bias_initializer,
                                regularizer=self.gate_bias_regularizer,
                                constraint=self.gate_bias_constraint)
                for i in range(self.num_tasks)
            ]

        self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dimension})

        super(MMoE, self).build(input_shape)
Ejemplo n.º 8
0
    def build(self, input_shape):
        bs, input_length, input_dim = input_shape

        self.controller_input_dim, self.controller_output_dim = controller_input_output_shape(
            input_dim, self.units, self.m_depth, self.n_slots,
            self.shift_range, self.read_heads, self.write_heads)

        # Now that we've calculated the shape of the controller, we have add it to the layer/model.
        if self.controller is None:
            self.controller = Dense(name="controller",
                                    activation='linear',
                                    bias_initializer='zeros',
                                    units=self.controller_output_dim,
                                    input_shape=(bs, input_length,
                                                 self.controller_input_dim))
            self.controller.build(input_shape=(self.batch_size, input_length,
                                               self.controller_input_dim))
            self.controller_with_state = False

        # This is a fixed shift matrix
        self.C = _circulant(self.n_slots, self.shift_range)

        self.trainable_weights = self.controller.trainable_weights

        # We need to declare the number of states we want to carry around.
        # In our case the dimension seems to be 6 (LSTM) or 5 (GRU) or 4 (FF),
        # see self.get_initial_states, those respond to:
        # [old_ntm_output] + [init_M, init_wr, init_ww] +  [init_h] (LSMT and GRU) + [(init_c] (LSTM only))
        # old_ntm_output does not make sense in our world, but is required by the definition of the step function we
        # intend to use.
        # WARNING: What self.state_spec does is only poorly understood,
        # I only copied it from keras/recurrent.py.
        self.states = [None, None, None, None]
        self.state_spec = [
            InputSpec(shape=(None, self.output_dim)),  # old_ntm_output
            InputSpec(shape=(None, self.n_slots, self.m_depth)),  # Memory
            InputSpec(shape=(None, self.read_heads,
                             self.n_slots)),  # weights_read
            InputSpec(shape=(None, self.write_heads, self.n_slots))
        ]  # weights_write

        super(NeuralTuringMachine, self).build(input_shape)
Ejemplo n.º 9
0
    def build(self, input_shape):
        self.input_spec = [InputSpec(shape=input_shape)]
        shape = (int(input_shape[self.axis]),)

        self.gamma = K.variable(self.gamma_init(shape), name='%s_gamma' % self.name)
        self.beta = K.variable(self.beta_init(shape), name='%s_beta' % self.name)
        self.trainable_weights = [self.gamma, self.beta]

        if self.initial_weights is not None:
            self.set_weights(self.initial_weights)
            del self.initial_weights
Ejemplo n.º 10
0
    def _build(self, input_shape):
        if not self.built:
            if input_shape[0][1] != input_shape[1][1]:
                raise ValueError("The number of capsules must be the same in diss and signals. You provide "
                                 + str(input_shape[0][1]) + "!=" + str(input_shape[1][1]) + ". Maybe you forgot the "
                                 "calling of a routing/ competition module.")

            if input_shape[0][1] != self.capsule_number:
                raise ValueError("The defined number of capsules is not equal the number of capsules in signals. " +
                                 "You provide: " + str(input_shape[0][1]) + "!=" + str(self.capsule_number) +
                                 ". Maybe you forgot the calling of a competition module.")

            self.beta = self.add_weight(shape=(1,),
                                        initializer=self.beta_initializer,
                                        regularizer=self.beta_regularizer,
                                        constraint=self.beta_constraint,
                                        name='beta')

            self.input_spec = [InputSpec(shape=(None,) + tuple(input_shape[0][1:])),
                               InputSpec(shape=(None,) + tuple(input_shape[1][1:]))]
Ejemplo n.º 11
0
 def build(self, input_shape):
     assert (len(input_shape) == 2)
     input_dim = input_shape[1]
     self.input_spec = InputSpec(dtype=tf.float32, shape=(None, input_dim))
     self.prototypes = self.add_weight(shape=(self.n_prototypes, input_dim),
                                       initializer='glorot_uniform',
                                       name='prototypes')
     if self.initial_prototypes is not None:
         self.set_weights(self.initial_prototypes)
         del self.initial_prototypes
     self.built = True
Ejemplo n.º 12
0
    def _build(self, input_shape):
        if not self.built:
            if input_shape[0][1] != input_shape[1][1]:
                raise ValueError(
                    "The number of capsules must be equal to the number of prototypes. Necessary "
                    "assumption for Gibbs Routing. You provide " +
                    str(input_shape[0][1]) + "!=" + str(input_shape[1][1]) +
                    ". Maybe you forgot the calling of a measuring module.")

            # add additional dimension to use broadcasting
            self.beta = self.add_weight(shape=(input_shape[0][1], 1),
                                        initializer=self.beta_initializer,
                                        regularizer=self.beta_regularizer,
                                        constraint=self.beta_constraint,
                                        name='beta')

            self.input_spec = [
                InputSpec(shape=(None, ) + tuple(input_shape[0][1:])),
                InputSpec(shape=(None, ) + tuple(input_shape[1][1:]))
            ]
Ejemplo n.º 13
0
    def __init__(self, output_dim, input_dim=None, weights=None, alpha=1.0, **kwargs):
        self.output_dim = output_dim
        self.input_dim = input_dim
        self.alpha = alpha
        # kmeans cluster centre locations
        self.initial_weights = weights
        self.input_spec = [InputSpec(ndim=2)]

        if self.input_dim:
            kwargs['input_shape'] = (self.input_dim,)
        super(ClusteringLayer, self).__init__(**kwargs)
Ejemplo n.º 14
0
    def build(self, input_shape):
        assert len(input_shape) >= 3
        input_dim = input_shape[-1]

        self.kernel = self.add_weight(shape=(1, input_dim),
                                      initializer=self.kernel_initializer,
                                      name='kernel',
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)
        self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim})
        self.built = True
 def __init__(self, n_clusters, weights=None, alpha=1.0, dist_metric='eucl', **kwargs):
     if 'input_shape' not in kwargs and 'input_dim' in kwargs:
         kwargs['input_shape'] = (kwargs.pop('input_dim'),)
     super(TSClusteringLayer, self).__init__(**kwargs)
     self.n_clusters = n_clusters
     self.alpha = alpha
     self.dist_metric = dist_metric
     self.initial_weights = weights
     self.input_spec = InputSpec(ndim=3)
     self.clusters = None
     self.built = False
Ejemplo n.º 16
0
 def __init__(self,
              n_filters,
              n_experts_per_filter,
              kernel_size,
              strides=(1, 1, 1),
              padding='valid',
              data_format='channels_last',
              dilation_rate=(1, 1, 1),
              expert_activation=None,
              gating_activation=None,
              use_expert_bias=True,
              use_gating_bias=True,
              expert_kernel_initializer_scale=1.0,
              gating_kernel_initializer_scale=1.0,
              expert_bias_initializer='zeros',
              gating_bias_initializer='zeros',
              expert_kernel_regularizer=None,
              gating_kernel_regularizer=None,
              expert_bias_regularizer=None,
              gating_bias_regularizer=None,
              expert_kernel_constraint=None,
              gating_kernel_constraint=None,
              expert_bias_constraint=None,
              gating_bias_constraint=None,
              activity_regularizer=None,
              **kwargs):
     super(Conv3DMoE, self).__init__(
         rank=3,
         n_filters=n_filters,
         n_experts_per_filter=n_experts_per_filter,
         kernel_size=kernel_size,
         strides=strides,
         padding=padding,
         data_format=data_format,
         dilation_rate=dilation_rate,
         expert_activation=expert_activation,
         gating_activation=gating_activation,
         use_expert_bias=use_expert_bias,
         use_gating_bias=use_gating_bias,
         expert_kernel_initializer_scale=expert_kernel_initializer_scale,
         gating_kernel_initializer_scale=gating_kernel_initializer_scale,
         expert_bias_initializer=expert_bias_initializer,
         gating_bias_initializer=gating_bias_initializer,
         expert_kernel_regularizer=expert_kernel_regularizer,
         gating_kernel_regularizer=gating_kernel_regularizer,
         expert_bias_regularizer=expert_bias_regularizer,
         gating_bias_regularizer=gating_bias_regularizer,
         expert_kernel_constraint=expert_kernel_constraint,
         gating_kernel_constraint=gating_kernel_constraint,
         expert_bias_constraint=expert_bias_constraint,
         gating_bias_constraint=gating_bias_constraint,
         activity_regularizer=activity_regularizer,
         **kwargs)
     self.input_spec = InputSpec(ndim=5)
Ejemplo n.º 17
0
 def __init__(self, size=(1, 1), target_size=None, data_format='default', **kwargs):
     if data_format == 'default':
         data_format = K.image_data_format()
     self.size = tuple(size)
     if target_size is not None:
         self.target_size = tuple(target_size)
     else:
         self.target_size = None
     assert data_format in {'channels_last', 'channels_first'}, 'data_format must be in {tf, th}'
     self.data_format = data_format
     self.input_spec = [InputSpec(ndim=4)]
     super(BilinearUpSampling2D, self).__init__(**kwargs)
Ejemplo n.º 18
0
 def __init__(self, type=2, n=None, axis=-2, norm=None, rank=1, data_format='channels_last',**kwargs):
     super(DCT1D, self).__init__(**kwargs)
     self.rank = rank
     self.type = type
     self.n = n
     self.axis = axis
     self.norm = norm
     self.data_format = conv_utils.normalize_data_format(data_format)
     self.input_spec = InputSpec(ndim=self.rank + 2)
     if norm is not None:
         if norm != 'ortho':
             raise ValueError('Normalization should be `ortho` or `None`')
    def __init__(self, config, weights=None, **kwargs):
        self.output_dim = config.trainer.n_clusters
        self.input_dim = config.model.input_shape
        self.alpha = config.model.alpha
        self.initial_weights = weights
        self.input_spec = [InputSpec()]

        if self.input_dim:
            kwargs['input_shape'] = (self.input_dim, )
        super(ClusteringLayer_temporal,
              self).__init__(batch_size=config.data_loader.batch_size.train,
                             input_shape=config.model.input_shape)
Ejemplo n.º 20
0
 def build(self, input_shape):
     assert len(
         input_shape) == 2  #由于维度是2,所以shape的长度也是2,assert判断表达式,为false时触发异常
     input_dim = input_shape[1]  #shape第一维是样本数,第二维是样本的维度,这里为了得到样本的维度
     self.input_spec = InputSpec(dtype=K.floatx(), shape=(None, input_dim))
     self.clusters = self.add_weight((self.n_clusters, input_dim),
                                     initializer='glorot_uniform',
                                     name='clusters')  #shape第一维是聚类数,第二维是维度
     if self.initial_weights is not None:
         self.set_weights(self.initial_weights)  #如果有传入初始权重就设置为初始权重weights
         del self.initial_weights
     self.built = True
Ejemplo n.º 21
0
    def build(self, input_shape):
        if self.data_format == 'channels_first':
            channel_axis = 1
        else:
            channel_axis = -1
        if input_shape[channel_axis] is None:
            raise ValueError('The channel dimension of the inputs '
                             'should be defined. Found `None`.')
        input_dim = input_shape[channel_axis]
        self.kernel_shape = self.kernel_size + (input_dim, self.filters)
        a,b,c,d = self.kernel_shape 
          
        if a!=b:
            raise ValueError('kernel width and depth are not equal')

         
        #self.kernel_initializer =  RandomUniform(minval=-2, maxval=2, seed=None)
        # self.kernel_initializer =  RandomNormal(0.0,1.0)

        #print('kkkk')
        # for x in self.source_features:
           # print(K.int_shape(x))
        
        #对源特征和目标特征参数分开处理
        self.kernel_shape = list(self.kernel_shape)
        self.kernel_shape_addition = self.kernel_shape[:]
        self.kernel_shape[2] = self.kernel_shape[2]-self.num_source
        self.kernel_shape_addition[2] = self.num_source
        self.kernel_addition = self.add_weight(shape=self.kernel_shape_addition,
                                                initializer=initializers.get('zero'),
                                                name='kernel_source',
                                                regularizer=self.kernel_regularizer,
                                                constraint=self.kernel_constraint)
        self.kernel = self.add_weight(shape=self.kernel_shape,
                                      initializer=self.kernel_initializer,
                                      name='kernel',
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)
                  
                   
        if self.use_bias:
            self.bias = self.add_weight(shape=(self.filters,),
                                        initializer=self.bias_initializer,
                                        name='bias',
                                        regularizer=self.bias_regularizer,
                                        constraint=self.bias_constraint)
        else:
            self.bias = None
        # Set input spec.
        self.input_spec = InputSpec(ndim=self.rank + 2,
                                    axes={channel_axis: input_dim})
        self.built = True
Ejemplo n.º 22
0
 def build(self, input_shape):  #这一层可被训练的参数是什么
     assert len(input_shape) == 2
     input_dim = input_shape[1]
     self.input_spec = InputSpec(dtype=K.floatx(), shape=(None, input_dim))
     self.clusters = self.add_weight(
         (self.n_clusters, input_dim),
         initializer='glorot_uniform',
         name='clusters'
     )  #增加了一个叫'clusters'的参数(可被训练的),参数(矩阵)维数是(self.n_clusters, input_dim)
     if self.initial_weights is not None:
         self.set_weights(self.initial_weights)
         del self.initial_weights
     self.built = True
Ejemplo n.º 23
0
    def __init__(
            self,
            width,
            activation=None,
            use_bias=True,
            kernel_initializer="glorot_uniform",
            bias_initializer="zeros",
            kernel_regularizer=None,
            bias_regularizer=None,
            activity_regularizer=None,
            kernel_constraint=None,
            bias_constraint=None,
            # "single": only one weight applied to all neighbor sums
            # "all": a different weight for each property
            conv_wts="single",
            **kwargs):

        if "input_shape" not in kwargs and "input_dim" in kwargs:
            kwargs["input_shape"] = (kwargs.pop("input_dim"), )

        allowed_conv_wts = ("all", "single")
        if conv_wts not in allowed_conv_wts:
            raise ValueError("conv_wt should be one of %r" % allowed_conv_wts)

        super(GraphConv, self).__init__(**kwargs)

        self.width = width
        self.conv_wts = conv_wts

        self.activation = activations.get(activation)
        self.use_bias = use_bias
        self.kernel_initializer = initializers.get(kernel_initializer)
        self.bias_initializer = initializers.get(bias_initializer)
        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.bias_regularizer = regularizers.get(bias_regularizer)
        self.activity_regularizer = regularizers.get(activity_regularizer)
        self.kernel_constraint = constraints.get(kernel_constraint)
        self.bias_constraint = constraints.get(bias_constraint)
        self.input_spec = [InputSpec(ndim=3), InputSpec(ndim=3)]
Ejemplo n.º 24
0
 def __init__(self, left, input_dim, query_dim=None, output_dim=None,
              right=0, merge="concatenate",  use_bias=False, **kwargs):
     if 'input_shape' not in kwargs:
         kwargs['input_shape'] = (None, None, input_dim)
     super(AttentionCell, self).__init__(**kwargs)
     self.left = left
     self.right = right
     self.merge = merge
     self.use_bias = use_bias
     self.input_dim = input_dim
     self.query_dim = query_dim or self.input_dim
     self.output_dim = output_dim or self.input_dim
     self.input_spec = InputSpec(shape=(None, None, input_dim))
Ejemplo n.º 25
0
 def __init__(self,filters=1,
              kernel_size=(3,3),
              strides=(1,1),
              data_format='channels_last',
              operation = 'm',
              **kwargs):
     super(erode, self).__init__(**kwargs)
     self.filters = filters
     self.kernel_size = kernel_size
     self.strides = strides
     self.data_format = data_format
     self.operation = operation
     self.input_spec = InputSpec(ndim=4)
Ejemplo n.º 26
0
    def build(self, input_shape):
        assert len(input_shape) >= 2
        input_dim = input_shape[-1]

        self.gate_kernel = self.add_weight(
            shape=(input_dim, input_dim), initializer='uniform', name='gate_kernel')
        self.gate_bias = self.add_weight(
            shape=(input_dim,), initializer=self.bias_initializer, name='gate_bias')
        self.dense_kernel = self.add_weight(
            shape=(input_dim, input_dim), initializer='uniform', name='dense_kernel')
        self.dense_bias = self.add_weight(
            shape=(input_dim,), initializer=self.bias_initializer, name='dense_bias')
        self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim})
        self.built = True
Ejemplo n.º 27
0
 def build(self, input_shape):
     if len(input_shape) > 2:
         raise ValueError("Input to attention layer hasn't been flattened")
     self.input_dim = input_shape[-1]
     self.kernel = self.add_weight(
         shape=(self.input_dim, ),
         initializer=initializers.Ones(),
         name='kernel',
         constraint=constraints.NonNeg()
         #constraint=constraints.min_max_norm(min_value=0.0, max_value=1.0)
         #constraint=constraints.UnitNorm(axis=self.axis)
     )
     self.input_spec = InputSpec(min_ndim=2, axes={-1: self.input_dim})
     self.built = True
Ejemplo n.º 28
0
 def __init__(self,
              size=(2, 2),
              num_pixels=(0, 0),
              data_format='channels_last',
              method_name='FgSegNet_M',
              **kwargs):
     super(MyUpSampling2D, self).__init__(**kwargs)
     self.data_format = conv_utils.normalize_data_format(data_format)
     self.size = conv_utils.normalize_tuple(size, 2, 'size')
     self.input_spec = InputSpec(ndim=4)
     self.num_pixels = num_pixels
     self.method_name = method_name
     assert method_name in ['FgSegNet_M', 'FgSegNet_S', 'FgSegNet_v2'
                            ], 'Provided method_name is incorrect.'
Ejemplo n.º 29
0
    def __init__(self,
                 filters=1,
                 kernel_size=80,
                 rank=1,
                 strides=1,
                 padding='valid',
                 data_format='channels_last',
                 dilation_rate=1,
                 activation=None,
                 use_bias=True,
                 fsHz=1000.,
                 fc_initializer=initializers.RandomUniform(minval=10,
                                                           maxval=4000),
                 n_order_initializer=initializers.constant(4.),
                 amp_initializer=initializers.constant(10**5),
                 beta_initializer=initializers.RandomNormal(mean=30, stddev=6),
                 bias_initializer='zeros',
                 **kwargs):
        super(Conv1D_gammatone_coeff, self).__init__(**kwargs)
        self.rank = rank
        self.filters = filters
        self.kernel_size_ = kernel_size
        self.kernel_size = conv_utils.normalize_tuple(kernel_size, rank,
                                                      'kernel_size')
        self.strides = conv_utils.normalize_tuple(strides, rank, 'strides')
        self.padding = conv_utils.normalize_padding(padding)
        self.data_format = normalize_data_format(data_format)
        self.dilation_rate = conv_utils.normalize_tuple(
            dilation_rate, rank, 'dilation_rate')
        self.activation = activations.get(activation)
        self.use_bias = use_bias
        self.bias_initializer = initializers.get(bias_initializer)
        self.fc_initializer = initializers.get(fc_initializer)
        self.n_order_initializer = initializers.get(n_order_initializer)
        self.amp_initializer = initializers.get(amp_initializer)
        self.beta_initializer = initializers.get(beta_initializer)
        self.input_spec = InputSpec(ndim=self.rank + 2)

        self.fc = self.fc_initializer.__call__((self.filters, 1))
        self.n_order = self.n_order_initializer((1, 1))
        self.amp = self.amp_initializer((self.filters, 1))
        self.beta = self.beta_initializer((self.filters, 1))

        self.fsHz = fsHz
        self.t = tf.range(start=0,
                          limit=kernel_size / float(fsHz),
                          delta=1 / float(fsHz),
                          dtype=K.floatx())
        self.t = tf.expand_dims(input=self.t, axis=-1)
Ejemplo n.º 30
0
 def __init__(self,
              first_threshold=None,
              second_threshold=None,
              use_dimension_bias=False,
              use_intermediate_layer=False,
              intermediate_dim=64,
              intermediate_activation=None,
              from_logits=False,
              return_logits=False,
              bias_initializer=1.0,
              **kwargs):
     # if 'input_shape' not in kwargs:
     #     kwargs['input_shape'] = [(None, input_dim,), (None, input_dim)]
     super(WeightedCombinationLayer, self).__init__(**kwargs)
     self.first_threshold = first_threshold if first_threshold is not None else INFTY
     self.second_threshold = second_threshold if second_threshold is not None else INFTY
     self.use_dimension_bias = use_dimension_bias
     self.use_intermediate_layer = use_intermediate_layer
     self.intermediate_dim = intermediate_dim
     self.intermediate_activation = kact.get(intermediate_activation)
     self.from_logits = from_logits
     self.return_logits = return_logits
     self.bias_initializer = bias_initializer
     self.input_spec = [InputSpec(), InputSpec(), InputSpec()]