class NODE(nn.Module): def __init__(self, hidden_size, sequence_size, other_module): super().__init__() self.sequence_size = sequence_size self.other_mod = other_module self.odef = nodefunc(hidden_size + 10) self.sspan = torch.linspace(0.0, 1.0, 2) self.node = NeuralDE(self.odef, sensitivity="adjoint", solver="dopri5", rtol=0.01, atol=0.01, s_span=self.sspan) self.encode = Augmenter(augment_func=nn.Linear(hidden_size, 10)) self.decode = nn.Linear(hidden_size + 10, hidden_size) self.loss_func = nn.CrossEntropyLoss() def process_first(self, xfirst): sspan = torch.linspace(0.0, 1.0, self.sequence_size) x = torch.squeeze(xfirst, 0) enc = self.encode(x) traj = self.node.trajectory(enc, sspan) out = self.decode(traj) out = torch.unsqueeze(out, 1) return out def forward(self, x, h): return self.other_mod(x, h)
class NeuralOde(torch.nn.Module): """ A wrapper of the continuous neural network that represents the ODE. """ def __init__(self, cartpole, controller, method='dopri5'): super().__init__() self.cartpole, self.controller = cartpole, controller self.model_of_dyn_system = NeuralDE( controller, sensitivity='adjoint', solver=method ).to(device) def final_state_loss(self, state): _, dx, theta = state[:, 0], state[:, 1], state[:, 2] # get theta in [-pi,+pi] theta = pi_mod(theta + math.pi) - math.pi return 4*theta**2 + torch.abs(dx) def train(self, n_epochs=100, batch_size=200, lr_patience=10, early_stop_patience=20, epsilon=0.1): optimizer = torch.optim.Adam( self.model_of_dyn_system.parameters(), lr=.1) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, 'min', patience=lr_patience, factor=0.5) steps_since_plat, last_plat = 0, 0 for i in range(n_epochs): optimizer.zero_grad() # setup training scenario start_state = cartpole.sample_state(batch_size).to(device) # run simulation final_state = self.model_of_dyn_system(start_state) # evaluate performance loss = self.final_state_loss(final_state) step_loss = loss.mean() print("epoch: {}, loss: {}: ".format(i, step_loss)) loss.sum().backward() optimizer.step() scheduler.step(step_loss) # if stuck on minimum, stop delta_loss = abs(last_plat - step_loss.data) if ((steps_since_plat >= early_stop_patience) and (delta_loss <= epsilon)): break elif abs(last_plat - step_loss.data) > epsilon: last_plat, steps_since_plat = step_loss, 0 steps_since_plat += 1 def trajectory(self, state, T=1, time_steps=200): """ Data trajectory from t = 0 to t = T """ state = state.to(device) t = torch.linspace(0, T, time_steps).to(device) # integrate and remove batch dim traj = self.model_of_dyn_system.trajectory(state, t) return traj.detach().cpu()[:, 0, :]