def test_starnet2017(self): # Data preparation, keep the data size large (>800 data points to prevent issues) random_xdata = np.random.normal(0, 1, (200, 7514)) random_ydata = np.random.normal(0, 1, (200, 25)) # StarNet2017 print("======StarNet2017======") starnet2017 = StarNet2017() starnet2017.max_epochs = 1 starnet2017.callbacks = ErrorOnNaN() starnet2017.train(random_xdata, random_ydata) prediction = starnet2017.test(random_xdata) np.testing.assert_array_equal(prediction.shape, random_ydata.shape) starnet2017.save(name='starnet2017')
def test_starnet2017(self): # StarNet2017 print("======StarNet2017======") starnet2017 = StarNet2017() starnet2017.max_epochs = 1 starnet2017.callbacks = ErrorOnNaN() starnet2017.train(random_xdata, random_ydata) prediction = starnet2017.test(random_xdata) jacobian = starnet2017.jacobian(random_xdata[:10]) np.testing.assert_array_equal(prediction.shape, random_ydata.shape) np.testing.assert_array_equal(jacobian.shape, [random_xdata[:10].shape[0], random_ydata.shape[1], random_xdata.shape[1]]) starnet2017.save(name='starnet2017') starnet2017_loaded = load_folder("starnet2017") prediction_loaded = starnet2017_loaded.test(random_xdata) # StarNet2017 is deterministic np.testing.assert_array_equal(prediction, prediction_loaded) # Fine-tuning test starnet2017_loaded.max_epochs = 1 starnet2017.callbacks = ErrorOnNaN() starnet2017_loaded.train(random_xdata, random_ydata)