Example #1
0
    def __init__(self,
                 units,
                 activation=None,
                 use_bias=True,
                 kernel_initializer='glorot_uniform',
                 bias_initializer='zeros',
                 kernel_regularizer=None,
                 bias_regularizer=None,
                 activity_regularizer=None,
                 kernel_constraint=None,
                 bias_constraint=None,
                 **kwargs):

        super(Dense, self).__init__(**kwargs)

        self.units = int(units)
        self.activation = activations.get(activation)
        self.use_bias = use_bias

        self.kernel_initializer = initializers.get(kernel_initializer)
        self.bias_initializer = initializers.get(bias_initializer)

        # Not implemented arguments
        default_args_check(kernel_regularizer, "kernel_regularizer", "Dense")
        default_args_check(bias_regularizer, "bias_regularizer", "Dense")
        default_args_check(activity_regularizer, "activity_regularizer",
                           "Dense")
        default_args_check(kernel_constraint, "kernel_constraint", "Dense")
        default_args_check(bias_constraint, "bias_constraint", "Dense")
Example #2
0
  def __init__(self,
               kernel_size,
               strides=(1, 1),
               padding='valid',
               depth_multiplier=1,
               data_format=None,
               activation=None,
               use_bias=True,
               depthwise_initializer='glorot_uniform',
               bias_initializer='zeros',
               depthwise_regularizer=None,
               bias_regularizer=None,
               activity_regularizer=None,
               depthwise_constraint=None,
               bias_constraint=None,
               **kwargs):

    super(DepthwiseConv2D, self).__init__(**kwargs)

    self.rank = 2
    self.kernel_size = conv_utils.normalize_tuple(
        kernel_size, self.rank, 'kernel_size')
    if self.kernel_size[0] != self.kernel_size[1]:
      raise NotImplementedError("TF Encrypted currently only supports same "
                                "stride along the height and the width."
                                "You gave: {}".format(self.kernel_size))
    self.strides = conv_utils.normalize_tuple(strides, self.rank, 'strides')
    self.padding = conv_utils.normalize_padding(padding).upper()
    self.depth_multiplier = depth_multiplier
    self.data_format = conv_utils.normalize_data_format(data_format)
    if activation is not None:
      logger.info("Performing an activation before a pooling layer can result "
                  "in unnecessary performance loss. Check model definition in "
                  "case of missed optimization.")
    self.activation = activations.get(activation)
    self.use_bias = use_bias
    self.depthwise_initializer = initializers.get(depthwise_initializer)
    self.bias_initializer = initializers.get(bias_initializer)

    # Not implemented arguments
    default_args_check(depthwise_regularizer,
                       "depthwise_regularizer",
                       "DepthwiseConv2D")
    default_args_check(bias_regularizer,
                       "bias_regularizer",
                       "DepthwiseConv2D")
    default_args_check(activity_regularizer,
                       "activity_regularizer",
                       "DepthwiseConv2D")
    default_args_check(depthwise_constraint,
                       "depthwise_constraint",
                       "DepthwiseConv2D")
    default_args_check(bias_constraint,
                       "bias_constraint",
                       "DepthwiseConv2D")
Example #3
0
    def __init__(self,
                 max_value=None,
                 negative_slope=0,
                 threshold=0,
                 **kwargs):
        super(ReLU, self).__init__(**kwargs)

        # Not implemented arguments
        default_args_check(max_value, "max_value", "ReLU")
        default_args_check(negative_slope, "negative_slope", "ReLU")
        default_args_check(threshold, "threshold", "ReLU")
    def __init__(
            self,
            axis=3,
            momentum=0.99,
            epsilon=1e-3,
            center=True,
            scale=True,
            beta_initializer='zeros',
            gamma_initializer='ones',
            moving_mean_initializer='zeros',
            moving_variance_initializer='ones',
            beta_regularizer=None,
            gamma_regularizer=None,
            beta_constraint=None,
            gamma_constraint=None,
            renorm=False,
            renorm_clipping=None,
            renorm_momentum=0.99,
            fused=None,  # pylint: disable=unused-argument
            trainable=False,
            virtual_batch_size=None,
            adjustment=None,
            name=None,
            **kwargs):
        super(BatchNormalization, self).__init__(name=name,
                                                 trainable=trainable,
                                                 **kwargs)

        self.beta_initializer = initializers.get(beta_initializer)
        self.gamma_initializer = initializers.get(gamma_initializer)
        self.moving_mean_initializer = initializers.get(
            moving_mean_initializer)
        self.moving_variance_initializer = initializers.get(
            moving_variance_initializer)

        default_args_check(beta_regularizer, "beta_regularizer",
                           "BatchNormalization")
        default_args_check(gamma_regularizer, "gamma_regularizer",
                           "BatchNormalization")
        default_args_check(beta_constraint, "beta_constraint",
                           "BatchNormalization")
        default_args_check(gamma_constraint, "gamma_constraint",
                           "BatchNormalization")
        default_args_check(renorm, "renorm", "BatchNormalization")
        default_args_check(renorm_clipping, "renorm_clipping",
                           "BatchNormalization")
        default_args_check(virtual_batch_size, "virtual_batch_size",
                           "BatchNormalization")
        default_args_check(adjustment, "adjustment", "BatchNormalization")

        # Axis from get_config can be in ListWrapper format even if
        # the layer is expecting an integer for the axis
        if isinstance(axis, list):
            axis = axis[0]

        # Axis -3 is equivalent to 1, and axis -1 is equivalent to 3, because the
        # input rank is required to be 4 (which is checked later).
        if axis not in (1, 3):
            raise ValueError("Axis of 1 or 3 is currently only supported")

        self.axis = axis
        self.scale = scale
        self.center = center
        self.epsilon = epsilon
        self.momentum = momentum
        self.renorm_momentum = renorm_momentum