예제 #1
0
class Conv2DSchema(BaseLayerSchema):
    filters = fields.Int()
    kernel_size = ObjectOrListObject(fields.Int, min=2, max=2)
    strides = ObjectOrListObject(fields.Int, min=2, max=2, default=(1, 1), missing=(1, 1))
    padding = fields.Str(default='valid', missing='valid',
                         validate=validate.OneOf(['same', 'valid']))
    data_format = fields.Str(default=None, missing=None,
                             validate=validate.OneOf('channels_first', 'channels_last'))
    dilation_rate = ObjectOrListObject(fields.Int, min=2, max=2, default=(1, 1), missing=(1, 1))
    activation = StrOrFct(allow_none=True, validate=validate.OneOf(ACTIVATION_VALUES))
    use_bias = fields.Bool(default=True, missing=True)
    kernel_initializer = fields.Nested(InitializerSchema, allow_none=True)
    bias_initializer = fields.Nested(InitializerSchema, allow_none=True)
    kernel_regularizer = fields.Nested(RegularizerSchema, allow_none=True)
    bias_regularizer = fields.Nested(RegularizerSchema, allow_none=True)
    activity_regularizer = fields.Nested(RegularizerSchema, allow_none=True)
    kernel_constraint = fields.Nested(ConstraintSchema, allow_none=True)
    bias_constraint = fields.Nested(ConstraintSchema, allow_none=True)

    class Meta:
        ordered = True

    @post_load
    def make_load(self, data):
        return Conv2DConfig(**data)
예제 #2
0
class ConvRecurrent2DSchema(RecurrentSchema):
    filters = fields.Int()
    kernel_size = ObjectOrListObject(fields.Int, min=2, max=2)
    strides = ObjectOrListObject(fields.Int,
                                 min=2,
                                 max=2,
                                 default=(1, 1),
                                 missing=(1, 1))
    padding = fields.Str(default='valid',
                         missing='valid',
                         validate=validate.OneOf(['same', 'valid']))
    data_format = fields.Str(allow_none=True,
                             validate=validate.OneOf('channels_first',
                                                     'channels_last'))
    dilation_rate = ObjectOrListObject(fields.Int,
                                       min=2,
                                       max=2,
                                       default=(1, 1),
                                       missing=(1, 1))
    return_sequences = fields.Bool(default=False, missing=False)
    go_backwards = fields.Bool(default=False, missing=False)
    stateful = fields.Bool(default=False, missing=False)

    class Meta:
        ordered = True

    @post_load
    def make_load(self, data):
        return ConvRecurrent2DConfig(**data)
예제 #3
0
class AveragePooling3DSchema(BaseLayerSchema):
    pool_size = ObjectOrListObject(fields.Int,
                                   min=3,
                                   max=3,
                                   default=(2, 2, 2),
                                   missing=(2, 2, 2))
    strides = ObjectOrListObject(fields.Int,
                                 min=3,
                                 max=3,
                                 default=None,
                                 missing=None)
    padding = fields.Str(default='valid',
                         missing='valid',
                         validate=validate.OneOf(['same', 'valid']))
    data_format = fields.Str(default=None,
                             missing=None,
                             validate=validate.OneOf('channels_first',
                                                     'channels_last'))

    class Meta:
        ordered = True

    @post_load
    def make(self, data):
        return AveragePooling3DConfig(**data)

    @post_dump
    def unmake(self, data):
        return AveragePooling3DConfig.remove_reduced_attrs(data)
예제 #4
0
class MaxPooling2DSchema(BaseLayerSchema):
    pool_size = ObjectOrListObject(fields.Int,
                                   min=2,
                                   max=2,
                                   default=(2, 2),
                                   missing=(2, 2))
    strides = ObjectOrListObject(fields.Int,
                                 min=2,
                                 max=2,
                                 default=None,
                                 missing=None)
    padding = fields.Str(default='valid',
                         missing='valid',
                         validate=validate.OneOf(['same', 'valid']))
    data_format = fields.Str(default=None,
                             missing=None,
                             validate=validate.OneOf('channels_first',
                                                     'channels_last'))

    class Meta:
        ordered = True

    @post_load
    def make_load(self, data):
        return MaxPooling2DConfig(**data)
