def __init__(self, latent_dim, layers=1):
        """
        affine model with adaptive covariance
        :param latent_dim:
        """
        super(GaussPriorNonlinearAdaptive, self).__init__()

        self.latent_dim = latent_dim
        self.adaptive_mean = create_net(latent_dim,
                                        latent_dim,
                                        n_layers=layers)
        self.log_var_net = create_net(latent_dim, latent_dim, n_layers=layers)
    def __init__(self, latent_dim, layers=1):
        """
        affine model with constant covariance
        :param latent_dim:
        """
        super(GaussPriorNonlinearConstant, self).__init__()

        self.adaptive_mean = create_net(latent_dim, latent_dim)
        # self.log_prob_f = lambda dy_dt, y: self.adaptive_mean(y) + torch.exp(self.log_var)

        self.latent_dim = latent_dim
        self.adaptive_mean = create_net(latent_dim, latent_dim, layers)
        self.log_var = torch.nn.Parameter(torch.randn(latent_dim))
Beispiel #3
0
    def __init__(self, latent_dim, ode_func_layers, ode_func_units, input_dim,
                 decoder_units):
        super(ODE_RNN, self).__init__()

        ode_func_net = utils.create_net(latent_dim,
                                        latent_dim,
                                        n_layers=ode_func_layers,
                                        n_units=ode_func_units,
                                        nonlinear=nn.Tanh)

        utils.init_network_weights(ode_func_net)

        rec_ode_func = ODEFunc(ode_func_net=ode_func_net)

        self.ode_solver = DiffeqSolver(rec_ode_func,
                                       "euler",
                                       odeint_rtol=1e-3,
                                       odeint_atol=1e-4)

        self.decoder = nn.Sequential(nn.Linear(latent_dim, decoder_units),
                                     nn.Tanh(),
                                     nn.Linear(decoder_units, input_dim * 2))

        utils.init_network_weights(self.decoder)

        self.gru_unit = GRU_Unit(latent_dim, input_dim, n_units=decoder_units)

        self.latent_dim = latent_dim

        self.sigma_fn = nn.Softplus()
Beispiel #4
0
def get_diffeq_solver(ode_latents,
                      ode_units,
                      rec_layers,
                      ode_method,
                      ode_type="linear",
                      device=torch.device("cpu")):

    if ode_type == "linear":
        ode_func_net = utils.create_net(ode_latents,
                                        ode_latents,
                                        n_layers=int(rec_layers),
                                        n_units=int(ode_units),
                                        nonlinear=nn.Tanh)
    elif ode_type == "gru":
        ode_func_net = FullGRUODECell_Autonomous(ode_latents, bias=True)
    else:
        raise Exception("Invalid ODE-type. Choose linear or gru.")

    rec_ode_func = ODEFunc(input_dim=0,
                           latent_dim=ode_latents,
                           ode_func_net=ode_func_net,
                           device=device).to(device)

    z0_diffeq_solver = DiffeqSolver(0,
                                    rec_ode_func,
                                    ode_method,
                                    ode_latents,
                                    odeint_rtol=1e-3,
                                    odeint_atol=1e-4,
                                    device=device)

    return z0_diffeq_solver
Beispiel #5
0
    def __init__(self, output_dim, input_dim, hidden_dim, layer_num):
        super(Encoder, self).__init__()
        # decode data from latent space where we are solving an ODE back to the data space

        encoder = utils.create_net(input_dim, output_dim * 2, layer_num,
                                   hidden_dim)

        utils.init_network_weights(encoder)
        self.encoder = encoder
Beispiel #6
0
    def __init__(self, input_dim, latent_dim, device=torch.device("cpu")):
        """
        input_dim: dimensionality of the input
        latent_dim: dimensionality used for ODE. Analog of a continous latent state
        """
        super(CDEFunc, self).__init__()

        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.device = device
        self.interpolation = None

        # Equ 3 in Neural Controlled Differential Equations for Irregular Time Series
        self.cde_func = utils.create_net(latent_dim,
                                         latent_dim * (input_dim + 1),
                                         n_units=10,
                                         n_layers=1)

        utils.init_network_weights(self.cde_func)
def create_ode_rnn_encoder(args, device):
    """
    This function create the ode-rnn model as an encoder
    args: the arguments from parse_arguments in ctfp_tools
    device: cpu or gpu to run the model
    return an ode-rnn model 
    """
    enc_input_dim = args.input_size * 2  ## concatenate the mask with input

    ode_func_net = utils.create_net(
        args.rec_size,
        args.rec_size,
        n_layers=args.rec_layers,
        n_units=args.units,
        nonlinear=nn.Tanh,
    )

    rec_ode_func = ODEFunc(
        input_dim=enc_input_dim,
        latent_dim=args.rec_size,
        ode_func_net=ode_func_net,
        device=device,
    ).to(device)

    z0_diffeq_solver = DiffeqSolver(
        enc_input_dim,
        rec_ode_func,
        "euler",
        args.latent_size,
        odeint_rtol=1e-3,
        odeint_atol=1e-4,
        device=device,
    )

    encoder_z0 = Encoder_z0_ODE_RNN(
        args.rec_size,
        enc_input_dim,
        z0_diffeq_solver,
        z0_dim=args.latent_size,
        n_gru_units=args.gru_units,
        device=device,
    ).to(device)
    return encoder_z0
