Exemplo n.º 1
0
    def __init__(self, latent_dim, ode_func_layers, ode_func_units, input_dim,
                 decoder_units):
        super(ODE_RNN, self).__init__()

        ode_func_net = utils.create_net(latent_dim,
                                        latent_dim,
                                        n_layers=ode_func_layers,
                                        n_units=ode_func_units,
                                        nonlinear=nn.Tanh)

        utils.init_network_weights(ode_func_net)

        rec_ode_func = ODEFunc(ode_func_net=ode_func_net)

        self.ode_solver = DiffeqSolver(rec_ode_func,
                                       "euler",
                                       odeint_rtol=1e-3,
                                       odeint_atol=1e-4)

        self.decoder = nn.Sequential(nn.Linear(latent_dim, decoder_units),
                                     nn.Tanh(),
                                     nn.Linear(decoder_units, input_dim * 2))

        utils.init_network_weights(self.decoder)

        self.gru_unit = GRU_Unit(latent_dim, input_dim, n_units=decoder_units)

        self.latent_dim = latent_dim

        self.sigma_fn = nn.Softplus()
Exemplo n.º 2
0
def get_odernn_encoder(input_dim, latent_dim, odernn_hypers, device):
    ode_func_net = create_net(latent_dim,
                              latent_dim,
                              n_layers=odernn_hypers['odefunc_n_layers'],
                              n_units=odernn_hypers['odefunc_n_units'])

    ode_func = ODEFunc(input_dim=input_dim * 2,
                       latent_dim=latent_dim,
                       ode_func_net=ode_func_net,
                       device=device).to(device)

    diffeq_solver = DiffeqSolver(input_dim * 2,
                                 ode_func,
                                 'dopri5',
                                 latent_dim,
                                 odeint_rtol=1e-3,
                                 odeint_atol=1e-4,
                                 device=device)

    encoder = Encoder_z0_ODE_RNN(
        latent_dim,
        input_dim * 2,
        diffeq_solver,
        z0_dim=latent_dim,
        n_gru_units=odernn_hypers['encoder_gru_units'],
        device=device).to(device)

    return encoder
Exemplo n.º 3
0
def get_diffeq_solver(ode_latents,
                      ode_units,
                      rec_layers,
                      ode_method,
                      ode_type="linear",
                      device=torch.device("cpu")):

    if ode_type == "linear":
        ode_func_net = utils.create_net(ode_latents,
                                        ode_latents,
                                        n_layers=int(rec_layers),
                                        n_units=int(ode_units),
                                        nonlinear=nn.Tanh)
    elif ode_type == "gru":
        ode_func_net = FullGRUODECell_Autonomous(ode_latents, bias=True)
    else:
        raise Exception("Invalid ODE-type. Choose linear or gru.")

    rec_ode_func = ODEFunc(input_dim=0,
                           latent_dim=ode_latents,
                           ode_func_net=ode_func_net,
                           device=device).to(device)

    z0_diffeq_solver = DiffeqSolver(0,
                                    rec_ode_func,
                                    ode_method,
                                    ode_latents,
                                    odeint_rtol=1e-3,
                                    odeint_atol=1e-4,
                                    device=device)

    return z0_diffeq_solver
