def test_set_parameters(): model = mock_model() model.initialize(init='xavier', ctx=mx.cpu(0)) p = mx.gluon.Parameter('source_target_embed_weight', shape=(20, 4)) p.initialize(init='xavier', ctx=mx.cpu(0)) model.set_parameters({'source_target_embed_weight': p}) assert mx.test_utils.same( model.params['source_target_embed_weight'].data(), p.data())
def test_set_parameters_shape(): model = mock_model() model.initialize(init='xavier', ctx=mx.cpu(0)) p = mx.gluon.Parameter('source_target_embed_weight', shape=(10, 10)) p.initialize(init='xavier', ctx=mx.cpu(0)) with pytest.raises(AssertionError) as e: model.set_parameters({'source_target_embed_weight': p}) assert str(e.value) == "Parameter 'source_target_embed_weight' has shape '(20, 4)' in the model but shape " \ "'(10, 10)' in the new_params dictionary."
def test_set_parameters_allow_missing(): model = mock_model() model.initialize(init='xavier', ctx=mx.cpu(0)) model.set_parameters({}, allow_missing=True) assert 'source_target_embed_weight' in model.params with pytest.raises(AssertionError) as e: model.set_parameters({}, allow_missing=False) assert str(e.value) == "Parameter 'source_target_embed_weight' is missing in new_params dictionary. " \ "Set allow_missing=True to ignore missing parameters."
def test_set_parameters_allow_missing(): mx = pytest.importorskip('mxnet') model = mock_model() model.initialize(init='xavier', ctx=mx.cpu(0)) model.set_parameters({}, allow_missing=True) assert 'embedding_source.weight' in model.collect_params() with pytest.raises(AssertionError) as e: model.set_parameters({}, allow_missing=False) assert str(e.value) == "Parameter 'embedding_source.weight' is missing in new_params dictionary. " \ "Set allow_missing=True to ignore missing parameters."
def test_set_parameters_context(): model = mock_model() model.initialize(init='xavier', ctx=[mx.cpu(0), mx.cpu(1)]) p = mx.gluon.Parameter('source_target_embed_weight', shape=(20, 4)) p.initialize(init='xavier', ctx=mx.cpu(2)) model.set_parameters({'source_target_embed_weight': p}) for i in range(2): assert mx.test_utils.same( model.params['source_target_embed_weight'].data(mx.cpu(i)), p.data(mx.cpu(2)))
def test_set_parameters_context(): mx = pytest.importorskip('mxnet') model = mock_model() model.initialize(init='xavier', ctx=[mx.cpu(0), mx.cpu(1)]) p = mx.gluon.Parameter('embedding_source.weight', shape=(20, 4)) p.initialize(init='xavier', ctx=mx.cpu(2)) model.set_parameters({'embedding_source.weight': p}) for i in range(2): assert mx.test_utils.same( model.collect_params()['embedding_source.weight'].data( mx.cpu(i)).asnumpy(), p.data(mx.cpu(2)).asnumpy())
def test_set_parameters(): mx = pytest.importorskip('mxnet') model = mock_model() model.initialize(init='xavier', ctx=mx.cpu(0)) model_params = model.collect_params() p = mx.gluon.Parameter('output_layer.weight', shape=(20, 4)) p.initialize(init='xavier', ctx=mx.cpu(0)) model.set_parameters({p.name: p}) assert mx.test_utils.same(model_params['output_layer.weight'].data(), p.data()) assert mx.test_utils.same(model_params['embedding_source.weight'].data(), p.data()) assert mx.test_utils.same(model_params['embedding_target.weight'].data(), p.data())
def test_set_parameters_uninitialized(): model = mock_model() model.initialize(init='xavier', ctx=mx.cpu(0)) p = mx.gluon.Parameter('source_target_embed_weight', shape=(20, 4)) with pytest.raises(AssertionError) as e: model.set_parameters({'source_target_embed_weight': p}) assert str( e.value ) == "Parameter 'source_target_embed_weight' is not initialized in new_params dictionary." p.initialize(init='xavier', ctx=mx.cpu(0)) model = mock_model() with pytest.raises(AssertionError) as e: model.set_parameters({'source_target_embed_weight': p}) assert str(e.value) == "Parameter 'source_target_embed_weight' must be initialized before it can be reset using " \ "set_parameters."
def test_set_parameters_ignore_extra(): model = mock_model() model.initialize(init='xavier', ctx=mx.cpu(0)) p = mx.gluon.Parameter('source_target_embed_weight', shape=(20, 4)) p.initialize(init='xavier', ctx=mx.cpu(0)) q = mx.gluon.Parameter('q', shape=(1, 1)) q.initialize(init='xavier', ctx=mx.cpu(0)) params = {'source_target_embed_weight': p, 'q': q} model.set_parameters(params, ignore_extra=True) assert 'source_target_embed_weight' in model.params assert 'q' not in model.params with pytest.raises(ValueError) as e: model.set_parameters(params, ignore_extra=False) assert str(e.value) == "Parameter 'q' in new_params dictionary is not preset in ParameterDict. " \ "Set ignore_extra=True to ignore."
def test_set_parameters_ignore_extra(): mx = pytest.importorskip('mxnet') model = mock_model() model.initialize(init='xavier', ctx=mx.cpu(0)) p = mx.gluon.Parameter('embedding_source.weight', shape=(20, 4)) p.initialize(init='xavier', ctx=mx.cpu(0)) q = mx.gluon.Parameter('q', shape=(1, 1)) q.initialize(init='xavier', ctx=mx.cpu(0)) params = {'embedding_source.weight': p, 'q': q} model.set_parameters(params, ignore_extra=True) assert 'embedding_source.weight' in model.collect_params() assert 'q' not in model.collect_params() with pytest.raises(ValueError) as e: model.set_parameters(params, ignore_extra=False) assert str(e.value) == "Parameter 'q' in new_params dictionary is not preset in ParameterDict. " \ "Set ignore_extra=True to ignore."