Esempio n. 1
0
    def __init__(
        self,
        input_dim,
        latent_dim,
        device=torch.device("cpu"),
        z0_diffeq_solver=None,
        n_gru_units=100,
        n_units=100,
        concat_mask=False,
        obsrv_std=0.1,
        use_binary_classif=False,
        classif_per_tp=False,
        n_labels=1,
        train_classif_w_reconstr=False,
    ):

        Baseline.__init__(
            self,
            input_dim,
            latent_dim,
            device=device,
            obsrv_std=obsrv_std,
            use_binary_classif=use_binary_classif,
            classif_per_tp=classif_per_tp,
            n_labels=n_labels,
            train_classif_w_reconstr=train_classif_w_reconstr,
        )

        ode_rnn_encoder_dim = latent_dim

        self.ode_gru = Encoder_z0_ODE_RNN(
            latent_dim=ode_rnn_encoder_dim,
            input_dim=(input_dim) * 2,  # input and the mask
            z0_diffeq_solver=z0_diffeq_solver,
            n_gru_units=n_gru_units,
            device=device,
        ).to(device)

        self.z0_diffeq_solver = z0_diffeq_solver

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, n_units),
            nn.Tanh(),
            nn.Linear(n_units, input_dim),
        )

        utils.init_network_weights(self.decoder)
