Ejemplo n.º 1
0
class NODE(nn.Module):
    def __init__(self, hidden_size, sequence_size, other_module):
        super().__init__()
        self.sequence_size = sequence_size

        self.other_mod = other_module

        self.odef = nodefunc(hidden_size + 10)
        self.sspan = torch.linspace(0.0, 1.0, 2)
        self.node = NeuralDE(self.odef,
                             sensitivity="adjoint",
                             solver="dopri5",
                             rtol=0.01,
                             atol=0.01,
                             s_span=self.sspan)
        self.encode = Augmenter(augment_func=nn.Linear(hidden_size, 10))
        self.decode = nn.Linear(hidden_size + 10, hidden_size)

        self.loss_func = nn.CrossEntropyLoss()

    def process_first(self, xfirst):
        sspan = torch.linspace(0.0, 1.0, self.sequence_size)
        x = torch.squeeze(xfirst, 0)
        enc = self.encode(x)
        traj = self.node.trajectory(enc, sspan)
        out = self.decode(traj)
        out = torch.unsqueeze(out, 1)
        return out

    def forward(self, x, h):
        return self.other_mod(x, h)
Ejemplo n.º 2
0
class NeuralOde(torch.nn.Module):
    """
    A wrapper of the continuous neural network that represents the ODE.

    """
    def __init__(self, cartpole, controller, method='dopri5'):
        super().__init__()
        self.cartpole, self.controller = cartpole, controller

        self.model_of_dyn_system = NeuralDE(
            controller, sensitivity='adjoint', solver=method
        ).to(device)

    def final_state_loss(self, state):
        _, dx, theta = state[:, 0], state[:, 1], state[:, 2]

        # get theta in [-pi,+pi]
        theta = pi_mod(theta + math.pi) - math.pi

        return 4*theta**2 + torch.abs(dx)

    def train(self, n_epochs=100, batch_size=200, lr_patience=10,
              early_stop_patience=20, epsilon=0.1):
        optimizer = torch.optim.Adam(
            self.model_of_dyn_system.parameters(), lr=.1)

        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 'min', patience=lr_patience, factor=0.5)

        steps_since_plat, last_plat = 0, 0
        for i in range(n_epochs):
            optimizer.zero_grad()

            # setup training scenario
            start_state = cartpole.sample_state(batch_size).to(device)

            # run simulation
            final_state = self.model_of_dyn_system(start_state)

            # evaluate performance
            loss = self.final_state_loss(final_state)
            step_loss = loss.mean()
            print("epoch: {}, loss: {}: ".format(i, step_loss))

            loss.sum().backward()
            optimizer.step()
            scheduler.step(step_loss)

            # if stuck on minimum, stop
            delta_loss = abs(last_plat - step_loss.data)
            if ((steps_since_plat >= early_stop_patience) and
                (delta_loss <= epsilon)):
                break
            elif abs(last_plat - step_loss.data) > epsilon:
                last_plat, steps_since_plat = step_loss, 0
            steps_since_plat += 1

    def trajectory(self, state, T=1, time_steps=200):
        """
        Data trajectory from t = 0 to t = T

        """

        state = state.to(device)
        t = torch.linspace(0, T, time_steps).to(device)

        # integrate and remove batch dim
        traj = self.model_of_dyn_system.trajectory(state, t)
        return traj.detach().cpu()[:, 0, :]