class ConstantInitializerSchema(BaseSchema):
    value = fields.Int(allow_none=True)
    dtype = DType(allow_none=True)

    @staticmethod
    def schema_config():
        return ConstantInitializerConfig
示例#2
0
class BaseLayerSchema(Schema):
    name = fields.Str(allow_none=True)
    trainable = fields.Bool(default=True, missing=True)
    dtype = DType(allow_none=True)
    inbound_nodes = ObjectOrListObject(Tensor, allow_none=True)

    def get_attribute(self, attr, obj, default):
        return get_value(attr, obj, default)
示例#3
0
class ConvertImagesDtypeSchema(BaseLayerSchema):
    dtype = DType()
    saturate = fields.Bool(allow_none=True)
    name = fields.Str(allow_none=True)

    @staticmethod
    def schema_config():
        return ConvertImagesDtypeConfig
示例#4
0
class ZerosInitializerSchema(Schema):
    dtype = DType(allow_none=True)

    class Meta:
        ordered = True

    @post_load
    def make_load(self, data):
        return ZerosInitializerConfig(**data)
示例#5
0
class CastSchema(BaseLayerSchema):
    dtype = DType()

    class Meta:
        ordered = True

    @post_load
    def make_load(self, data):
        return CastConfig(**data)
class VarianceScalingInitializerSchema(BaseSchema):
    scale = fields.Float(allow_none=True)
    mode = fields.Str(allow_none=True, validate=validate.OneOf(['fan_in', 'fan_out', 'fan_avg']))
    distribution = fields.Str(allow_none=True)
    dtype = DType(allow_none=True)

    @staticmethod
    def schema_config():
        return VarianceScalingInitializerConfig
class NormalInitializerSchema(BaseSchema):
    mean = fields.Number(allow_none=True)
    stddev = fields.Number(allow_none=True)
    dtype = DType(allow_none=True)
    seed = fields.Int(allow_none=True)

    @staticmethod
    def schema_config():
        return NormalInitializerConfig
class UniformInitializerSchema(BaseSchema):
    minval = fields.Number(allow_none=True)
    maxval = fields.Number(allow_none=True)
    dtype = DType(allow_none=True)
    seed = fields.Int(allow_none=True)

    @staticmethod
    def schema_config():
        return UniformInitializerConfig
示例#9
0
class ConstantInitializerSchema(Schema):
    value = fields.Int(allow_none=True)
    dtype = DType(allow_none=True)

    class Meta:
        ordered = True

    @post_load
    def make_load(self, data):
        return ConstantInitializerConfig(**data)
示例#10
0
class ConvertImagesDtypeSchema(BaseLayerSchema):
    dtype = DType()
    saturate = fields.Bool(allow_none=True)
    name = fields.Str(allow_none=True)

    class Meta:
        ordered = True

    @post_load
    def make_load(self, data):
        return ConvertImagesDtypeConfig(**data)
示例#11
0
class NormalInitializerSchema(Schema):
    mean = fields.Number(allow_none=True)
    stddev = fields.Number(allow_none=True)
    dtype = DType(allow_none=True)
    seed = fields.Int(allow_none=True)

    class Meta:
        ordered = True

    @post_load
    def make_load(self, data):
        return NormalInitializerConfig(**data)
示例#12
0
class UniformInitializerSchema(Schema):
    minval = fields.Number(allow_none=True)
    maxval = fields.Number(allow_none=True)
    dtype = DType(allow_none=True)
    seed = fields.Int(allow_none=True)

    class Meta:
        ordered = True

    @post_load
    def make_load(self, data):
        return UniformInitializerConfig(**data)
示例#13
0
class CastSchema(BaseLayerSchema):
    dtype = DType()

    class Meta:
        ordered = True

    @post_load
    def make(self, data):
        return CastConfig(**data)

    @post_dump
    def unmake(self, data):
        return CastConfig.remove_reduced_attrs(data)
class ZerosInitializerSchema(Schema):
    dtype = DType(allow_none=True)

    class Meta:
        ordered = True

    @post_load
    def make(self, data):
        return ZerosInitializerConfig(**data)

    @post_dump
    def unmake(self, data):
        return ZerosInitializerConfig.remove_reduced_attrs(data)
class ConstantInitializerSchema(Schema):
    value = fields.Int(allow_none=True)
    dtype = DType(allow_none=True)

    class Meta:
        ordered = True

    @post_load
    def make(self, data):
        return ConstantInitializerConfig(**data)

    @post_dump
    def unmake(self, data):
        return ConstantInitializerConfig.remove_reduced_attrs(data)
示例#16
0
class VarianceScalingInitializerSchema(Schema):
    scale = fields.Float(allow_none=True)
    mode = fields.Str(allow_none=True,
                      validate=validate.OneOf(['fan_in', 'fan_out',
                                               'fan_avg']))
    distribution = fields.Str(allow_none=True)
    dtype = DType(allow_none=True)

    class Meta:
        ordered = True

    @post_load
    def make_load(self, data):
        return VarianceScalingInitializerConfig(**data)
示例#17
0
class ConvertImagesDtypeSchema(BaseLayerSchema):
    dtype = DType()
    saturate = fields.Bool(allow_none=True)
    name = fields.Str(allow_none=True)

    class Meta:
        ordered = True

    @post_load
    def make(self, data):
        return ConvertImagesDtypeConfig(**data)

    @post_dump
    def unmake(self, data):
        return ConvertImagesDtypeConfig.remove_reduced_attrs(data)
class NormalInitializerSchema(Schema):
    mean = fields.Number(allow_none=True)
    stddev = fields.Number(allow_none=True)
    dtype = DType(allow_none=True)
    seed = fields.Int(allow_none=True)

    class Meta:
        ordered = True

    @post_load
    def make(self, data):
        return NormalInitializerConfig(**data)

    @post_dump
    def unmake(self, data):
        return NormalInitializerConfig.remove_reduced_attrs(data)
示例#19
0
class CastSchema(BaseLayerSchema):
    dtype = DType()

    @staticmethod
    def schema_config():
        return CastConfig
class OnesInitializerSchema(BaseSchema):
    dtype = DType(allow_none=True)

    @staticmethod
    def schema_config():
        return OnesInitializerConfig