예제 #5
0
class GraphSchema(BaseSchema):
    input_layers = ObjectOrListObject(Tensor)
    output_layers = ObjectOrListObject(Tensor)
    layers = fields.Nested(LayerSchema, many=True)
    name = fields.Str(allow_none=True)

    class Meta:
        unknown = EXCLUDE

    @staticmethod
    def schema_config():
        return GraphConfig
예제 #6
0
class Cropping2DSchema(BaseLayerSchema):
    cropping = ObjectOrListObject(ObjectOrListObject(fields.Int, min=2, max=2), min=2, max=2,
                                  default=((0, 0), (0, 0)), missing=((0, 0), (0, 0)))
    data_format = fields.Str(default=None, missing=None,
                             validate=validate.OneOf('channels_first', 'channels_last'))

    class Meta:
        ordered = True

    @post_load
    def make_load(self, data):
        return Cropping2DConfig(**data)
예제 #7
0
class GraphSchema(Schema):
    input_layers = ObjectOrListObject(Tensor)
    output_layers = ObjectOrListObject(Tensor)
    layers = fields.Nested(LayerSchema, many=True)
    name = fields.Str(allow_none=True)

    class Meta:
        ordered = True

    @post_load
    def make(self, data):
        return GraphConfig(**data)
예제 #8
0
class LocallyConnected1DSchema(BaseLayerSchema):
    filters = fields.Int()
    kernel_size = ObjectOrListObject(fields.Int, min=1, max=1)
    strides = ObjectOrListObject(fields.Int,
                                 min=1,
                                 max=1,
                                 default=1,
                                 missing=1)
    padding = fields.Str(default='valid',
                         missing='valid',
                         validate=validate.OneOf(['same', 'valid']))
    data_format = fields.Str(default=None,
                             missing=None,
                             validate=validate.OneOf('channels_first',
                                                     'channels_last'))
    activation = StrOrFct(allow_none=True,
                          validate=validate.OneOf(ACTIVATION_VALUES))
    use_bias = fields.Bool(default=True, missing=True)
    kernel_initializer = fields.Nested(InitializerSchema,
                                       default=None,
                                       missing=None)
    bias_initializer = fields.Nested(InitializerSchema,
                                     default=None,
                                     missing=None)
    kernel_regularizer = fields.Nested(RegularizerSchema,
                                       default=None,
                                       missing=None)
    bias_regularizer = fields.Nested(RegularizerSchema,
                                     default=None,
                                     missing=None)
    activity_regularizer = fields.Nested(RegularizerSchema,
                                         default=None,
                                         missing=None)
    kernel_constraint = fields.Nested(RegularizerSchema,
                                      default=None,
                                      missing=None)
    bias_constraint = fields.Nested(RegularizerSchema,
                                    default=None,
                                    missing=None)

    class Meta:
        ordered = True

    @post_load
    def make(self, data):
        return LocallyConnected1DConfig(**data)

    @post_dump
    def unmake(self, data):
        return LocallyConnected1DConfig.remove_reduced_attrs(data)
예제 #9
0
class DotSchema(BaseLayerSchema):
    axes = ObjectOrListObject(fields.Int)
    normalize = fields.Bool(allow_none=True)

    @staticmethod
    def schema_config():
        return DotConfig
예제 #10
0
class EpisodeLoggingTensorHookSchema(BaseSchema):
    tensors = ObjectOrListObject(fields.Str)
    every_n_episodes = fields.Int()

    @staticmethod
    def schema_config():
        return EpisodeLoggingTensorHookConfig
예제 #11
0
class BaseLayerSchema(Schema):
    name = fields.Str(allow_none=True)
    trainable = fields.Bool(default=True, missing=True)
    dtype = DType(allow_none=True)
    inbound_nodes = ObjectOrListObject(Tensor, allow_none=True)

    def get_attribute(self, attr, obj, default):
        return get_value(attr, obj, default)
예제 #12
0
class StepLoggingTensorHookSchema(BaseSchema):
    tensors = ObjectOrListObject(fields.Str)
    every_n_iter = fields.Int(allow_none=True)
    every_n_secs = fields.Int(allow_none=True)

    @staticmethod
    def schema_config():
        return StepLoggingTensorHookConfig
