Пример #1
0
def test_identity(tensor_shape):
    if len(tensor_shape) > 2:
        with pytest.raises(ValueError):
            _runner(initializers.identity(), tensor_shape,
                    target_mean=1. / tensor_shape[0], target_max=1.)
    else:
        _runner(initializers.identity(), tensor_shape,
                target_mean=1. / tensor_shape[0], target_max=1.)
Пример #2
0
def test_identity(tensor_shape):
    if len(tensor_shape) > 2:
        with pytest.raises(Exception):
            _runner(initializers.identity(), tensor_shape,
                    target_mean=1. / tensor_shape[0], target_max=1.)
    else:
        _runner(initializers.identity(), tensor_shape,
                target_mean=1. / tensor_shape[0], target_max=1.)
Пример #3
0
def test_identity(tensor_shape):
    target_mean = (1. * min(tensor_shape)) / (tensor_shape[0] * tensor_shape[1])
    if len(tensor_shape) > 2:
        with pytest.raises(ValueError):
            _runner(initializers.identity(), tensor_shape,
                    target_mean=target_mean, target_max=1.)
    else:
        _runner(initializers.identity(), tensor_shape,
                target_mean=target_mean, target_max=1.)
def test_identity(tensor_shape):
    if len(tensor_shape) > 2 or max(tensor_shape) % min(tensor_shape) != 0:
        with pytest.raises(ValueError):
            _runner(initializers.identity(),
                    tensor_shape,
                    target_mean=1. / tensor_shape[0],
                    target_max=1.)
    else:
        _runner(initializers.identity(),
                tensor_shape,
                target_mean=1. / tensor_shape[0],
                target_max=1.)
Пример #5
0
def test_identity(tensor_shape):
    target_mean = (1. * min(tensor_shape)) / (tensor_shape[0] *
                                              tensor_shape[1])
    if len(tensor_shape) > 2:
        with pytest.raises(ValueError):
            _runner(initializers.identity(),
                    tensor_shape,
                    target_mean=target_mean,
                    target_max=1.)
    else:
        _runner(initializers.identity(),
                tensor_shape,
                target_mean=target_mean,
                target_max=1.)
     initializers.truncated_normal(mean=0.2, stddev=0.003, seed=42),
     dict(class_name="truncated_normal", mean=0.2, stddev=0.003, seed=42),
     id="tn_1",
 ),
 pytest.param(
     initializers.Orthogonal(1.1),
     dict(class_name="orthogonal", gain=1.1, seed=None),
     id="o_0",
 ),
 pytest.param(
     initializers.orthogonal(gain=1.2, seed=42),
     dict(class_name="orthogonal", gain=1.2, seed=42),
     id="o_1",
 ),
 pytest.param(initializers.Identity(1.1), dict(class_name="identity", gain=1.1), id="i_0"),
 pytest.param(initializers.identity(), dict(class_name="identity", gain=1.0), id="i_1"),
 #################### VarianceScaling ####################
 pytest.param(
     initializers.glorot_normal(), dict(class_name="glorot_normal", seed=None), id="gn_0"
 ),
 pytest.param(
     initializers.glorot_uniform(42), dict(class_name="glorot_uniform", seed=42), id="gu_0"
 ),
 pytest.param(initializers.he_normal(), dict(class_name="he_normal", seed=None), id="hn_0"),
 pytest.param(
     initializers.he_uniform(42), dict(class_name="he_uniform", seed=42), id="hu_0"
 ),
 pytest.param(
     initializers.lecun_normal(), dict(class_name="lecun_normal", seed=None), id="ln_0"
 ),
 pytest.param(
     id="tn_1",
 ),
 pytest.param(
     initializers.Orthogonal(1.1),
     dict(class_name="orthogonal", gain=1.1, seed=None),
     id="o_0",
 ),
 pytest.param(
     initializers.orthogonal(gain=1.2, seed=42),
     dict(class_name="orthogonal", gain=1.2, seed=42),
     id="o_1",
 ),
 pytest.param(initializers.Identity(1.1),
              dict(class_name="identity", gain=1.1),
              id="i_0"),
 pytest.param(initializers.identity(),
              dict(class_name="identity", gain=1.0),
              id="i_1"),
 #################### VarianceScaling ####################
 pytest.param(initializers.glorot_normal(),
              dict(class_name="glorot_normal", seed=None),
              id="gn_0"),
 pytest.param(initializers.glorot_uniform(42),
              dict(class_name="glorot_uniform", seed=42),
              id="gu_0"),
 pytest.param(initializers.he_normal(),
              dict(class_name="he_normal", seed=None),
              id="hn_0"),
 pytest.param(initializers.he_uniform(42),
              dict(class_name="he_uniform", seed=42),
              id="hu_0"),