def __init__(self,
                 input_dim,
                 latent_dim,
                 device,
                 concat_mask=False,
                 obsrv_std=0.1,
                 use_binary_classif=False,
                 linear_classifier=False,
                 classif_per_tp=False,
                 input_space_decay=False,
                 cell="gru",
                 n_units=100,
                 n_labels=1,
                 train_classif_w_reconstr=False):

        super(Classic_RNN,
              self).__init__(input_dim,
                             latent_dim,
                             device,
                             obsrv_std=obsrv_std,
                             use_binary_classif=use_binary_classif,
                             classif_per_tp=classif_per_tp,
                             linear_classifier=linear_classifier,
                             n_labels=n_labels,
                             train_classif_w_reconstr=train_classif_w_reconstr)

        self.concat_mask = concat_mask

        encoder_dim = int(input_dim)
        if concat_mask:
            encoder_dim = encoder_dim * 2

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, n_units),
            nn.Tanh(),
            nn.Linear(n_units, input_dim),
        )

        #utils.init_network_weights(self.encoder)
        utils.init_network_weights(self.decoder)

        if cell == "gru":
            self.rnn_cell = GRUCell(encoder_dim + 1,
                                    latent_dim)  # +1 for delta t
        elif cell == "expdecay":
            self.rnn_cell = GRUCellExpDecay(input_size=encoder_dim,
                                            input_size_for_decay=input_dim,
                                            hidden_size=latent_dim,
                                            device=device)
        else:
            raise Exception("Unknown RNN cell: {}".format(cell))

        if input_space_decay:
            self.w_input_decay = Parameter(torch.Tensor(1, int(input_dim))).to(
                self.device)
            self.b_input_decay = Parameter(torch.Tensor(1, int(input_dim))).to(
                self.device)
        self.input_space_decay = input_space_decay

        self.z0_net = lambda hidden_state: hidden_state
    def __init__(self,
                 latent_dim,
                 input_dim,
                 lstm_output_size=20,
                 use_delta_t=True,
                 device=torch.device("cpu")):

        super(Encoder_z0_RNN, self).__init__()

        self.gru_rnn_output_size = lstm_output_size
        self.latent_dim = latent_dim
        self.input_dim = input_dim
        self.device = device
        self.use_delta_t = use_delta_t

        self.hiddens_to_z0 = nn.Sequential(
            nn.Linear(self.gru_rnn_output_size, 50),
            nn.Tanh(),
            nn.Linear(50, latent_dim * 2),
        )

        utils.init_network_weights(self.hiddens_to_z0)

        input_dim = self.input_dim

        if use_delta_t:
            self.input_dim += 1
        self.gru_rnn = GRU(self.input_dim, self.gru_rnn_output_size).to(device)
    def __init__(self, latent_dim, input_dim):
        super(Decoder, self).__init__()
        # decode data from latent space where we are solving an ODE back to the data space

        decoder = nn.Sequential(nn.Linear(latent_dim, input_dim), )

        utils.init_network_weights(decoder)
        self.decoder = decoder
Beispiel #4
0
	def __init__(self, input_dim, latent_dim, ode_func_net, device = torch.device("cpu")):
		"""
		input_dim: dimensionality of the input
		latent_dim: dimensionality used for ODE. Analog of a continous latent state
		"""
		super(ODEFunc, self).__init__()

		self.input_dim = input_dim
		self.device = device

		utils.init_network_weights(ode_func_net)
		self.gradient_net = ode_func_net
    def __init__(self,
                 input_size,
                 input_size_for_decay,
                 hidden_size,
                 device,
                 bias=True):
        super(GRUCellExpDecay, self).__init__(input_size,
                                              hidden_size,
                                              bias,
                                              num_chunks=3)

        self.device = device
        self.input_size_for_decay = input_size_for_decay
        self.decay = nn.Sequential(nn.Linear(input_size_for_decay, 1), )
        utils.init_network_weights(self.decay)