예제 #13
0
class UpSampling1DSchema(BaseLayerSchema):
    size = ObjectOrListObject(fields.Int, min=2, max=2, default=2, missing=2)

    class Meta:
        ordered = True

    @post_load
    def make_load(self, data):
        return UpSampling1DConfig(**data)
예제 #14
0
class ZeroPadding1DSchema(BaseLayerSchema):
    padding = ObjectOrListObject(fields.Int, min=1, max=1, default=1, missing=1)

    class Meta:
        ordered = True

    @post_load
    def make_load(self, data):
        return ZeroPadding1DConfig(**data)
예제 #15
0
class Cropping1DSchema(BaseLayerSchema):
    cropping = ObjectOrListObject(fields.Int, min=2, max=2, default=(1, 1), missing=(1, 1))

    class Meta:
        ordered = True

    @post_load
    def make_load(self, data):
        return Cropping1DConfig(**data)
예제 #16
0
class FinalOpsHookSchema(Schema):
    final_ops = ObjectOrListObject(fields.Str)

    class Meta:
        ordered = True

    @post_load
    def make_load(self, data):
        return FinalOpsHookConfig(**data)
예제 #17
0
class LocallyConnected2DSchema(BaseLayerSchema):
    filters = fields.Int()
    kernel_size = ObjectOrListObject(fields.Int, min=2, max=2)
    strides = ObjectOrListObject(fields.Int,
                                 min=2,
                                 max=2,
                                 default=(1, 1),
                                 missing=(1, 1))
    padding = fields.Str(default='valid',
                         missing='valid',
                         validate=validate.OneOf(['same', 'valid']))
    data_format = fields.Str(default=None,
                             missing=None,
                             validate=validate.OneOf('channels_first',
                                                     'channels_last'))
    activation = StrOrFct(allow_none=True,
                          validate=validate.OneOf(ACTIVATION_VALUES))
    use_bias = fields.Bool(default=True, missing=True)
    kernel_initializer = fields.Nested(InitializerSchema,
                                       default=None,
                                       missing=None)
    bias_initializer = fields.Nested(InitializerSchema,
                                     default=None,
                                     missing=None)
    kernel_regularizer = fields.Nested(RegularizerSchema,
                                       default=None,
                                       missing=None)
    bias_regularizer = fields.Nested(RegularizerSchema,
                                     default=None,
                                     missing=None)
    activity_regularizer = fields.Nested(RegularizerSchema,
                                         default=None,
                                         missing=None)
    kernel_constraint = fields.Nested(RegularizerSchema,
                                      default=None,
                                      missing=None)
    bias_constraint = fields.Nested(RegularizerSchema,
                                    default=None,
                                    missing=None)

    @staticmethod
    def schema_config():
        return LocallyConnected2DConfig
예제 #18
0
class EpisodeLoggingTensorHookSchema(Schema):
    tensors = ObjectOrListObject(fields.Str)
    every_n_episodes = fields.Int()

    class Meta:
        ordered = True

    @post_load
    def make_load(self, data):
        return EpisodeLoggingTensorHookConfig(**data)
예제 #19
0
class StepLoggingTensorHookSchema(Schema):
    tensors = ObjectOrListObject(fields.Str)
    every_n_iter = fields.Int(allow_none=True)
    every_n_secs = fields.Int(allow_none=True)

    class Meta:
        ordered = True

    @post_load
    def make_load(self, data):
        return StepLoggingTensorHookConfig(**data)
예제 #20
0
class ZeroPadding3DSchema(BaseLayerSchema):
    padding = ObjectOrListObject(fields.Int, min=3, max=3, default=(1, 1, 1), missing=(1, 1, 1))
    data_format = fields.Str(default=None, missing=None,
                             validate=validate.OneOf('channels_first', 'channels_last'))

    class Meta:
        ordered = True

    @post_load
    def make_load(self, data):
        return ZeroPadding3DConfig(**data)
예제 #21
0
class MaxPooling2DSchema(BaseLayerSchema):
    pool_size = ObjectOrListObject(fields.Int,
                                   min=2,
                                   max=2,
                                   default=(2, 2),
                                   missing=(2, 2))
    strides = ObjectOrListObject(fields.Int,
                                 min=2,
                                 max=2,
                                 default=None,
                                 missing=None)
    padding = fields.Str(default='valid',
                         missing='valid',
                         validate=validate.OneOf(['same', 'valid']))
    data_format = fields.Str(default=None,
                             missing=None,
                             validate=validate.OneOf('channels_first',
                                                     'channels_last'))

    @staticmethod
    def schema_config():
        return MaxPooling2DConfig
