def __init__(self): super(ParameterNet, self).__init__() self.para_xavier_uniform = Parameter(init.initializer('xavier_uniform', parameter_shape), name="xavier_uniform") self.para_he_uniform = Parameter(init.initializer('he_uniform', parameter_shape), name="he_uniform") self.para_xavier_uniform2 = Parameter(init.initializer(init.XavierUniform(), parameter_shape), name="xavier_uniform2") self.para_he_uniform2 = Parameter(init.initializer(init.HeUniform(), parameter_shape), name="he_uniform2") self.para_truncated_normal = Parameter(init.initializer(init.TruncatedNormal(), parameter_shape), name="truncated_normal") self.para_normal = Parameter(init.initializer(init.Normal(), parameter_shape), name="normal") self.para_uniform = Parameter(init.initializer(init.Uniform(), parameter_shape), name="uniform")
def test_init_he_uniform(): """ test_init_he_uniform """ tensor1 = init.initializer(init.HeUniform(), [20, 22], ms.float32) tensor2 = init.initializer(init.HeUniform(), [20, 22, 5, 5], ms.float32) tensor3 = init.initializer('he_uniform', [20, 22, 5, 5], ms.float32) tensor4 = init.initializer('he_uniform', [20, 22], ms.float32) tensors = [tensor1, tensor2, tensor3, tensor4] for tensor in tensors: shape = tensor.asnumpy().shape if len(shape) > 2: s = reduce(lambda x, y: x * y, shape[2:]) else: s = 1 n_in = shape[1] * s std = math.sqrt(2 / n_in) boundary = std * math.sqrt(3) assert _check_uniform(tensor, -boundary, boundary)
def test_init_he_uniform_error(): with py.raises(ValueError): init.initializer(init.HeUniform(), [6], ms.float32)
step_size = dataset.get_dataset_size() print("step" + str(step_size)) # define net net = resnet(class_num=config.class_num) # init weight if args_opt.pre_trained: param_dict = load_checkpoint(args_opt.pre_trained) load_param_into_net(net, param_dict) else: for _, cell in net.cells_and_names(): if isinstance(cell, nn.Conv2d): cell.weight.set_data( weight_init.initializer(weight_init.HeUniform(), cell.weight.shape, cell.weight.dtype)) if isinstance(cell, nn.Dense): cell.weight.set_data( weight_init.initializer(weight_init.HeNormal(), cell.weight.shape, cell.weight.dtype)) # init lr lr = get_lr(lr_init=config.lr_init, lr_end=config.lr_end, lr_max=config.lr_max, warmup_epochs=config.warmup_epochs, total_epochs=config.epoch_size, steps_per_epoch=step_size,