def test_call_invalid_model(): # type: () -> None with pytest.raises(ValueError): set_up_predictor(method='invalid', n_unit=n_unit, conv_layers=conv_layers, class_num=class_num)
def test_setup_predictor(models_dict): # type: (Dict[str, chainer.Link]) -> None for method, instance in models_dict.items(): predictor = set_up_predictor(method=method, n_unit=n_unit, conv_layers=conv_layers, class_num=class_num) assert isinstance(predictor.graph_conv, instance) assert isinstance(predictor, GraphConvPredictor)
def test_set_up_predictor_with_conv_kwargs(): # type: () -> None predictor = set_up_predictor(method='nfp', n_unit=n_unit, conv_layers=conv_layers, class_num=class_num, conv_kwargs={ 'max_degree': 4, 'concat_hidden': True }) assert predictor.graph_conv.max_degree == 4 assert predictor.graph_conv.concat_hidden is True