예제 #22
0
class BaseModelSchema(BaseSchema):
    graph = fields.Nested(GraphSchema)
    loss = fields.Nested(LossSchema, allow_none=True)
    optimizer = fields.Nested(OptimizerSchema, allow_none=True)
    metrics = fields.Nested(MetricSchema, many=True, allow_none=True)
    summaries = ObjectOrListObject(fields.Str, allow_none=True)
    clip_gradients = fields.Float(allow_none=True)
    clip_embed_gradients = fields.Float(allow_none=True)
    name = fields.Str(allow_none=True)

    @staticmethod
    def schema_config():
        return BaseModelConfig
예제 #23
0
class FinalOpsHookSchema(Schema):
    final_ops = ObjectOrListObject(fields.Str)

    class Meta:
        ordered = True

    @post_load
    def make(self, data):
        return FinalOpsHookConfig(**data)

    @post_dump
    def unmake(self, data):
        return FinalOpsHookConfig.remove_reduced_attrs(data)
예제 #24
0
class DotSchema(BaseLayerSchema):
    axes = ObjectOrListObject(fields.Int)
    normalize = fields.Bool(allow_none=True)

    class Meta:
        ordered = True

    @post_load
    def make(self, data):
        return DotConfig(**data)

    @post_dump
    def unmake(self, data):
        return DotConfig.remove_reduced_attrs(data)
예제 #25
0
class EpisodeLoggingTensorHookSchema(Schema):
    tensors = ObjectOrListObject(fields.Str)
    every_n_episodes = fields.Int()

    class Meta:
        ordered = True

    @post_load
    def make(self, data):
        return EpisodeLoggingTensorHookConfig(**data)

    @post_dump
    def unmake(self, data):
        return EpisodeLoggingTensorHookConfig.remove_reduced_attrs(data)
class PReLUSchema(BaseLayerSchema):
    alpha_initializer = fields.Nested(InitializerSchema,
                                      default=None,
                                      missing=None)
    alpha_regularizer = fields.Nested(RegularizerSchema,
                                      default=None,
                                      missing=None)
    alpha_constraint = fields.Nested(ConstraintSchema,
                                     default=None,
                                     missing=None)
    shared_axes = ObjectOrListObject(fields.Int, default=None, missing=None)

    @staticmethod
    def schema_config():
        return PReLUConfig
예제 #27
0
class BaseModelSchema(Schema):
    graph = fields.Nested(GraphSchema)
    loss = fields.Nested(LossSchema, allow_none=True)
    optimizer = fields.Nested(OptimizerSchema, allow_none=True)
    metrics = fields.Nested(MetricSchema, many=True, allow_none=True)
    summaries = ObjectOrListObject(fields.Str, allow_none=True)
    clip_gradients = fields.Float(allow_none=True)
    clip_embed_gradients = fields.Float(allow_none=True)
    name = fields.Str(allow_none=True)

    class Meta:
        ordered = True

    @post_load
    def make_load(self, data):
        return BaseModelConfig(**data)
예제 #28
0
class PReLUSchema(BaseLayerSchema):
    alpha_initializer = fields.Nested(InitializerSchema,
                                      default=None,
                                      missing=None)
    alpha_regularizer = fields.Nested(RegularizerSchema,
                                      default=None,
                                      missing=None)
    alpha_constraint = fields.Nested(ConstraintSchema,
                                     default=None,
                                     missing=None)
    shared_axes = ObjectOrListObject(fields.Int, default=None, missing=None)

    class Meta:
        ordered = True

    @post_load
    def make_load(self, data):
        return PReLUConfig(**data)
예제 #29
0
class BaseBridgeSchema(Schema):
    state_size = ObjectOrListObject(fields.Int, allow_none=True)
    name = fields.Str(allow_none=True)
예제 #30
0
class FinalOpsHookSchema(BaseSchema):
    final_ops = ObjectOrListObject(fields.Str)

    @staticmethod
    def schema_config():
        return FinalOpsHookConfig