예제 #1
0
    def __init__(self, latent_dim, ode_func_layers, ode_func_units, input_dim,
                 decoder_units):
        super(ODE_RNN, self).__init__()

        ode_func_net = utils.create_net(latent_dim,
                                        latent_dim,
                                        n_layers=ode_func_layers,
                                        n_units=ode_func_units,
                                        nonlinear=nn.Tanh)

        utils.init_network_weights(ode_func_net)

        rec_ode_func = ODEFunc(ode_func_net=ode_func_net)

        self.ode_solver = DiffeqSolver(rec_ode_func,
                                       "euler",
                                       odeint_rtol=1e-3,
                                       odeint_atol=1e-4)

        self.decoder = nn.Sequential(nn.Linear(latent_dim, decoder_units),
                                     nn.Tanh(),
                                     nn.Linear(decoder_units, input_dim * 2))

        utils.init_network_weights(self.decoder)

        self.gru_unit = GRU_Unit(latent_dim, input_dim, n_units=decoder_units)

        self.latent_dim = latent_dim

        self.sigma_fn = nn.Softplus()
예제 #2
0
def get_odernn_encoder(input_dim, latent_dim, odernn_hypers, device):
    ode_func_net = create_net(latent_dim,
                              latent_dim,
                              n_layers=odernn_hypers['odefunc_n_layers'],
                              n_units=odernn_hypers['odefunc_n_units'])

    ode_func = ODEFunc(input_dim=input_dim * 2,
                       latent_dim=latent_dim,
                       ode_func_net=ode_func_net,
                       device=device).to(device)

    diffeq_solver = DiffeqSolver(input_dim * 2,
                                 ode_func,
                                 'dopri5',
                                 latent_dim,
                                 odeint_rtol=1e-3,
                                 odeint_atol=1e-4,
                                 device=device)

    encoder = Encoder_z0_ODE_RNN(
        latent_dim,
        input_dim * 2,
        diffeq_solver,
        z0_dim=latent_dim,
        n_gru_units=odernn_hypers['encoder_gru_units'],
        device=device).to(device)

    return encoder
예제 #3
0
def get_diffeq_solver(ode_latents,
                      ode_units,
                      rec_layers,
                      ode_method,
                      ode_type="linear",
                      device=torch.device("cpu")):

    if ode_type == "linear":
        ode_func_net = utils.create_net(ode_latents,
                                        ode_latents,
                                        n_layers=int(rec_layers),
                                        n_units=int(ode_units),
                                        nonlinear=nn.Tanh)
    elif ode_type == "gru":
        ode_func_net = FullGRUODECell_Autonomous(ode_latents, bias=True)
    else:
        raise Exception("Invalid ODE-type. Choose linear or gru.")

    rec_ode_func = ODEFunc(input_dim=0,
                           latent_dim=ode_latents,
                           ode_func_net=ode_func_net,
                           device=device).to(device)

    z0_diffeq_solver = DiffeqSolver(0,
                                    rec_ode_func,
                                    ode_method,
                                    ode_latents,
                                    odeint_rtol=1e-3,
                                    odeint_atol=1e-4,
                                    device=device)

    return z0_diffeq_solver
def create_ode_rnn_encoder(args, device):
    """
    This function create the ode-rnn model as an encoder
    args: the arguments from parse_arguments in ctfp_tools
    device: cpu or gpu to run the model
    return an ode-rnn model 
    """
    enc_input_dim = args.input_size * 2  ## concatenate the mask with input

    ode_func_net = utils.create_net(
        args.rec_size,
        args.rec_size,
        n_layers=args.rec_layers,
        n_units=args.units,
        nonlinear=nn.Tanh,
    )

    rec_ode_func = ODEFunc(
        input_dim=enc_input_dim,
        latent_dim=args.rec_size,
        ode_func_net=ode_func_net,
        device=device,
    ).to(device)

    z0_diffeq_solver = DiffeqSolver(
        enc_input_dim,
        rec_ode_func,
        "euler",
        args.latent_size,
        odeint_rtol=1e-3,
        odeint_atol=1e-4,
        device=device,
    )

    encoder_z0 = Encoder_z0_ODE_RNN(
        args.rec_size,
        enc_input_dim,
        z0_diffeq_solver,
        z0_dim=args.latent_size,
        n_gru_units=args.gru_units,
        device=device,
    ).to(device)
    return encoder_z0