Esempio n. 2
0
    def __init__(self,
                 input_dim,
                 latent_dim,
                 device=torch.device("cpu"),
                 z0_diffeq_solver=None,
                 n_gru_units=100,
                 n_units=100,
                 concat_mask=False,
                 obsrv_std=0.1,
                 use_binary_classif=False,
                 classif_per_tp=False,
                 n_labels=1,
                 train_classif_w_reconstr=False,
                 RNNcell='gru_small',
                 stacking=None,
                 linear_classifier=False,
                 ODE_sharing=True,
                 RNN_sharing=False,
                 include_topper=False,
                 linear_topper=False,
                 use_BN=True,
                 resnet=False,
                 ode_type="linear",
                 ode_units=200,
                 rec_layers=1,
                 ode_method="dopri5",
                 stack_order=None,
                 nornnimputation=False,
                 use_pos_encod=False,
                 n_intermediate_tp=2):

        Baseline.__init__(self,
                          input_dim,
                          latent_dim,
                          device=device,
                          obsrv_std=obsrv_std,
                          use_binary_classif=use_binary_classif,
                          classif_per_tp=classif_per_tp,
                          n_labels=n_labels,
                          train_classif_w_reconstr=train_classif_w_reconstr)

        self.include_topper = include_topper
        self.resnet = resnet
        self.use_BN = use_BN
        ode_rnn_encoder_dim = latent_dim

        if ODE_sharing or RNN_sharing or self.resnet or self.include_topper:
            self.include_topper = True
            input_dim_first = latent_dim
        else:
            input_dim_first = input_dim

        if RNNcell == 'lstm':
            ode_latents = int(latent_dim) * 2
        else:
            ode_latents = int(latent_dim)

        #need one Encoder_z0_ODE_RNN per layer.
        self.ode_gru = []
        self.z0_diffeq_solver = []
        first_layer = True
        rnn_input = input_dim_first * 2

        if stack_order is None:
            stack_order = [
                "ode_rnn"
            ] * stacking  # a list of ode_rnn, star, gru, gru_small, lstm

        self.stacking = stacking
        if not (len(stack_order) == stacking
                ):  # stack_order argument must be as long as the stacking list
            print(
                "Warning, the specified stacking order is not the same length as the number of stacked layers, taking stack-order as reference."
            )
            print("Stack-order: ", stack_order)
            print("Stacking argument: ", stacking)
            self.stacking = len(stack_order)

        # get the default ODE and RNN for the weightsharing
        # ODE stuff
        z0_diffeq_solver = get_diffeq_solver(ode_latents,
                                             ode_units,
                                             rec_layers,
                                             ode_method,
                                             ode_type="linear",
                                             device=device)

        # RNNcell
        if RNNcell == 'gru':
            RNN_update = GRU_unit(latent_dim,
                                  rnn_input,
                                  n_units=n_gru_units,
                                  device=device).to(device)

        elif RNNcell == 'gru_small':
            RNN_update = GRU_standard_unit(latent_dim,
                                           rnn_input,
                                           device=device).to(device)

        elif RNNcell == 'lstm':
            RNN_update = LSTM_unit(latent_dim, rnn_input).to(device)

        elif RNNcell == "star":
            RNN_update = STAR_unit(latent_dim, rnn_input,
                                   n_units=n_gru_units).to(device)

        else:
            raise Exception(
                "Invalid RNN-cell type. Hint: expdecay not available for ODE-RNN"
            )

        # Put the layers it into the model
        for s in range(self.stacking):

            use_ODE = (stack_order[s] == "ode_rnn")

            if first_layer:
                # input and the mask
                layer_input_dimension = (input_dim_first) * 2
                first_layer = False

            else:
                # otherwise we just take the latent dimension of the previous layer as the sequence
                layer_input_dimension = latent_dim * 2

            # append the same z0_ODE-RNN for every layer

            if not RNN_sharing:

                if not use_ODE:
                    if use_pos_encod:
                        vertical_rnn_input = layer_input_dimension + 4  # +4 for 2dim encoding and it's mask
                    else:
                        vertical_rnn_input = layer_input_dimension + 2  # +2 for delta t and it's mask

                    thisRNNcell = stack_order[s]

                else:
                    vertical_rnn_input = layer_input_dimension
                    thisRNNcell = RNNcell

                if thisRNNcell == 'gru':
                    #pdb.set_trace()
                    RNN_update = GRU_unit(latent_dim,
                                          vertical_rnn_input,
                                          n_units=n_gru_units,
                                          device=device).to(device)

                elif thisRNNcell == 'gru_small':
                    RNN_update = GRU_standard_unit(latent_dim,
                                                   vertical_rnn_input,
                                                   device=device).to(device)

                elif thisRNNcell == 'lstm':
                    # two times latent dimension because of the cell state!
                    RNN_update = LSTM_unit(latent_dim * 2,
                                           vertical_rnn_input).to(device)

                elif thisRNNcell == "star":
                    RNN_update = STAR_unit(latent_dim,
                                           vertical_rnn_input,
                                           n_units=n_gru_units).to(device)

                else:
                    raise Exception(
                        "Invalid RNN-cell type. Hint: expdecay not available for ODE-RNN"
                    )

            if not ODE_sharing:

                if RNNcell == 'lstm':
                    ode_latents = int(latent_dim) * 2
                else:
                    ode_latents = int(latent_dim)

                z0_diffeq_solver = get_diffeq_solver(ode_latents,
                                                     ode_units,
                                                     rec_layers,
                                                     ode_method,
                                                     ode_type="linear",
                                                     device=device)

            self.Encoder0 = Encoder_z0_ODE_RNN(
                latent_dim=ode_rnn_encoder_dim,
                input_dim=layer_input_dimension,
                z0_diffeq_solver=z0_diffeq_solver,
                n_gru_units=n_gru_units,
                device=device,
                RNN_update=RNN_update,
                use_BN=use_BN,
                use_ODE=use_ODE,
                nornnimputation=nornnimputation,
                use_pos_encod=use_pos_encod,
                n_intermediate_tp=n_intermediate_tp).to(device)

            self.ode_gru.append(self.Encoder0)

        # construct topper
        if self.include_topper:
            if linear_topper:
                self.topper = nn.Sequential(
                    nn.Linear(input_dim, latent_dim),
                    nn.Tanh(),
                ).to(device)
            else:
                self.topper = nn.Sequential(
                    nn.Linear(input_dim, 100),
                    nn.Tanh(),
                    nn.Linear(100, latent_dim),
                    nn.Tanh(),
                ).to(device)

            utils.init_network_weights(self.topper)

            self.topper_bn = nn.BatchNorm1d(latent_dim)
        """
		self.decoder = nn.Sequential(
			nn.Linear(latent_dim, n_units),
			nn.Tanh(),
			nn.Linear(n_units, input_dim),)

		utils.init_network_weights(self.decoder)
		"""

        z0_dim = latent_dim

        # get the end-of-sequence classifier
        if use_binary_classif:
            if linear_classifier:
                self.classifier = nn.Sequential(nn.Linear(z0_dim, n_labels),
                                                nn.Softmax(dim=(2)))
            else:
                self.classifier = create_classifier(z0_dim, n_labels)
            utils.init_network_weights(self.classifier)

            if self.use_BN:
                self.bn_lasthidden = nn.BatchNorm1d(latent_dim)

        self.device = device