def __init__(self, latent_dim, ode_func_layers, ode_func_units, input_dim, decoder_units): super(ODE_RNN, self).__init__() ode_func_net = utils.create_net(latent_dim, latent_dim, n_layers=ode_func_layers, n_units=ode_func_units, nonlinear=nn.Tanh) utils.init_network_weights(ode_func_net) rec_ode_func = ODEFunc(ode_func_net=ode_func_net) self.ode_solver = DiffeqSolver(rec_ode_func, "euler", odeint_rtol=1e-3, odeint_atol=1e-4) self.decoder = nn.Sequential(nn.Linear(latent_dim, decoder_units), nn.Tanh(), nn.Linear(decoder_units, input_dim * 2)) utils.init_network_weights(self.decoder) self.gru_unit = GRU_Unit(latent_dim, input_dim, n_units=decoder_units) self.latent_dim = latent_dim self.sigma_fn = nn.Softplus()
def get_odernn_encoder(input_dim, latent_dim, odernn_hypers, device): ode_func_net = create_net(latent_dim, latent_dim, n_layers=odernn_hypers['odefunc_n_layers'], n_units=odernn_hypers['odefunc_n_units']) ode_func = ODEFunc(input_dim=input_dim * 2, latent_dim=latent_dim, ode_func_net=ode_func_net, device=device).to(device) diffeq_solver = DiffeqSolver(input_dim * 2, ode_func, 'dopri5', latent_dim, odeint_rtol=1e-3, odeint_atol=1e-4, device=device) encoder = Encoder_z0_ODE_RNN( latent_dim, input_dim * 2, diffeq_solver, z0_dim=latent_dim, n_gru_units=odernn_hypers['encoder_gru_units'], device=device).to(device) return encoder
def get_diffeq_solver(ode_latents, ode_units, rec_layers, ode_method, ode_type="linear", device=torch.device("cpu")): if ode_type == "linear": ode_func_net = utils.create_net(ode_latents, ode_latents, n_layers=int(rec_layers), n_units=int(ode_units), nonlinear=nn.Tanh) elif ode_type == "gru": ode_func_net = FullGRUODECell_Autonomous(ode_latents, bias=True) else: raise Exception("Invalid ODE-type. Choose linear or gru.") rec_ode_func = ODEFunc(input_dim=0, latent_dim=ode_latents, ode_func_net=ode_func_net, device=device).to(device) z0_diffeq_solver = DiffeqSolver(0, rec_ode_func, ode_method, ode_latents, odeint_rtol=1e-3, odeint_atol=1e-4, device=device) return z0_diffeq_solver
def create_LatentODE_model(args, input_dim, z0_prior, obsrv_std, device): # dim related latent_dim = args.latents # ode output dimension rec_dim = args.rec_dims input_dim = input_dim ode_dim = args.ode_dims #ode gcn dimension #encoder related encoder_z0 = GNN(in_dim=input_dim, n_hid=rec_dim, out_dim=latent_dim, n_heads=args.n_heads, n_layers=args.rec_layers, dropout=args.dropout, conv_name=args.z0_encoder, aggregate=args.rec_attention) # [b,n_ball,e] #ODE related if args.augment_dim > 0: ode_input_dim = latent_dim + args.augment_dim else: ode_input_dim = latent_dim ode_func_net = GNN(in_dim=ode_input_dim, n_hid=ode_dim, out_dim=ode_input_dim, n_heads=args.n_heads, n_layers=args.gen_layers, dropout=args.dropout, conv_name=args.odenet, aggregate="add") gen_ode_func = GraphODEFunc(ode_func_net=ode_func_net, device=device).to(device) diffeq_solver = DiffeqSolver(gen_ode_func, args.solver, args=args, odeint_rtol=1e-2, odeint_atol=1e-2, device=device) #Decoder related decoder = Decoder(latent_dim, input_dim).to(device) model = LatentGraphODE( input_dim=input_dim, latent_dim=args.latents, encoder_z0=encoder_z0, decoder=decoder, diffeq_solver=diffeq_solver, z0_prior=z0_prior, device=device, obsrv_std=obsrv_std, ).to(device) return model
def create_ode_rnn_encoder(args, device): """ This function create the ode-rnn model as an encoder args: the arguments from parse_arguments in ctfp_tools device: cpu or gpu to run the model return an ode-rnn model """ enc_input_dim = args.input_size * 2 ## concatenate the mask with input ode_func_net = utils.create_net( args.rec_size, args.rec_size, n_layers=args.rec_layers, n_units=args.units, nonlinear=nn.Tanh, ) rec_ode_func = ODEFunc( input_dim=enc_input_dim, latent_dim=args.rec_size, ode_func_net=ode_func_net, device=device, ).to(device) z0_diffeq_solver = DiffeqSolver( enc_input_dim, rec_ode_func, "euler", args.latent_size, odeint_rtol=1e-3, odeint_atol=1e-4, device=device, ) encoder_z0 = Encoder_z0_ODE_RNN( args.rec_size, enc_input_dim, z0_diffeq_solver, z0_dim=args.latent_size, n_gru_units=args.gru_units, device=device, ).to(device) return encoder_z0
def create_LatentODE_model(args, input_dim, z0_prior, obsrv_std, device, classif_per_tp=False, n_labels=1): dim = args.latents if args.poisson: lambda_net = utils.create_net(dim, input_dim, n_layers=1, n_units=args.units, nonlinear=nn.Tanh) # ODE function produces the gradient for latent state and for poisson rate ode_func_net = utils.create_net(dim * 2, args.latents * 2, n_layers=args.gen_layers, n_units=args.units, nonlinear=nn.Tanh) gen_ode_func = ODEFunc_w_Poisson(input_dim=input_dim, latent_dim=args.latents * 2, ode_func_net=ode_func_net, lambda_net=lambda_net, device=device).to(device) else: dim = args.latents ode_func_net = utils.create_net(dim, args.latents, n_layers=args.gen_layers, n_units=args.units, nonlinear=nn.Tanh) gen_ode_func = ODEFunc(input_dim=input_dim, latent_dim=args.latents, ode_func_net=ode_func_net, device=device).to(device) z0_diffeq_solver = None n_rec_dims = args.rec_dims enc_input_dim = int(input_dim) * 2 # we concatenate the mask gen_data_dim = input_dim z0_dim = args.latents if args.poisson: z0_dim += args.latents # predict the initial poisson rate if args.z0_encoder == "odernn": ode_func_net = utils.create_net(n_rec_dims, n_rec_dims, n_layers=args.rec_layers, n_units=args.units, nonlinear=nn.Tanh) rec_ode_func = ODEFunc(input_dim=enc_input_dim, latent_dim=n_rec_dims, ode_func_net=ode_func_net, device=device).to(device) z0_diffeq_solver = DiffeqSolver(enc_input_dim, rec_ode_func, "euler", args.latents, odeint_rtol=1e-3, odeint_atol=1e-4, device=device) encoder_z0 = Encoder_z0_ODE_RNN(n_rec_dims, enc_input_dim, z0_diffeq_solver, z0_dim=z0_dim, n_gru_units=args.gru_units, device=device).to(device) elif args.z0_encoder == "rnn": encoder_z0 = Encoder_z0_RNN(z0_dim, enc_input_dim, lstm_output_size=n_rec_dims, device=device).to(device) # my code elif args.z0_encoder == "trans": encoder_z0 = TransformerLayer(args.latents * 2, input_dim, nhidden=20).to(device) else: raise Exception("Unknown encoder for Latent ODE model: " + args.z0_encoder) decoder = Decoder(args.latents, gen_data_dim).to(device) diffeq_solver = DiffeqSolver(gen_data_dim, gen_ode_func, 'dopri5', args.latents, odeint_rtol=1e-3, odeint_atol=1e-4, device=device) model = LatentODE( input_dim=gen_data_dim, latent_dim=args.latents, encoder_z0=encoder_z0, decoder=decoder, diffeq_solver=diffeq_solver, z0_prior=z0_prior, device=device, obsrv_std=obsrv_std, use_poisson_proc=args.poisson, use_binary_classif=args.classif, linear_classifier=args.linear_classif, classif_per_tp=classif_per_tp, n_labels=n_labels, train_classif_w_reconstr=(args.dataset == "physionet")).to(device) return model
ode_func_net = utils.create_net(n_ode_gru_dims, n_ode_gru_dims, n_layers=args.rec_layers, n_units=args.units, nonlinear=nn.Tanh) rec_ode_func = ODEFunc(input_dim=input_dim, latent_dim=n_ode_gru_dims, ode_func_net=ode_func_net, device=device).to(device) z0_diffeq_solver = DiffeqSolver(input_dim, rec_ode_func, "euler", args.latents, odeint_rtol=1e-3, odeint_atol=1e-4, device=device) model = ODE_RNN( input_dim, n_ode_gru_dims, device=device, z0_diffeq_solver=z0_diffeq_solver, n_gru_units=args.gru_units, concat_mask=True, obsrv_std=obsrv_std, use_binary_classif=args.classif, classif_per_tp=classif_per_tp, n_labels=n_labels,
class ODE_RNN(nn.Module): """Class for standalone ODE-RNN model. Makes predictions forward in time.""" def __init__(self, latent_dim, ode_func_layers, ode_func_units, input_dim, decoder_units): super(ODE_RNN, self).__init__() ode_func_net = utils.create_net(latent_dim, latent_dim, n_layers=ode_func_layers, n_units=ode_func_units, nonlinear=nn.Tanh) utils.init_network_weights(ode_func_net) rec_ode_func = ODEFunc(ode_func_net=ode_func_net) self.ode_solver = DiffeqSolver(rec_ode_func, "euler", odeint_rtol=1e-3, odeint_atol=1e-4) self.decoder = nn.Sequential(nn.Linear(latent_dim, decoder_units), nn.Tanh(), nn.Linear(decoder_units, input_dim * 2)) utils.init_network_weights(self.decoder) self.gru_unit = GRU_Unit(latent_dim, input_dim, n_units=decoder_units) self.latent_dim = latent_dim self.sigma_fn = nn.Softplus() def forward(self, data, mask, mask_first, time_steps, extrap_time=float('inf'), use_sampling=False): batch_size, n_time_steps, n_dims = data.size() prev_hidden = torch.zeros((batch_size, self.latent_dim)) prev_hidden_std = torch.zeros((batch_size, self.latent_dim)) if data.is_cuda: prev_hidden = prev_hidden.to(data.get_device()) prev_hidden_std = prev_hidden_std.to(data.get_device()) interval_length = time_steps[-1] - time_steps[0] minimum_step = interval_length / 50 outputs = [] outputs_std = [] prev_observation = data[:, 0] if use_sampling: prev_output = data[:, 0] for i in range(1, len(time_steps)): # Make one step. if time_steps[i] - time_steps[i - 1] < minimum_step: inc = self.ode_solver.ode_func(time_steps[i - 1], prev_hidden) ode_sol = prev_hidden + inc * (time_steps[i] - time_steps[i - 1]) ode_sol = torch.stack((prev_hidden, ode_sol), 1) # Several steps. else: num_intermediate_steps = max( 2, ((time_steps[i] - time_steps[i - 1]) / minimum_step).int()) time_points = torch.linspace(time_steps[i - 1], time_steps[i], num_intermediate_steps) ode_sol = self.ode_solver(prev_hidden.unsqueeze(0), time_points)[0] hidden_ode = ode_sol[:, -1] x_i = prev_observation if use_sampling and np.random.uniform( 0, 1) < 0.5 and time_steps[i] <= extrap_time: x_i = prev_output mask_i = mask[:, i] output_hidden, hidden, hidden_std = self.gru_unit( hidden_ode, prev_hidden_std, x_i, mask_i) hidden = mask_first[:, i - 1] * hidden hidden_std = mask_first[:, i - 1] * hidden_std prev_hidden, prev_hidden_std = hidden, hidden_std mean, std = torch.chunk(self.decoder(output_hidden), chunks=2, dim=-1) outputs += [mean] outputs_std += [self.sigma_fn(std)] if use_sampling: prev_output = prev_output * (1 - mask_i) + mask_i * outputs[-1] if time_steps[i] <= extrap_time: prev_observation = prev_observation * ( 1 - mask_i) + mask_i * data[:, i] else: prev_observation = prev_observation * ( 1 - mask_i) + mask_i * outputs[-1] outputs = torch.stack(outputs, 1) outputs_std = torch.stack(outputs_std, 1) return outputs, outputs_std @property def num_params(self): """Number of parameters.""" return np.sum( [torch.tensor(param.shape).prod() for param in self.parameters()])