예제 #5
0
def create_LatentODE_model(args,
                           input_dim,
                           z0_prior,
                           obsrv_std,
                           device,
                           classif_per_tp=False,
                           n_labels=1):

    dim = args.latents
    if args.poisson:
        lambda_net = utils.create_net(dim,
                                      input_dim,
                                      n_layers=1,
                                      n_units=args.units,
                                      nonlinear=nn.Tanh)

        # ODE function produces the gradient for latent state and for poisson rate
        ode_func_net = utils.create_net(dim * 2,
                                        args.latents * 2,
                                        n_layers=args.gen_layers,
                                        n_units=args.units,
                                        nonlinear=nn.Tanh)

        gen_ode_func = ODEFunc_w_Poisson(input_dim=input_dim,
                                         latent_dim=args.latents * 2,
                                         ode_func_net=ode_func_net,
                                         lambda_net=lambda_net,
                                         device=device).to(device)
    else:
        dim = args.latents
        ode_func_net = utils.create_net(dim,
                                        args.latents,
                                        n_layers=args.gen_layers,
                                        n_units=args.units,
                                        nonlinear=nn.Tanh)

        gen_ode_func = ODEFunc(input_dim=input_dim,
                               latent_dim=args.latents,
                               ode_func_net=ode_func_net,
                               device=device).to(device)

    z0_diffeq_solver = None
    n_rec_dims = args.rec_dims
    enc_input_dim = int(input_dim) * 2  # we concatenate the mask
    gen_data_dim = input_dim

    z0_dim = args.latents
    if args.poisson:
        z0_dim += args.latents  # predict the initial poisson rate

    if args.z0_encoder == "odernn":
        ode_func_net = utils.create_net(n_rec_dims,
                                        n_rec_dims,
                                        n_layers=args.rec_layers,
                                        n_units=args.units,
                                        nonlinear=nn.Tanh)

        rec_ode_func = ODEFunc(input_dim=enc_input_dim,
                               latent_dim=n_rec_dims,
                               ode_func_net=ode_func_net,
                               device=device).to(device)

        z0_diffeq_solver = DiffeqSolver(enc_input_dim,
                                        rec_ode_func,
                                        "euler",
                                        args.latents,
                                        odeint_rtol=1e-3,
                                        odeint_atol=1e-4,
                                        device=device)

        encoder_z0 = Encoder_z0_ODE_RNN(n_rec_dims,
                                        enc_input_dim,
                                        z0_diffeq_solver,
                                        z0_dim=z0_dim,
                                        n_gru_units=args.gru_units,
                                        device=device).to(device)

    elif args.z0_encoder == "rnn":
        encoder_z0 = Encoder_z0_RNN(z0_dim,
                                    enc_input_dim,
                                    lstm_output_size=n_rec_dims,
                                    device=device).to(device)

    # my code
    elif args.z0_encoder == "trans":
        encoder_z0 = TransformerLayer(args.latents * 2, input_dim,
                                      nhidden=20).to(device)

    else:
        raise Exception("Unknown encoder for Latent ODE model: " +
                        args.z0_encoder)

    decoder = Decoder(args.latents, gen_data_dim).to(device)

    diffeq_solver = DiffeqSolver(gen_data_dim,
                                 gen_ode_func,
                                 'dopri5',
                                 args.latents,
                                 odeint_rtol=1e-3,
                                 odeint_atol=1e-4,
                                 device=device)

    model = LatentODE(
        input_dim=gen_data_dim,
        latent_dim=args.latents,
        encoder_z0=encoder_z0,
        decoder=decoder,
        diffeq_solver=diffeq_solver,
        z0_prior=z0_prior,
        device=device,
        obsrv_std=obsrv_std,
        use_poisson_proc=args.poisson,
        use_binary_classif=args.classif,
        linear_classifier=args.linear_classif,
        classif_per_tp=classif_per_tp,
        n_labels=n_labels,
        train_classif_w_reconstr=(args.dataset == "physionet")).to(device)

    return model
예제 #6
0
        if args.poisson:
            print(
                "Poisson process likelihood not implemented for ODE-RNN: ignoring --poisson"
            )

        if args.extrap:
            raise Exception("Extrapolation for ODE-RNN not implemented")

        ode_func_net = utils.create_net(n_ode_gru_dims,
                                        n_ode_gru_dims,
                                        n_layers=args.rec_layers,
                                        n_units=args.units,
                                        nonlinear=nn.Tanh)

        rec_ode_func = ODEFunc(input_dim=input_dim,
                               latent_dim=n_ode_gru_dims,
                               ode_func_net=ode_func_net,
                               device=device).to(device)

        z0_diffeq_solver = DiffeqSolver(input_dim,
                                        rec_ode_func,
                                        "euler",
                                        args.latents,
                                        odeint_rtol=1e-3,
                                        odeint_atol=1e-4,
                                        device=device)

        model = ODE_RNN(
            input_dim,
            n_ode_gru_dims,
            device=device,
            z0_diffeq_solver=z0_diffeq_solver,