예제 #1
0
class MaxPooling2DSchema(BaseLayerSchema):
    pool_size = ObjectOrListObject(fields.Int, min=2, max=2, default=(2, 2), missing=(2, 2))
    strides = ObjectOrListObject(fields.Int, min=2, max=2, default=None, missing=None)
    padding = fields.Str(default='valid', missing='valid',
                         validate=validate.OneOf(['same', 'valid']))
    data_format = fields.Str(default=None, missing=None,
                             validate=validate.OneOf('channels_first', 'channels_last'))

    @staticmethod
    def schema_config():
        return MaxPooling2DConfig
예제 #2
0
class GraphSchema(BaseSchema):
    input_layers = ObjectOrListObject(Tensor)
    output_layers = ObjectOrListObject(Tensor)
    layers = fields.Nested(LayerSchema, many=True)
    name = fields.Str(allow_none=True)

    class Meta:
        unknown = EXCLUDE

    @staticmethod
    def schema_config():
        return GraphConfig
class ConvRecurrent2DSchema(RecurrentSchema):
    filters = fields.Int()
    kernel_size = ObjectOrListObject(fields.Int, min=2, max=2)
    strides = ObjectOrListObject(fields.Int, min=2, max=2, default=(1, 1), missing=(1, 1))
    padding = fields.Str(default='valid', missing='valid',
                         validate=validate.OneOf(['same', 'valid']))
    data_format = fields.Str(allow_none=True,
                             validate=validate.OneOf('channels_first', 'channels_last'))
    dilation_rate = ObjectOrListObject(fields.Int, min=2, max=2, default=(1, 1), missing=(1, 1))
    return_sequences = fields.Bool(default=False, missing=False)
    go_backwards = fields.Bool(default=False, missing=False)
    stateful = fields.Bool(default=False, missing=False)

    @staticmethod
    def schema_config():
        return ConvRecurrent2DConfig
예제 #4
0
class EpisodeLoggingTensorHookSchema(BaseSchema):
    tensors = ObjectOrListObject(fields.Str)
    every_n_episodes = fields.Int()

    @staticmethod
    def schema_config():
        return EpisodeLoggingTensorHookConfig
예제 #5
0
class DotSchema(BaseLayerSchema):
    axes = ObjectOrListObject(fields.Int)
    normalize = fields.Bool(allow_none=True)

    @staticmethod
    def schema_config():
        return DotConfig
예제 #6
0
class StepLoggingTensorHookSchema(BaseSchema):
    tensors = ObjectOrListObject(fields.Str)
    every_n_iter = fields.Int(allow_none=True)
    every_n_secs = fields.Int(allow_none=True)

    @staticmethod
    def schema_config():
        return StepLoggingTensorHookConfig
예제 #7
0
class BaseLayerSchema(BaseSchema):
    name = fields.Str(allow_none=True)
    trainable = fields.Bool(default=True, missing=True)
    dtype = DType(allow_none=True)
    inbound_nodes = ObjectOrListObject(Tensor, allow_none=True)

    class Meta:
        unknown = EXCLUDE
        ordered = True

    def get_attribute(self, obj, attr, default):
        return get_value(attr, obj, default)
예제 #8
0
class LocallyConnected1DSchema(BaseLayerSchema):
    filters = fields.Int()
    kernel_size = ObjectOrListObject(fields.Int, min=1, max=1)
    strides = ObjectOrListObject(fields.Int, min=1, max=1, default=1, missing=1)
    padding = fields.Str(default='valid', missing='valid',
                         validate=validate.OneOf(['same', 'valid']))
    data_format = fields.Str(default=None, missing=None,
                             validate=validate.OneOf('channels_first', 'channels_last'))
    activation = StrOrFct(allow_none=True, validate=validate.OneOf(ACTIVATION_VALUES))
    use_bias = fields.Bool(default=True, missing=True)
    kernel_initializer = fields.Nested(InitializerSchema, default=None, missing=None)
    bias_initializer = fields.Nested(InitializerSchema, default=None, missing=None)
    kernel_regularizer = fields.Nested(RegularizerSchema, default=None, missing=None)
    bias_regularizer = fields.Nested(RegularizerSchema, default=None, missing=None)
    activity_regularizer = fields.Nested(RegularizerSchema, default=None, missing=None)
    kernel_constraint = fields.Nested(RegularizerSchema, default=None, missing=None)
    bias_constraint = fields.Nested(RegularizerSchema, default=None, missing=None)

    @staticmethod
    def schema_config():
        return LocallyConnected1DConfig
예제 #9
0
class BaseModelSchema(BaseSchema):
    graph = fields.Nested(GraphSchema)
    loss = fields.Nested(LossSchema, allow_none=True)
    optimizer = fields.Nested(OptimizerSchema, allow_none=True)
    metrics = fields.Nested(MetricSchema, many=True, allow_none=True)
    summaries = ObjectOrListObject(fields.Str, allow_none=True)
    clip_gradients = fields.Float(allow_none=True)
    clip_embed_gradients = fields.Float(allow_none=True)
    name = fields.Str(allow_none=True)

    @staticmethod
    def schema_config():
        return BaseModelConfig
class PReLUSchema(BaseLayerSchema):
    alpha_initializer = fields.Nested(InitializerSchema,
                                      default=None,
                                      missing=None)
    alpha_regularizer = fields.Nested(RegularizerSchema,
                                      default=None,
                                      missing=None)
    alpha_constraint = fields.Nested(ConstraintSchema,
                                     default=None,
                                     missing=None)
    shared_axes = ObjectOrListObject(fields.Int, default=None, missing=None)

    @staticmethod
    def schema_config():
        return PReLUConfig
예제 #11
0
class RunSchema(BaseSchema):
    cmd = ObjectOrListObject(fields.Str)

    @staticmethod
    def schema_config():
        return RunConfig
예제 #12
0
class FinalOpsHookSchema(BaseSchema):
    final_ops = ObjectOrListObject(fields.Str)

    @staticmethod
    def schema_config():
        return FinalOpsHookConfig
예제 #13
0
class BaseBridgeSchema(BaseSchema):
    state_size = ObjectOrListObject(fields.Int, allow_none=True)
    name = fields.Str(allow_none=True)