def __init__(self, latent_dim, ode_func_layers, ode_func_units, input_dim, decoder_units): super(ODE_RNN, self).__init__() ode_func_net = utils.create_net(latent_dim, latent_dim, n_layers=ode_func_layers, n_units=ode_func_units, nonlinear=nn.Tanh) utils.init_network_weights(ode_func_net) rec_ode_func = ODEFunc(ode_func_net=ode_func_net) self.ode_solver = DiffeqSolver(rec_ode_func, "euler", odeint_rtol=1e-3, odeint_atol=1e-4) self.decoder = nn.Sequential(nn.Linear(latent_dim, decoder_units), nn.Tanh(), nn.Linear(decoder_units, input_dim * 2)) utils.init_network_weights(self.decoder) self.gru_unit = GRU_Unit(latent_dim, input_dim, n_units=decoder_units) self.latent_dim = latent_dim self.sigma_fn = nn.Softplus()
def get_odernn_encoder(input_dim, latent_dim, odernn_hypers, device): ode_func_net = create_net(latent_dim, latent_dim, n_layers=odernn_hypers['odefunc_n_layers'], n_units=odernn_hypers['odefunc_n_units']) ode_func = ODEFunc(input_dim=input_dim * 2, latent_dim=latent_dim, ode_func_net=ode_func_net, device=device).to(device) diffeq_solver = DiffeqSolver(input_dim * 2, ode_func, 'dopri5', latent_dim, odeint_rtol=1e-3, odeint_atol=1e-4, device=device) encoder = Encoder_z0_ODE_RNN( latent_dim, input_dim * 2, diffeq_solver, z0_dim=latent_dim, n_gru_units=odernn_hypers['encoder_gru_units'], device=device).to(device) return encoder
def get_diffeq_solver(ode_latents, ode_units, rec_layers, ode_method, ode_type="linear", device=torch.device("cpu")): if ode_type == "linear": ode_func_net = utils.create_net(ode_latents, ode_latents, n_layers=int(rec_layers), n_units=int(ode_units), nonlinear=nn.Tanh) elif ode_type == "gru": ode_func_net = FullGRUODECell_Autonomous(ode_latents, bias=True) else: raise Exception("Invalid ODE-type. Choose linear or gru.") rec_ode_func = ODEFunc(input_dim=0, latent_dim=ode_latents, ode_func_net=ode_func_net, device=device).to(device) z0_diffeq_solver = DiffeqSolver(0, rec_ode_func, ode_method, ode_latents, odeint_rtol=1e-3, odeint_atol=1e-4, device=device) return z0_diffeq_solver
def create_ode_rnn_encoder(args, device): """ This function create the ode-rnn model as an encoder args: the arguments from parse_arguments in ctfp_tools device: cpu or gpu to run the model return an ode-rnn model """ enc_input_dim = args.input_size * 2 ## concatenate the mask with input ode_func_net = utils.create_net( args.rec_size, args.rec_size, n_layers=args.rec_layers, n_units=args.units, nonlinear=nn.Tanh, ) rec_ode_func = ODEFunc( input_dim=enc_input_dim, latent_dim=args.rec_size, ode_func_net=ode_func_net, device=device, ).to(device) z0_diffeq_solver = DiffeqSolver( enc_input_dim, rec_ode_func, "euler", args.latent_size, odeint_rtol=1e-3, odeint_atol=1e-4, device=device, ) encoder_z0 = Encoder_z0_ODE_RNN( args.rec_size, enc_input_dim, z0_diffeq_solver, z0_dim=args.latent_size, n_gru_units=args.gru_units, device=device, ).to(device) return encoder_z0
def create_LatentODE_model(args, input_dim, z0_prior, obsrv_std, device, classif_per_tp=False, n_labels=1): dim = args.latents if args.poisson: lambda_net = utils.create_net(dim, input_dim, n_layers=1, n_units=args.units, nonlinear=nn.Tanh) # ODE function produces the gradient for latent state and for poisson rate ode_func_net = utils.create_net(dim * 2, args.latents * 2, n_layers=args.gen_layers, n_units=args.units, nonlinear=nn.Tanh) gen_ode_func = ODEFunc_w_Poisson(input_dim=input_dim, latent_dim=args.latents * 2, ode_func_net=ode_func_net, lambda_net=lambda_net, device=device).to(device) else: dim = args.latents ode_func_net = utils.create_net(dim, args.latents, n_layers=args.gen_layers, n_units=args.units, nonlinear=nn.Tanh) gen_ode_func = ODEFunc(input_dim=input_dim, latent_dim=args.latents, ode_func_net=ode_func_net, device=device).to(device) z0_diffeq_solver = None n_rec_dims = args.rec_dims enc_input_dim = int(input_dim) * 2 # we concatenate the mask gen_data_dim = input_dim z0_dim = args.latents if args.poisson: z0_dim += args.latents # predict the initial poisson rate if args.z0_encoder == "odernn": ode_func_net = utils.create_net(n_rec_dims, n_rec_dims, n_layers=args.rec_layers, n_units=args.units, nonlinear=nn.Tanh) rec_ode_func = ODEFunc(input_dim=enc_input_dim, latent_dim=n_rec_dims, ode_func_net=ode_func_net, device=device).to(device) z0_diffeq_solver = DiffeqSolver(enc_input_dim, rec_ode_func, "euler", args.latents, odeint_rtol=1e-3, odeint_atol=1e-4, device=device) encoder_z0 = Encoder_z0_ODE_RNN(n_rec_dims, enc_input_dim, z0_diffeq_solver, z0_dim=z0_dim, n_gru_units=args.gru_units, device=device).to(device) elif args.z0_encoder == "rnn": encoder_z0 = Encoder_z0_RNN(z0_dim, enc_input_dim, lstm_output_size=n_rec_dims, device=device).to(device) # my code elif args.z0_encoder == "trans": encoder_z0 = TransformerLayer(args.latents * 2, input_dim, nhidden=20).to(device) else: raise Exception("Unknown encoder for Latent ODE model: " + args.z0_encoder) decoder = Decoder(args.latents, gen_data_dim).to(device) diffeq_solver = DiffeqSolver(gen_data_dim, gen_ode_func, 'dopri5', args.latents, odeint_rtol=1e-3, odeint_atol=1e-4, device=device) model = LatentODE( input_dim=gen_data_dim, latent_dim=args.latents, encoder_z0=encoder_z0, decoder=decoder, diffeq_solver=diffeq_solver, z0_prior=z0_prior, device=device, obsrv_std=obsrv_std, use_poisson_proc=args.poisson, use_binary_classif=args.classif, linear_classifier=args.linear_classif, classif_per_tp=classif_per_tp, n_labels=n_labels, train_classif_w_reconstr=(args.dataset == "physionet")).to(device) return model
if args.poisson: print( "Poisson process likelihood not implemented for ODE-RNN: ignoring --poisson" ) if args.extrap: raise Exception("Extrapolation for ODE-RNN not implemented") ode_func_net = utils.create_net(n_ode_gru_dims, n_ode_gru_dims, n_layers=args.rec_layers, n_units=args.units, nonlinear=nn.Tanh) rec_ode_func = ODEFunc(input_dim=input_dim, latent_dim=n_ode_gru_dims, ode_func_net=ode_func_net, device=device).to(device) z0_diffeq_solver = DiffeqSolver(input_dim, rec_ode_func, "euler", args.latents, odeint_rtol=1e-3, odeint_atol=1e-4, device=device) model = ODE_RNN( input_dim, n_ode_gru_dims, device=device, z0_diffeq_solver=z0_diffeq_solver,