Ejemplo n.º 1
0
def create_ODERNN_model():

    obsrv_std = torch.Tensor([0.1]).to(device)
    z0_prior = Normal(
        torch.Tensor([0.0]).to(device),
        torch.Tensor([1.]).to(device))
    gru_units = 40
    n_ode_gru_dims = latent_dim

    ode_func_net = utils.create_net(n_ode_gru_dims,
                                    n_ode_gru_dims,
                                    n_layers=2,
                                    n_units=100,
                                    nonlinear=nn.Tanh)

    rec_ode_func = ODEFunc(input_dim=input_dim,
                           latent_dim=n_ode_gru_dims,
                           ode_func_net=ode_func_net,
                           device=device).to(device)

    z0_diffeq_solver = DiffeqSolver(input_dim,
                                    rec_ode_func,
                                    "dopri5",
                                    latent_dim,
                                    odeint_rtol=1e-3,
                                    odeint_atol=1e-4,
                                    device=device)

    model = ODE_RNN(input_dim=input_dim,
                    latent_dim=latent_dim,
                    n_gru_units=50,
                    n_units=50,
                    device=device,
                    z0_diffeq_solver=z0_diffeq_solver,
                    concat_mask=True,
                    obsrv_std=obsrv_std,
                    use_binary_classif=False,
                    classif_per_tp=False,
                    n_labels=1,
                    train_classif_w_reconstr=False).to(device)

    disable_bias = True
    if disable_bias:
        for module in model.modules():
            if hasattr(module, 'bias'):
                module.bias = None

    return model
Ejemplo n.º 2
0
# Model
obsrv_std = torch.Tensor([0.1]).to(device)
z0_prior = Normal(torch.Tensor([0.0]).to(device), torch.Tensor([1.]).to(device))
gru_units = 40
n_ode_gru_dims = latent_dim
				
ode_func_net = utils.create_net(n_ode_gru_dims, n_ode_gru_dims, 
    n_layers = 2, n_units = 50, nonlinear = nn.Tanh)

rec_ode_func = ODEFunc(
    input_dim = input_dim, 
    latent_dim = n_ode_gru_dims,
    ode_func_net = ode_func_net,
    device = device).to(device)

z0_diffeq_solver = DiffeqSolver(input_dim, rec_ode_func, "dopri5", latent_dim, 
    odeint_rtol = 1e-3, odeint_atol = 1e-4, device = device)

model = ODE_RNN(input_dim=input_dim, latent_dim=latent_dim, 
            n_gru_units = 50, n_units = 50, device = device, 
			z0_diffeq_solver = z0_diffeq_solver,
			concat_mask = True, obsrv_std = obsrv_std,
			use_binary_classif = False,
			classif_per_tp = False,
			n_labels = 1,
			train_classif_w_reconstr = False
			).to(device)

##################################################################
# Training

def status(epoch, train_props, cv_props=None):
Ejemplo n.º 3
0
def create_LatentODE_model(args,
                           input_dim,
                           z0_prior,
                           obsrv_std,
                           device,
                           classif_per_tp=False,
                           n_labels=1):

    dim = args.latents
    if args.poisson:
        lambda_net = utils.create_net(dim,
                                      input_dim,
                                      n_layers=1,
                                      n_units=args.units,
                                      nonlinear=nn.Tanh)

        # ODE function produces the gradient for latent state and for poisson rate
        ode_func_net = utils.create_net(dim * 2,
                                        args.latents * 2,
                                        n_layers=args.gen_layers,
                                        n_units=args.units,
                                        nonlinear=nn.Tanh)

        gen_ode_func = ODEFunc_w_Poisson(input_dim=input_dim,
                                         latent_dim=args.latents * 2,
                                         ode_func_net=ode_func_net,
                                         lambda_net=lambda_net,
                                         device=device).to(device)
    else:
        dim = args.latents
        ode_func_net = utils.create_net(dim,
                                        args.latents,
                                        n_layers=args.gen_layers,
                                        n_units=args.units,
                                        nonlinear=nn.Tanh)

        gen_ode_func = ODEFunc(input_dim=input_dim,
                               latent_dim=args.latents,
                               ode_func_net=ode_func_net,
                               device=device).to(device)

    z0_diffeq_solver = None
    n_rec_dims = args.rec_dims
    enc_input_dim = int(input_dim) * 2  # we concatenate the mask
    gen_data_dim = input_dim

    z0_dim = args.latents
    if args.poisson:
        z0_dim += args.latents  # predict the initial poisson rate

    if args.z0_encoder == "odernn":
        ode_func_net = utils.create_net(n_rec_dims,
                                        n_rec_dims,
                                        n_layers=args.rec_layers,
                                        n_units=args.units,
                                        nonlinear=nn.Tanh)

        rec_ode_func = ODEFunc(input_dim=enc_input_dim,
                               latent_dim=n_rec_dims,
                               ode_func_net=ode_func_net,
                               device=device).to(device)

        z0_diffeq_solver = DiffeqSolver(enc_input_dim,
                                        rec_ode_func,
                                        "euler",
                                        args.latents,
                                        odeint_rtol=1e-3,
                                        odeint_atol=1e-4,
                                        device=device)

        encoder_z0 = Encoder_z0_ODE_RNN(n_rec_dims,
                                        enc_input_dim,
                                        z0_diffeq_solver,
                                        z0_dim=z0_dim,
                                        n_gru_units=args.gru_units,
                                        device=device).to(device)

    elif args.z0_encoder == "rnn":
        encoder_z0 = Encoder_z0_RNN(z0_dim,
                                    enc_input_dim,
                                    lstm_output_size=n_rec_dims,
                                    device=device).to(device)
    else:
        raise Exception("Unknown encoder for Latent ODE model: " +
                        args.z0_encoder)

    decoder = Decoder(args.latents, gen_data_dim).to(device)

    diffeq_solver = DiffeqSolver(gen_data_dim,
                                 gen_ode_func,
                                 'dopri5',
                                 args.latents,
                                 odeint_rtol=1e-3,
                                 odeint_atol=1e-4,
                                 device=device)

    model = LatentODE(
        input_dim=gen_data_dim,
        latent_dim=args.latents,
        encoder_z0=encoder_z0,
        decoder=decoder,
        diffeq_solver=diffeq_solver,
        z0_prior=z0_prior,
        device=device,
        obsrv_std=obsrv_std,
        use_poisson_proc=args.poisson,
        use_binary_classif=args.classif,
        linear_classifier=args.linear_classif,
        classif_per_tp=classif_per_tp,
        n_labels=n_labels,
        train_classif_w_reconstr=(args.dataset == "physionet")).to(device)

    return model