Exemplo n.º 4
0
def create_LatentODE_model(args, input_dim, z0_prior, obsrv_std, device):

    # dim related
    latent_dim = args.latents  # ode output dimension
    rec_dim = args.rec_dims
    input_dim = input_dim
    ode_dim = args.ode_dims  #ode gcn dimension

    #encoder related

    encoder_z0 = GNN(in_dim=input_dim,
                     n_hid=rec_dim,
                     out_dim=latent_dim,
                     n_heads=args.n_heads,
                     n_layers=args.rec_layers,
                     dropout=args.dropout,
                     conv_name=args.z0_encoder,
                     aggregate=args.rec_attention)  # [b,n_ball,e]

    #ODE related
    if args.augment_dim > 0:
        ode_input_dim = latent_dim + args.augment_dim
    else:
        ode_input_dim = latent_dim

    ode_func_net = GNN(in_dim=ode_input_dim,
                       n_hid=ode_dim,
                       out_dim=ode_input_dim,
                       n_heads=args.n_heads,
                       n_layers=args.gen_layers,
                       dropout=args.dropout,
                       conv_name=args.odenet,
                       aggregate="add")

    gen_ode_func = GraphODEFunc(ode_func_net=ode_func_net,
                                device=device).to(device)

    diffeq_solver = DiffeqSolver(gen_ode_func,
                                 args.solver,
                                 args=args,
                                 odeint_rtol=1e-2,
                                 odeint_atol=1e-2,
                                 device=device)

    #Decoder related
    decoder = Decoder(latent_dim, input_dim).to(device)

    model = LatentGraphODE(
        input_dim=input_dim,
        latent_dim=args.latents,
        encoder_z0=encoder_z0,
        decoder=decoder,
        diffeq_solver=diffeq_solver,
        z0_prior=z0_prior,
        device=device,
        obsrv_std=obsrv_std,
    ).to(device)

    return model
def create_ode_rnn_encoder(args, device):
    """
    This function create the ode-rnn model as an encoder
    args: the arguments from parse_arguments in ctfp_tools
    device: cpu or gpu to run the model
    return an ode-rnn model 
    """
    enc_input_dim = args.input_size * 2  ## concatenate the mask with input

    ode_func_net = utils.create_net(
        args.rec_size,
        args.rec_size,
        n_layers=args.rec_layers,
        n_units=args.units,
        nonlinear=nn.Tanh,
    )

    rec_ode_func = ODEFunc(
        input_dim=enc_input_dim,
        latent_dim=args.rec_size,
        ode_func_net=ode_func_net,
        device=device,
    ).to(device)

    z0_diffeq_solver = DiffeqSolver(
        enc_input_dim,
        rec_ode_func,
        "euler",
        args.latent_size,
        odeint_rtol=1e-3,
        odeint_atol=1e-4,
        device=device,
    )

    encoder_z0 = Encoder_z0_ODE_RNN(
        args.rec_size,
        enc_input_dim,
        z0_diffeq_solver,
        z0_dim=args.latent_size,
        n_gru_units=args.gru_units,
        device=device,
    ).to(device)
    return encoder_z0
