def test_set_parameters(): model = mock_model() model.initialize(init='xavier', ctx=mx.cpu(0)) p = mx.gluon.Parameter('source_target_embed_weight', shape=(20, 4)) p.initialize(init='xavier', ctx=mx.cpu(0)) model.set_parameters({'source_target_embed_weight': p}) assert mx.test_utils.same( model.params['source_target_embed_weight'].data(), p.data())
def test_set_parameters_shape(): model = mock_model() model.initialize(init='xavier', ctx=mx.cpu(0)) p = mx.gluon.Parameter('source_target_embed_weight', shape=(10, 10)) p.initialize(init='xavier', ctx=mx.cpu(0)) with pytest.raises(AssertionError) as e: model.set_parameters({'source_target_embed_weight': p}) assert str(e.value) == "Parameter 'source_target_embed_weight' has shape '(20, 4)' in the model but shape " \ "'(10, 10)' in the new_params dictionary."
def test_set_parameters_allow_missing(): model = mock_model() model.initialize(init='xavier', ctx=mx.cpu(0)) model.set_parameters({}, allow_missing=True) assert 'source_target_embed_weight' in model.params with pytest.raises(AssertionError) as e: model.set_parameters({}, allow_missing=False) assert str(e.value) == "Parameter 'source_target_embed_weight' is missing in new_params dictionary. " \ "Set allow_missing=True to ignore missing parameters."
def test_set_parameters_allow_missing_pt(): model = mock_model_pt() sockeye.model_pt.initialize_parameters(model) model_params = dict(model.named_parameters()) model.set_parameters({}, allow_missing=True) assert 'embedding_source.embedding.weight' in model_params with pytest.raises(AssertionError) as e: model.set_parameters({}, allow_missing=False) assert str(e.value) == "Parameter 'embedding_source.embedding.weight' is missing in new_params dictionary. " \ "Set allow_missing=True to ignore missing parameters."
def test_set_parameters_allow_missing(): mx = pytest.importorskip('mxnet') model = mock_model() model.initialize(init='xavier', ctx=mx.cpu(0)) model.set_parameters({}, allow_missing=True) assert 'embedding_source.weight' in model.collect_params() with pytest.raises(AssertionError) as e: model.set_parameters({}, allow_missing=False) assert str(e.value) == "Parameter 'embedding_source.weight' is missing in new_params dictionary. " \ "Set allow_missing=True to ignore missing parameters."
def test_set_parameters_context(): model = mock_model() model.initialize(init='xavier', ctx=[mx.cpu(0), mx.cpu(1)]) p = mx.gluon.Parameter('source_target_embed_weight', shape=(20, 4)) p.initialize(init='xavier', ctx=mx.cpu(2)) model.set_parameters({'source_target_embed_weight': p}) for i in range(2): assert mx.test_utils.same( model.params['source_target_embed_weight'].data(mx.cpu(i)), p.data(mx.cpu(2)))
def test_set_parameters_pt(): model = mock_model_pt() sockeye.model_pt.initialize_parameters(model) model_params = dict(model.named_parameters()) param = pt.nn.Parameter(pt.ones(20, 4)) name = 'output_layer.weight' model.set_parameters({name: param}) pt.testing.assert_allclose(model_params['output_layer.weight'].data, param.data)
def test_set_parameters_context(): mx = pytest.importorskip('mxnet') model = mock_model() model.initialize(init='xavier', ctx=[mx.cpu(0), mx.cpu(1)]) p = mx.gluon.Parameter('embedding_source.weight', shape=(20, 4)) p.initialize(init='xavier', ctx=mx.cpu(2)) model.set_parameters({'embedding_source.weight': p}) for i in range(2): assert mx.test_utils.same( model.collect_params()['embedding_source.weight'].data( mx.cpu(i)).asnumpy(), p.data(mx.cpu(2)).asnumpy())
def test_set_parameters(): mx = pytest.importorskip('mxnet') model = mock_model() model.initialize(init='xavier', ctx=mx.cpu(0)) model_params = model.collect_params() p = mx.gluon.Parameter('output_layer.weight', shape=(20, 4)) p.initialize(init='xavier', ctx=mx.cpu(0)) model.set_parameters({p.name: p}) assert mx.test_utils.same(model_params['output_layer.weight'].data(), p.data()) assert mx.test_utils.same(model_params['embedding_source.weight'].data(), p.data()) assert mx.test_utils.same(model_params['embedding_target.weight'].data(), p.data())
def test_set_parameters_uninitialized(): model = mock_model() model.initialize(init='xavier', ctx=mx.cpu(0)) p = mx.gluon.Parameter('source_target_embed_weight', shape=(20, 4)) with pytest.raises(AssertionError) as e: model.set_parameters({'source_target_embed_weight': p}) assert str( e.value ) == "Parameter 'source_target_embed_weight' is not initialized in new_params dictionary." p.initialize(init='xavier', ctx=mx.cpu(0)) model = mock_model() with pytest.raises(AssertionError) as e: model.set_parameters({'source_target_embed_weight': p}) assert str(e.value) == "Parameter 'source_target_embed_weight' must be initialized before it can be reset using " \ "set_parameters."
def test_set_parameters_ignore_extra(): model = mock_model() model.initialize(init='xavier', ctx=mx.cpu(0)) p = mx.gluon.Parameter('source_target_embed_weight', shape=(20, 4)) p.initialize(init='xavier', ctx=mx.cpu(0)) q = mx.gluon.Parameter('q', shape=(1, 1)) q.initialize(init='xavier', ctx=mx.cpu(0)) params = {'source_target_embed_weight': p, 'q': q} model.set_parameters(params, ignore_extra=True) assert 'source_target_embed_weight' in model.params assert 'q' not in model.params with pytest.raises(ValueError) as e: model.set_parameters(params, ignore_extra=False) assert str(e.value) == "Parameter 'q' in new_params dictionary is not preset in ParameterDict. " \ "Set ignore_extra=True to ignore."
def test_set_parameters_ignore_extra(): mx = pytest.importorskip('mxnet') model = mock_model() model.initialize(init='xavier', ctx=mx.cpu(0)) p = mx.gluon.Parameter('embedding_source.weight', shape=(20, 4)) p.initialize(init='xavier', ctx=mx.cpu(0)) q = mx.gluon.Parameter('q', shape=(1, 1)) q.initialize(init='xavier', ctx=mx.cpu(0)) params = {'embedding_source.weight': p, 'q': q} model.set_parameters(params, ignore_extra=True) assert 'embedding_source.weight' in model.collect_params() assert 'q' not in model.collect_params() with pytest.raises(ValueError) as e: model.set_parameters(params, ignore_extra=False) assert str(e.value) == "Parameter 'q' in new_params dictionary is not preset in ParameterDict. " \ "Set ignore_extra=True to ignore."
def test_set_parameters_ignore_extra_pt(): model = mock_model_pt() sockeye.model_pt.initialize_parameters(model) model_params = dict(model.named_parameters()) p = pt.nn.Parameter(pt.ones(20, 4)) np = 'embedding_source.embedding.weight' q = pt.nn.Parameter(pt.zeros(1, 1)) nq = 'q' params = {np: p, nq: q} model.set_parameters(params, ignore_extra=True) assert 'embedding_source.embedding.weight' in model_params assert 'q' not in model_params with pytest.raises(ValueError) as e: model.set_parameters(params, ignore_extra=False) assert str(e.value) == "Parameter 'q' in new_params dictionary is not present in ParameterDict. " \ "Set ignore_extra=True to ignore."