Ejemplo n.º 1
0
def default_recurisive_init(custom_cell):
    """default_recurisive_init"""
    for _, cell in custom_cell.cells_and_names():
        if isinstance(cell, nn.Conv2d):
            cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)),
                                                  cell.weight.shape,
                                                  cell.weight.dtype))
            if cell.bias is not None:
                fan_in, _ = _calculate_in_and_out(cell.weight)
                bound = 1 / math.sqrt(fan_in)
                cell.bias.set_data(init.initializer(init.Uniform(bound),
                                                    cell.bias.shape,
                                                    cell.bias.dtype))
        elif isinstance(cell, nn.Dense):
            cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)),
                                                  cell.weight.shape,
                                                  cell.weight.dtype))
            if cell.bias is not None:
                fan_in, _ = _calculate_in_and_out(cell.weight)
                bound = 1 / math.sqrt(fan_in)
                cell.bias.set_data(init.initializer(init.Uniform(bound),
                                                    cell.bias.shape,
                                                    cell.bias.dtype))
        elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
            pass
 def __init__(self):
     super(ParameterNet, self).__init__()
     self.para_xavier_uniform = Parameter(init.initializer('xavier_uniform', parameter_shape), name="xavier_uniform")
     self.para_he_uniform = Parameter(init.initializer('he_uniform', parameter_shape), name="he_uniform")
     self.para_xavier_uniform2 = Parameter(init.initializer(init.XavierUniform(), parameter_shape), name="xavier_uniform2")
     self.para_he_uniform2 = Parameter(init.initializer(init.HeUniform(), parameter_shape), name="he_uniform2")
     self.para_truncated_normal = Parameter(init.initializer(init.TruncatedNormal(), parameter_shape), name="truncated_normal")
     self.para_normal = Parameter(init.initializer(init.Normal(), parameter_shape), name="normal")
     self.para_uniform = Parameter(init.initializer(init.Uniform(), parameter_shape), name="uniform")
Ejemplo n.º 3
0
def test_init_uniform():
    scale = 10
    tensor = init.initializer(init.Uniform(scale=scale), [5, 4], ms.float32)
    _check_value(tensor, -scale, scale)