def create_LatentODE_model(input_dim, z0_prior, obsrv_std, device = device, classif_per_tp = False, n_labels = 1, latents=latent_dim, disable_bias=True):

    dim = latent_dim
    ode_func_net = utils.create_net(dim, latents, 
        n_layers = 2, n_units = 100, nonlinear = nn.Tanh)

    gen_ode_func = ODEFunc(
        input_dim = input_dim, 
        latent_dim = latent_dim, 
        ode_func_net = ode_func_net,
        device = device).to(device)
        
    z0_diffeq_solver = None
    n_rec_dims = 50 # rec_dims: default 20
    enc_input_dim = int(input_dim) * 2 # we concatenate the mask
    gen_data_dim = input_dim

    z0_dim = latent_dim

    ode_func_net = utils.create_net(n_rec_dims, n_rec_dims, 
        n_layers = 2, n_units = 100, nonlinear = nn.Tanh)

    rec_ode_func = ODEFunc(
        input_dim = enc_input_dim, 
        latent_dim = n_rec_dims,
        ode_func_net = ode_func_net,
        device = device).to(device)

    z0_diffeq_solver = DiffeqSolver(enc_input_dim, rec_ode_func, "dopri5", latents, 
        odeint_rtol = 1e-3, odeint_atol = 1e-4, device = device)

    encoder_z0 = Encoder_z0_ODE_RNN(n_rec_dims, enc_input_dim, z0_diffeq_solver, 
        z0_dim = z0_dim, n_gru_units = 100, device = device).to(device)

    decoder = Decoder(latents, input_dim=gen_data_dim).to(device)

    diffeq_solver = DiffeqSolver(gen_data_dim, gen_ode_func, 'dopri5', latents, 
        odeint_rtol = 1e-5, odeint_atol = 1e-6, device = device)

    model = LatentODE(
        input_dim = gen_data_dim, 
        latent_dim = latents, 
        encoder_z0 = encoder_z0, 
        decoder = decoder, 
        diffeq_solver = diffeq_solver, 
        z0_prior = z0_prior, 
        device = device,
        obsrv_std = obsrv_std,
        use_poisson_proc = False, 
        use_binary_classif = False,
        linear_classifier = False,
        classif_per_tp = False,
        n_labels = n_labels,
        train_classif_w_reconstr = False
        ).to(device)

    if disable_bias:
        for module in model.modules():
            if hasattr(module, 'bias'):
                module.bias = None

    return model
Ejemplo n.º 5
0
		if args.poisson:
			print("Poisson process likelihood not implemented for ODE-RNN: ignoring --poisson")

		if args.extrap:
			raise Exception("Extrapolation for ODE-RNN not implemented")

		ode_func_net = utils.create_net(n_ode_gru_dims, n_ode_gru_dims, 
			n_layers = args.rec_layers, n_units = args.units, nonlinear = nn.Tanh)

		rec_ode_func = ODEFunc(
			input_dim = input_dim, 
			latent_dim = n_ode_gru_dims,
			ode_func_net = ode_func_net,
			device = device).to(device)

		z0_diffeq_solver = DiffeqSolver(input_dim, rec_ode_func, "euler", args.latents, 
			odeint_rtol = 1e-3, odeint_atol = 1e-4, device = device)
	
		model = ODE_RNN(input_dim, n_ode_gru_dims, device = device, 
			z0_diffeq_solver = z0_diffeq_solver, n_gru_units = args.gru_units,
			concat_mask = True, obsrv_std = obsrv_std,
			use_binary_classif = args.classif,
			classif_per_tp = classif_per_tp,
			n_labels = n_labels,
			train_classif_w_reconstr = (args.dataset == "physionet")
			).to(device)
	elif args.latent_ode:
		model = create_LatentODE_model(args, input_dim, z0_prior, obsrv_std, device, 
			classif_per_tp = classif_per_tp,
			n_labels = n_labels)
	else:
		raise Exception("Model not specified")