Beispiel #8
0
def create_LatentODE_model(args,
                           input_dim,
                           z0_prior,
                           obsrv_std,
                           device,
                           classif_per_tp=False,
                           n_labels=1):

    dim = args.latents
    if args.poisson:
        lambda_net = utils.create_net(dim,
                                      input_dim,
                                      n_layers=1,
                                      n_units=args.units,
                                      nonlinear=nn.Tanh)

        # ODE function produces the gradient for latent state and for poisson rate
        ode_func_net = utils.create_net(dim * 2,
                                        args.latents * 2,
                                        n_layers=args.gen_layers,
                                        n_units=args.units,
                                        nonlinear=nn.Tanh)

        gen_ode_func = ODEFunc_w_Poisson(input_dim=input_dim,
                                         latent_dim=args.latents * 2,
                                         ode_func_net=ode_func_net,
                                         lambda_net=lambda_net,
                                         device=device).to(device)
    else:
        dim = args.latents
        ode_func_net = utils.create_net(dim,
                                        args.latents,
                                        n_layers=args.gen_layers,
                                        n_units=args.units,
                                        nonlinear=nn.Tanh)

        gen_ode_func = ODEFunc(input_dim=input_dim,
                               latent_dim=args.latents,
                               ode_func_net=ode_func_net,
                               device=device).to(device)

    z0_diffeq_solver = None
    n_rec_dims = args.rec_dims
    enc_input_dim = int(input_dim) * 2  # we concatenate the mask
    gen_data_dim = input_dim

    z0_dim = args.latents
    if args.poisson:
        z0_dim += args.latents  # predict the initial poisson rate

    if args.z0_encoder == "odernn":
        ode_func_net = utils.create_net(n_rec_dims,
                                        n_rec_dims,
                                        n_layers=args.rec_layers,
                                        n_units=args.units,
                                        nonlinear=nn.Tanh)

        rec_ode_func = ODEFunc(input_dim=enc_input_dim,
                               latent_dim=n_rec_dims,
                               ode_func_net=ode_func_net,
                               device=device).to(device)

        z0_diffeq_solver = DiffeqSolver(enc_input_dim,
                                        rec_ode_func,
                                        "euler",
                                        args.latents,
                                        odeint_rtol=1e-3,
                                        odeint_atol=1e-4,
                                        device=device)

        encoder_z0 = Encoder_z0_ODE_RNN(n_rec_dims,
                                        enc_input_dim,
                                        z0_diffeq_solver,
                                        z0_dim=z0_dim,
                                        n_gru_units=args.gru_units,
                                        device=device).to(device)

    elif args.z0_encoder == "rnn":
        encoder_z0 = Encoder_z0_RNN(z0_dim,
                                    enc_input_dim,
                                    lstm_output_size=n_rec_dims,
                                    device=device).to(device)

    # my code
    elif args.z0_encoder == "trans":
        encoder_z0 = TransformerLayer(args.latents * 2, input_dim,
                                      nhidden=20).to(device)

    else:
        raise Exception("Unknown encoder for Latent ODE model: " +
                        args.z0_encoder)

    decoder = Decoder(args.latents, gen_data_dim).to(device)

    diffeq_solver = DiffeqSolver(gen_data_dim,
                                 gen_ode_func,
                                 'dopri5',
                                 args.latents,
                                 odeint_rtol=1e-3,
                                 odeint_atol=1e-4,
                                 device=device)

    model = LatentODE(
        input_dim=gen_data_dim,
        latent_dim=args.latents,
        encoder_z0=encoder_z0,
        decoder=decoder,
        diffeq_solver=diffeq_solver,
        z0_prior=z0_prior,
        device=device,
        obsrv_std=obsrv_std,
        use_poisson_proc=args.poisson,
        use_binary_classif=args.classif,
        linear_classifier=args.linear_classif,
        classif_per_tp=classif_per_tp,
        n_labels=n_labels,
        train_classif_w_reconstr=(args.dataset == "physionet")).to(device)

    return model
Beispiel #9
0
            train_classif_w_reconstr=(args.dataset == "physionet")).to(device)
    elif args.ode_rnn:
        # Create ODE-GRU model
        n_ode_gru_dims = args.latents

        if args.poisson:
            print(
                "Poisson process likelihood not implemented for ODE-RNN: ignoring --poisson"
            )

        if args.extrap:
            raise Exception("Extrapolation for ODE-RNN not implemented")

        ode_func_net = utils.create_net(n_ode_gru_dims,
                                        n_ode_gru_dims,
                                        n_layers=args.rec_layers,
                                        n_units=args.units,
                                        nonlinear=nn.Tanh)

        rec_ode_func = ODEFunc(input_dim=input_dim,
                               latent_dim=n_ode_gru_dims,
                               ode_func_net=ode_func_net,
                               device=device).to(device)

        z0_diffeq_solver = DiffeqSolver(input_dim,
                                        rec_ode_func,
                                        "euler",
                                        args.latents,
                                        odeint_rtol=1e-3,
                                        odeint_atol=1e-4,
                                        device=device)