def test_train_05(self): nll = NonlinearLevelSet(n_layers=2, active_dim=1, lr=0.02, epochs=1) with assert_plot_figures_added(): nll.train(inputs=inputs_torch, gradients=grad_torch, outputs=lift, interactive=True)
def test_save_forward(self): nll = NonlinearLevelSet(n_layers=2, active_dim=1, lr=0.02, epochs=1) nll.train(inputs=inputs_torch, gradients=grad_torch, interactive=False) outfilename = 'tests/data/saved_forward.pth' nll.save_forward(outfilename) self.assertTrue(os.path.exists(outfilename)) self.addCleanup(os.remove, outfilename)
def test_plot_loss(self): nll = NonlinearLevelSet(n_layers=2, active_dim=1, lr=0.02, epochs=2) nll.train(inputs=inputs_torch, gradients=grad_torch, interactive=False) with assert_plot_figures_added(): nll.plot_loss()
def test_plot_sufficient_summary_02(self): nll = NonlinearLevelSet(n_layers=2, active_dim=2, lr=0.02, epochs=1) nll.train(inputs=inputs_torch, gradients=grad_torch, interactive=False) with self.assertRaises(ValueError): nll.plot_sufficient_summary(inputs=inputs_torch, outputs=lift)
def test_plot_sufficient_summary_01(self): nll = NonlinearLevelSet(n_layers=2, active_dim=1, lr=0.02, epochs=1) nll.train(inputs=inputs_torch, gradients=grad_torch, interactive=False) with assert_plot_figures_added(): nll.plot_sufficient_summary(inputs=inputs_torch, outputs=lift)
def test_backward_n_params(self): nll = NonlinearLevelSet(n_layers=2, active_dim=1, lr=0.02, epochs=1) nll.train(inputs=inputs_torch, gradients=grad_torch, interactive=False) self.assertEqual(nll.backward.n_params, 9)
def test_train_04(self): nll = NonlinearLevelSet(n_layers=2, active_dim=1, lr=0.02, epochs=1) with self.assertRaises(ValueError): nll.train(inputs=inputs_torch, gradients=grad_torch, interactive=True)
def test_train_03(self): nll = NonlinearLevelSet(n_layers=2, active_dim=1, lr=0.02, epochs=1) nll.train(inputs=inputs_torch, gradients=grad_torch, interactive=False) self.assertIs(len(nll.loss_vec), 1)
def test_train_02(self): nll = NonlinearLevelSet(n_layers=2, active_dim=1, lr=0.02, epochs=1) nll.train(inputs=inputs_torch, gradients=grad_torch, interactive=False) self.assertIsInstance(nll.backward, BackwardNet)