Exemplo n.º 6
0
def create_LatentODE_model(args,
                           input_dim,
                           z0_prior,
                           obsrv_std,
                           device,
                           classif_per_tp=False,
                           n_labels=1):

    dim = args.latents
    if args.poisson:
        lambda_net = utils.create_net(dim,
                                      input_dim,
                                      n_layers=1,
                                      n_units=args.units,
                                      nonlinear=nn.Tanh)

        # ODE function produces the gradient for latent state and for poisson rate
        ode_func_net = utils.create_net(dim * 2,
                                        args.latents * 2,
                                        n_layers=args.gen_layers,
                                        n_units=args.units,
                                        nonlinear=nn.Tanh)

        gen_ode_func = ODEFunc_w_Poisson(input_dim=input_dim,
                                         latent_dim=args.latents * 2,
                                         ode_func_net=ode_func_net,
                                         lambda_net=lambda_net,
                                         device=device).to(device)
    else:
        dim = args.latents
        ode_func_net = utils.create_net(dim,
                                        args.latents,
                                        n_layers=args.gen_layers,
                                        n_units=args.units,
                                        nonlinear=nn.Tanh)

        gen_ode_func = ODEFunc(input_dim=input_dim,
                               latent_dim=args.latents,
                               ode_func_net=ode_func_net,
                               device=device).to(device)

    z0_diffeq_solver = None
    n_rec_dims = args.rec_dims
    enc_input_dim = int(input_dim) * 2  # we concatenate the mask
    gen_data_dim = input_dim

    z0_dim = args.latents
    if args.poisson:
        z0_dim += args.latents  # predict the initial poisson rate

    if args.z0_encoder == "odernn":
        ode_func_net = utils.create_net(n_rec_dims,
                                        n_rec_dims,
                                        n_layers=args.rec_layers,
                                        n_units=args.units,
                                        nonlinear=nn.Tanh)

        rec_ode_func = ODEFunc(input_dim=enc_input_dim,
                               latent_dim=n_rec_dims,
                               ode_func_net=ode_func_net,
                               device=device).to(device)

        z0_diffeq_solver = DiffeqSolver(enc_input_dim,
                                        rec_ode_func,
                                        "euler",
                                        args.latents,
                                        odeint_rtol=1e-3,
                                        odeint_atol=1e-4,
                                        device=device)

        encoder_z0 = Encoder_z0_ODE_RNN(n_rec_dims,
                                        enc_input_dim,
                                        z0_diffeq_solver,
                                        z0_dim=z0_dim,
                                        n_gru_units=args.gru_units,
                                        device=device).to(device)

    elif args.z0_encoder == "rnn":
        encoder_z0 = Encoder_z0_RNN(z0_dim,
                                    enc_input_dim,
                                    lstm_output_size=n_rec_dims,
                                    device=device).to(device)

    # my code
    elif args.z0_encoder == "trans":
        encoder_z0 = TransformerLayer(args.latents * 2, input_dim,
                                      nhidden=20).to(device)

    else:
        raise Exception("Unknown encoder for Latent ODE model: " +
                        args.z0_encoder)

    decoder = Decoder(args.latents, gen_data_dim).to(device)

    diffeq_solver = DiffeqSolver(gen_data_dim,
                                 gen_ode_func,
                                 'dopri5',
                                 args.latents,
                                 odeint_rtol=1e-3,
                                 odeint_atol=1e-4,
                                 device=device)

    model = LatentODE(
        input_dim=gen_data_dim,
        latent_dim=args.latents,
        encoder_z0=encoder_z0,
        decoder=decoder,
        diffeq_solver=diffeq_solver,
        z0_prior=z0_prior,
        device=device,
        obsrv_std=obsrv_std,
        use_poisson_proc=args.poisson,
        use_binary_classif=args.classif,
        linear_classifier=args.linear_classif,
        classif_per_tp=classif_per_tp,
        n_labels=n_labels,
        train_classif_w_reconstr=(args.dataset == "physionet")).to(device)

    return model
