Exemplo n.º 1
0
config.external_configurable(torch.distributions.pareto.Pareto,
                             module='torch.distributions.pareto')
config.external_configurable(torch.distributions.poisson.Poisson,
                             module='torch.distributions.poisson')
config.external_configurable(
    torch.distributions.relaxed_bernoulli.LogitRelaxedBernoulli,
    module='torch.distributions.relaxed_bernoulli')
config.external_configurable(
    torch.distributions.relaxed_bernoulli.RelaxedBernoulli,
    module='torch.distributions.relaxed_bernoulli')
config.external_configurable(
    torch.distributions.relaxed_categorical.RelaxedOneHotCategorical,
    module='torch.distributions.relaxed_categorical')
config.external_configurable(torch.distributions.studentT.StudentT,
                             module='torch.distributions.studentT')
config.external_configurable(torch.distributions.uniform.Uniform,
                             module='torch.distributions.uniform')
config.external_configurable(torch.distributions.weibull.Weibull,
                             module='torch.distributions.weibull')

# Constants

config.constant('torch.float16', torch.float16)
config.constant('torch.float32', torch.float32)
config.constant('torch.float64', torch.float64)
config.constant('torch.int8', torch.int8)
config.constant('torch.uint8', torch.uint8)
config.constant('torch.int16', torch.int16)
config.constant('torch.int32', torch.int32)
config.constant('torch.int64', torch.int64)
Exemplo n.º 2
0
config.external_configurable(tf.random.stateless_categorical,
                             'tf.random.stateless_categorical')
config.external_configurable(tf.random.stateless_normal,
                             'tf.random.stateless_normal')
config.external_configurable(tf.random.stateless_truncated_normal,
                             'tf.random.stateless_truncated_normal')
config.external_configurable(tf.random.stateless_uniform,
                             'tf.random.stateless_uniform')

# Distribution strategies.
config.external_configurable(tf.contrib.distribute.MirroredStrategy,
                             module='tf.contrib.distribute')

# Constants

config.constant('tf.float16', tf.float16)
config.constant('tf.float32', tf.float32)
config.constant('tf.float64', tf.float64)
config.constant('tf.bfloat16', tf.bfloat16)
config.constant('tf.complex64', tf.complex64)
config.constant('tf.complex128', tf.complex128)
config.constant('tf.int8', tf.int8)
config.constant('tf.uint8', tf.uint8)
config.constant('tf.uint16', tf.uint16)
config.constant('tf.int16', tf.int16)
config.constant('tf.int32', tf.int32)
config.constant('tf.int64', tf.int64)
config.constant('tf.bool', tf.bool)
config.constant('tf.string', tf.string)
config.constant('tf.qint8', tf.qint8)
config.constant('tf.quint8', tf.quint8)
config.external_configurable(nn.MultiheadAttention, 'multihead_att', module='T.nn')
config.external_configurable(nn.PReLU, 'prelu', module='T.nn')
config.external_configurable(nn.ReLU, 'relu', module='T.nn')
config.external_configurable(nn.RReLU, 'rrelu', module='T.nn')
config.external_configurable(nn.SELU, 'selu', module='T.nn')
config.external_configurable(nn.CELU, 'celu', module='T.nn')
config.external_configurable(nn.Sigmoid, 'sigmoid', module='T.nn')
config.external_configurable(nn.Softplus, 'softplus', module='T.nn')
config.external_configurable(nn.Softshrink, 'softshrink', module='T.nn')
config.external_configurable(nn.Softsign, 'softsign', module='T.nn')
config.external_configurable(nn.Tanh, 'tanh', module='T.nn')
config.external_configurable(nn.Tanhshrink, 'tanhshrink', module='T.nn')
config.external_configurable(nn.Threshold, 'threshold', module='T.nn')

# constants
config.constant('float16', T.float16)
config.constant('float32', T.float32)
config.constant('float64', T.float64)
config.constant('int8', T.int8)
config.constant('int16', T.int16)
config.constant('int32', T.int32)
config.constant('int64', T.int64)
config.constant('complex32', T.complex32)
config.constant('complex64', T.complex64)
config.constant('complex128', T.complex128)
config.constant('float', T.float)
config.constant('short', T.short)
config.constant('long', T.long)
config.constant('half', T.half)
config.constant('uint8', T.uint8)
config.constant('int', T.int)