Beispiel #6
0
    def __init__(self,
                 input_dim,
                 latent_dim,
                 device=torch.device("cpu"),
                 z0_diffeq_solver=None,
                 n_gru_units=100,
                 n_units=100,
                 concat_mask=False,
                 obsrv_std=0.1,
                 use_binary_classif=False,
                 classif_per_tp=False,
                 n_labels=1,
                 train_classif_w_reconstr=False):

        Baseline.__init__(self,
                          input_dim,
                          latent_dim,
                          device=device,
                          obsrv_std=obsrv_std,
                          use_binary_classif=use_binary_classif,
                          classif_per_tp=classif_per_tp,
                          n_labels=n_labels,
                          train_classif_w_reconstr=train_classif_w_reconstr)

        ode_rnn_encoder_dim = latent_dim

        self.ode_gru = Encoder_z0_ODE_RNN(
            latent_dim=ode_rnn_encoder_dim,
            input_dim=(input_dim) * 2,  # input and the mask
            z0_diffeq_solver=z0_diffeq_solver,
            n_gru_units=n_gru_units,
            device=device).to(device)

        self.z0_diffeq_solver = z0_diffeq_solver

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, n_units),
            nn.Tanh(),
            nn.Linear(n_units, input_dim),
        )

        utils.init_network_weights(self.decoder)
    def __init__(self,
                 latent_dim,
                 input_dim,
                 update_gate=None,
                 reset_gate=None,
                 new_state_net=None,
                 n_units=100,
                 device=torch.device("cpu")):
        super(GRU_unit, self).__init__()

        if update_gate is None:
            self.update_gate = nn.Sequential(
                nn.Linear(latent_dim * 2 + input_dim, n_units), nn.Tanh(),
                nn.Linear(n_units, latent_dim), nn.Sigmoid())
            utils.init_network_weights(self.update_gate)
        else:
            self.update_gate = update_gate

        if reset_gate is None:
            self.reset_gate = nn.Sequential(
                nn.Linear(latent_dim * 2 + input_dim, n_units), nn.Tanh(),
                nn.Linear(n_units, latent_dim), nn.Sigmoid())
            utils.init_network_weights(self.reset_gate)
        else:
            self.reset_gate = reset_gate

        if new_state_net is None:
            self.new_state_net = nn.Sequential(
                nn.Linear(latent_dim * 2 + input_dim, n_units), nn.Tanh(),
                nn.Linear(n_units, latent_dim * 2))
            utils.init_network_weights(self.new_state_net)
        else:
            self.new_state_net = new_state_net
Beispiel #8
0
    def __init__(self,
                 input_dim,
                 latent_dim,
                 z0_prior,
                 device,
                 obsrv_std=0.01,
                 use_binary_classif=False,
                 classif_per_tp=False,
                 use_poisson_proc=False,
                 linear_classifier=False,
                 n_labels=1,
                 train_classif_w_reconstr=False):

        super(VAE_Baseline, self).__init__()

        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.device = device
        self.n_labels = n_labels

        self.obsrv_std = torch.Tensor([obsrv_std]).to(device)

        self.z0_prior = z0_prior
        self.use_binary_classif = use_binary_classif
        self.classif_per_tp = classif_per_tp
        self.use_poisson_proc = use_poisson_proc
        self.linear_classifier = linear_classifier
        self.train_classif_w_reconstr = train_classif_w_reconstr

        z0_dim = latent_dim
        if use_poisson_proc:
            z0_dim += latent_dim

        if use_binary_classif:
            if linear_classifier:
                self.classifier = nn.Sequential(nn.Linear(z0_dim, n_labels))
            else:
                self.classifier = create_classifier(z0_dim, n_labels)
            utils.init_network_weights(self.classifier)
    def __init__(self,
                 latent_dim,
                 input_dim,
                 z0_diffeq_solver=None,
                 z0_dim=None,
                 GRU_update=None,
                 n_gru_units=100,
                 device=torch.device("cpu")):

        super(Encoder_z0_ODE_RNN, self).__init__()

        if z0_dim is None:
            self.z0_dim = latent_dim
        else:
            self.z0_dim = z0_dim

        if GRU_update is None:
            self.GRU_update = GRU_unit(latent_dim,
                                       input_dim,
                                       n_units=n_gru_units,
                                       device=device).to(device)
        else:
            self.GRU_update = GRU_update

        self.z0_diffeq_solver = z0_diffeq_solver
        self.latent_dim = latent_dim
        self.input_dim = input_dim
        self.device = device
        self.extra_info = None

        self.transform_z0 = nn.Sequential(
            nn.Linear(latent_dim * 2, 100),
            nn.Tanh(),
            nn.Linear(100, self.z0_dim * 2),
        )
        utils.init_network_weights(self.transform_z0)