Exemplo n.º 1
0
def test_set_parameters():
    model = mock_model()
    model.initialize(init='xavier', ctx=mx.cpu(0))
    p = mx.gluon.Parameter('source_target_embed_weight', shape=(20, 4))
    p.initialize(init='xavier', ctx=mx.cpu(0))
    model.set_parameters({'source_target_embed_weight': p})
    assert mx.test_utils.same(
        model.params['source_target_embed_weight'].data(), p.data())
Exemplo n.º 2
0
def test_set_parameters_shape():
    model = mock_model()
    model.initialize(init='xavier', ctx=mx.cpu(0))
    p = mx.gluon.Parameter('source_target_embed_weight', shape=(10, 10))
    p.initialize(init='xavier', ctx=mx.cpu(0))
    with pytest.raises(AssertionError) as e:
        model.set_parameters({'source_target_embed_weight': p})
    assert str(e.value) == "Parameter 'source_target_embed_weight' has shape '(20, 4)' in the model but shape " \
                           "'(10, 10)' in the new_params dictionary."
Exemplo n.º 3
0
def test_set_parameters_allow_missing():
    model = mock_model()
    model.initialize(init='xavier', ctx=mx.cpu(0))
    model.set_parameters({}, allow_missing=True)
    assert 'source_target_embed_weight' in model.params
    with pytest.raises(AssertionError) as e:
        model.set_parameters({}, allow_missing=False)
    assert str(e.value) == "Parameter 'source_target_embed_weight' is missing in new_params dictionary. " \
                           "Set allow_missing=True to ignore missing parameters."
Exemplo n.º 4
0
def test_set_parameters_allow_missing():
    mx = pytest.importorskip('mxnet')
    model = mock_model()
    model.initialize(init='xavier', ctx=mx.cpu(0))
    model.set_parameters({}, allow_missing=True)
    assert 'embedding_source.weight' in model.collect_params()
    with pytest.raises(AssertionError) as e:
        model.set_parameters({}, allow_missing=False)
    assert str(e.value) == "Parameter 'embedding_source.weight' is missing in new_params dictionary. " \
                           "Set allow_missing=True to ignore missing parameters."
Exemplo n.º 5
0
def test_set_parameters_context():
    model = mock_model()
    model.initialize(init='xavier', ctx=[mx.cpu(0), mx.cpu(1)])
    p = mx.gluon.Parameter('source_target_embed_weight', shape=(20, 4))
    p.initialize(init='xavier', ctx=mx.cpu(2))
    model.set_parameters({'source_target_embed_weight': p})
    for i in range(2):
        assert mx.test_utils.same(
            model.params['source_target_embed_weight'].data(mx.cpu(i)),
            p.data(mx.cpu(2)))
Exemplo n.º 6
0
def test_set_parameters_context():
    mx = pytest.importorskip('mxnet')
    model = mock_model()
    model.initialize(init='xavier', ctx=[mx.cpu(0), mx.cpu(1)])
    p = mx.gluon.Parameter('embedding_source.weight', shape=(20, 4))
    p.initialize(init='xavier', ctx=mx.cpu(2))
    model.set_parameters({'embedding_source.weight': p})
    for i in range(2):
        assert mx.test_utils.same(
            model.collect_params()['embedding_source.weight'].data(
                mx.cpu(i)).asnumpy(),
            p.data(mx.cpu(2)).asnumpy())
Exemplo n.º 7
0
def test_set_parameters():
    mx = pytest.importorskip('mxnet')
    model = mock_model()
    model.initialize(init='xavier', ctx=mx.cpu(0))
    model_params = model.collect_params()
    p = mx.gluon.Parameter('output_layer.weight', shape=(20, 4))
    p.initialize(init='xavier', ctx=mx.cpu(0))
    model.set_parameters({p.name: p})
    assert mx.test_utils.same(model_params['output_layer.weight'].data(),
                              p.data())
    assert mx.test_utils.same(model_params['embedding_source.weight'].data(),
                              p.data())
    assert mx.test_utils.same(model_params['embedding_target.weight'].data(),
                              p.data())
Exemplo n.º 8
0
def test_set_parameters_uninitialized():
    model = mock_model()
    model.initialize(init='xavier', ctx=mx.cpu(0))
    p = mx.gluon.Parameter('source_target_embed_weight', shape=(20, 4))
    with pytest.raises(AssertionError) as e:
        model.set_parameters({'source_target_embed_weight': p})
    assert str(
        e.value
    ) == "Parameter 'source_target_embed_weight' is not initialized in new_params dictionary."
    p.initialize(init='xavier', ctx=mx.cpu(0))
    model = mock_model()
    with pytest.raises(AssertionError) as e:
        model.set_parameters({'source_target_embed_weight': p})
    assert str(e.value) == "Parameter 'source_target_embed_weight' must be initialized before it can be reset using " \
                           "set_parameters."
Exemplo n.º 9
0
def test_set_parameters_ignore_extra():
    model = mock_model()
    model.initialize(init='xavier', ctx=mx.cpu(0))
    p = mx.gluon.Parameter('source_target_embed_weight', shape=(20, 4))
    p.initialize(init='xavier', ctx=mx.cpu(0))
    q = mx.gluon.Parameter('q', shape=(1, 1))
    q.initialize(init='xavier', ctx=mx.cpu(0))
    params = {'source_target_embed_weight': p, 'q': q}
    model.set_parameters(params, ignore_extra=True)
    assert 'source_target_embed_weight' in model.params
    assert 'q' not in model.params
    with pytest.raises(ValueError) as e:
        model.set_parameters(params, ignore_extra=False)
    assert str(e.value) == "Parameter 'q' in new_params dictionary is not preset in ParameterDict. " \
                           "Set ignore_extra=True to ignore."
Exemplo n.º 10
0
def test_set_parameters_ignore_extra():
    mx = pytest.importorskip('mxnet')
    model = mock_model()
    model.initialize(init='xavier', ctx=mx.cpu(0))
    p = mx.gluon.Parameter('embedding_source.weight', shape=(20, 4))
    p.initialize(init='xavier', ctx=mx.cpu(0))
    q = mx.gluon.Parameter('q', shape=(1, 1))
    q.initialize(init='xavier', ctx=mx.cpu(0))
    params = {'embedding_source.weight': p, 'q': q}
    model.set_parameters(params, ignore_extra=True)
    assert 'embedding_source.weight' in model.collect_params()
    assert 'q' not in model.collect_params()
    with pytest.raises(ValueError) as e:
        model.set_parameters(params, ignore_extra=False)
    assert str(e.value) == "Parameter 'q' in new_params dictionary is not preset in ParameterDict. " \
                           "Set ignore_extra=True to ignore."