Exemplo n.º 7
0
        ode_func_net = utils.create_net(n_ode_gru_dims,
                                        n_ode_gru_dims,
                                        n_layers=args.rec_layers,
                                        n_units=args.units,
                                        nonlinear=nn.Tanh)

        rec_ode_func = ODEFunc(input_dim=input_dim,
                               latent_dim=n_ode_gru_dims,
                               ode_func_net=ode_func_net,
                               device=device).to(device)

        z0_diffeq_solver = DiffeqSolver(input_dim,
                                        rec_ode_func,
                                        "euler",
                                        args.latents,
                                        odeint_rtol=1e-3,
                                        odeint_atol=1e-4,
                                        device=device)

        model = ODE_RNN(
            input_dim,
            n_ode_gru_dims,
            device=device,
            z0_diffeq_solver=z0_diffeq_solver,
            n_gru_units=args.gru_units,
            concat_mask=True,
            obsrv_std=obsrv_std,
            use_binary_classif=args.classif,
            classif_per_tp=classif_per_tp,
            n_labels=n_labels,
Exemplo n.º 8
0
class ODE_RNN(nn.Module):
    """Class for standalone ODE-RNN model. Makes predictions forward in time."""
    def __init__(self, latent_dim, ode_func_layers, ode_func_units, input_dim,
                 decoder_units):
        super(ODE_RNN, self).__init__()

        ode_func_net = utils.create_net(latent_dim,
                                        latent_dim,
                                        n_layers=ode_func_layers,
                                        n_units=ode_func_units,
                                        nonlinear=nn.Tanh)

        utils.init_network_weights(ode_func_net)

        rec_ode_func = ODEFunc(ode_func_net=ode_func_net)

        self.ode_solver = DiffeqSolver(rec_ode_func,
                                       "euler",
                                       odeint_rtol=1e-3,
                                       odeint_atol=1e-4)

        self.decoder = nn.Sequential(nn.Linear(latent_dim, decoder_units),
                                     nn.Tanh(),
                                     nn.Linear(decoder_units, input_dim * 2))

        utils.init_network_weights(self.decoder)

        self.gru_unit = GRU_Unit(latent_dim, input_dim, n_units=decoder_units)

        self.latent_dim = latent_dim

        self.sigma_fn = nn.Softplus()

    def forward(self,
                data,
                mask,
                mask_first,
                time_steps,
                extrap_time=float('inf'),
                use_sampling=False):

        batch_size, n_time_steps, n_dims = data.size()

        prev_hidden = torch.zeros((batch_size, self.latent_dim))
        prev_hidden_std = torch.zeros((batch_size, self.latent_dim))

        if data.is_cuda:
            prev_hidden = prev_hidden.to(data.get_device())
            prev_hidden_std = prev_hidden_std.to(data.get_device())

        interval_length = time_steps[-1] - time_steps[0]
        minimum_step = interval_length / 50

        outputs = []
        outputs_std = []

        prev_observation = data[:, 0]

        if use_sampling:
            prev_output = data[:, 0]

        for i in range(1, len(time_steps)):

            # Make one step.
            if time_steps[i] - time_steps[i - 1] < minimum_step:
                inc = self.ode_solver.ode_func(time_steps[i - 1], prev_hidden)

                ode_sol = prev_hidden + inc * (time_steps[i] -
                                               time_steps[i - 1])
                ode_sol = torch.stack((prev_hidden, ode_sol), 1)
            # Several steps.
            else:
                num_intermediate_steps = max(
                    2,
                    ((time_steps[i] - time_steps[i - 1]) / minimum_step).int())

                time_points = torch.linspace(time_steps[i - 1], time_steps[i],
                                             num_intermediate_steps)
                ode_sol = self.ode_solver(prev_hidden.unsqueeze(0),
                                          time_points)[0]

            hidden_ode = ode_sol[:, -1]

            x_i = prev_observation

            if use_sampling and np.random.uniform(
                    0, 1) < 0.5 and time_steps[i] <= extrap_time:
                x_i = prev_output

            mask_i = mask[:, i]

            output_hidden, hidden, hidden_std = self.gru_unit(
                hidden_ode, prev_hidden_std, x_i, mask_i)

            hidden = mask_first[:, i - 1] * hidden
            hidden_std = mask_first[:, i - 1] * hidden_std

            prev_hidden, prev_hidden_std = hidden, hidden_std

            mean, std = torch.chunk(self.decoder(output_hidden),
                                    chunks=2,
                                    dim=-1)

            outputs += [mean]
            outputs_std += [self.sigma_fn(std)]

            if use_sampling:
                prev_output = prev_output * (1 - mask_i) + mask_i * outputs[-1]

            if time_steps[i] <= extrap_time:
                prev_observation = prev_observation * (
                    1 - mask_i) + mask_i * data[:, i]
            else:
                prev_observation = prev_observation * (
                    1 - mask_i) + mask_i * outputs[-1]

        outputs = torch.stack(outputs, 1)
        outputs_std = torch.stack(outputs_std, 1)

        return outputs, outputs_std

    @property
    def num_params(self):
        """Number of parameters."""
        return np.sum(
            [torch.tensor(param.shape).prod() for param in self.parameters()])