def solo_train(x, hidden=2, width=2, activation=nn.ReLU, num_epochs=10, lr=0.001, momentum=0.9, corr_len=10): net = networks.flat_net(hidden_depth=hidden, width=width, activation=activation) optimiser = optim.SGD(net.parameters(), lr=lr, momentum=momentum) mse_weight = 1 loss_func = lambda ypred: second_deriv(ypred, mse_weight=mse_weight) for epoch in range(num_epochs): mse_weight = min(1 / (epoch + 1), 0.1) optimiser.zero_grad() ypred = net(x) loss = loss_func(ypred) log.debug("e: %s, loss: %s", epoch, loss) if torch.isnan(loss): raise RuntimeError("NaN loss, poorly configured experiment") yield ypred, loss loss.backward() optimiser.step()
def default_train(x, y, hidden=2, width=2, activation=nn.ReLU, num_epochs=200, lr=0.001, momentum=0.9): net = networks.flat_net(hidden_depth=hidden, width=width, activation=activation) loss_func = nn.MSELoss() optimiser = optim.SGD(net.parameters(), lr=lr, momentum=momentum) start_loss = loss_func(net(x), y) loss = 0 for epoch in range(num_epochs): optimiser.zero_grad() ypred = net(x) loss = loss_func(ypred, y) log.debug("e: %s, loss: %s", epoch, loss) if torch.isnan(loss): raise RuntimeError("NaN loss, poorly configured experiment") yield ypred, loss loss.backward() optimiser.step() log.debug("First loss %s v final %s", start_loss, loss)
def bee_trainer(xt, yt, width=2, num_epochs=200): net = networks.flat_net(1, width, activation=nn.ReLU) optimiser = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9) loss_func = torch.nn.MSELoss() for epoch in range(num_epochs): optimiser.zero_grad() ypred = net(xt) loss = loss_func(ypred, yt) if torch.isnan(loss): raise RuntimeError("NaN loss, poorly configured experiment") loss.backward() optimiser.step() weight, bias, *_ = net.parameters() yield ypred, weight.detach().flatten().numpy().copy(), bias.detach( ).numpy().copy()
def sin_experiment(): xt = torch.linspace(-6, 6, 100) yt = torch.sin(xt) num_epochs = 10 net = networks.flat_net(2, 2, activations.xTanH) optimiser = torch.optim.SGD(net.parameters(), lr=0.002, momentum=0.9) loss_func = torch.nn.MSELoss() for epoch in range(num_epochs): optimiser.zero_grad() ypred = net(xt) loss = loss_func(ypred, yt) if torch.isnan(loss): raise RuntimeError("NaN loss, poorly configured experiment") yield ypred, loss loss.backward() optimiser.step()
def test_flat(): networks.flat_net(2, 3)
def test_flat_forward(): xt = torch.linspace(-5, 5, 100) nw = networks.flat_net(4, 4) yt = nw.forward(xt) assert yt.size() == (100, )
def test_invalid_flat(): with pytest.raises(ValueError): networks.flat_net(0, 3)