def __init__(self, input_dim, latent_dim, device, 
		obsrv_std = 0.01, use_binary_classif = False,
		classif_per_tp = False,
		use_poisson_proc = False,
		linear_classifier = False,
		n_labels = 1,
		train_classif_w_reconstr = False):
		super(Baseline, self).__init__()

		self.input_dim = input_dim
		self.latent_dim = latent_dim
		self.n_labels = n_labels

		self.obsrv_std = torch.Tensor([obsrv_std]).to(device)
		self.device = device

		self.use_binary_classif = use_binary_classif
		self.classif_per_tp = classif_per_tp
		self.use_poisson_proc = use_poisson_proc
		self.linear_classifier = linear_classifier
		self.train_classif_w_reconstr = train_classif_w_reconstr

		z0_dim = latent_dim
		if use_poisson_proc:
			z0_dim += latent_dim

		if use_binary_classif: 
			if linear_classifier:
				self.classifier = nn.Sequential(
					nn.Linear(z0_dim, n_labels))
			else:
				self.classifier = create_classifier(z0_dim, n_labels)
			utils.init_network_weights(self.classifier)
Пример #2
0
    def __init__(self,
                 latent_dim,
                 input_dim,
                 lstm_output_size=20,
                 use_delta_t=True,
                 device=torch.device("cpu")):

        super(Encoder_z0_RNN, self).__init__()

        self.gru_rnn_output_size = lstm_output_size
        self.latent_dim = latent_dim
        self.input_dim = input_dim
        self.device = device
        self.use_delta_t = use_delta_t

        self.hiddens_to_z0 = nn.Sequential(
            nn.Linear(self.gru_rnn_output_size, 50),
            nn.Tanh(),
            nn.Linear(50, latent_dim * 2),
        )

        utils.init_network_weights(self.hiddens_to_z0)

        input_dim = self.input_dim

        if use_delta_t:
            self.input_dim += 1
        self.gru_rnn = GRU(self.input_dim, self.gru_rnn_output_size).to(device)
Пример #3
0
    def __init__(self, latent_dim, ode_func_layers, ode_func_units, input_dim,
                 decoder_units):
        super(ODE_RNN, self).__init__()

        ode_func_net = utils.create_net(latent_dim,
                                        latent_dim,
                                        n_layers=ode_func_layers,
                                        n_units=ode_func_units,
                                        nonlinear=nn.Tanh)

        utils.init_network_weights(ode_func_net)

        rec_ode_func = ODEFunc(ode_func_net=ode_func_net)

        self.ode_solver = DiffeqSolver(rec_ode_func,
                                       "euler",
                                       odeint_rtol=1e-3,
                                       odeint_atol=1e-4)

        self.decoder = nn.Sequential(nn.Linear(latent_dim, decoder_units),
                                     nn.Tanh(),
                                     nn.Linear(decoder_units, input_dim * 2))

        utils.init_network_weights(self.decoder)

        self.gru_unit = GRU_Unit(latent_dim, input_dim, n_units=decoder_units)

        self.latent_dim = latent_dim

        self.sigma_fn = nn.Softplus()
Пример #4
0
	def __init__(self, latent_dim, input_dim, z0_diffeq_solver = None, 
		z0_dim = None, GRU_update = None, 
		n_gru_units = 100, 
		device = torch.device("cpu")):
		
		super(Encoder_z0_ODE_RNN, self).__init__()

		if z0_dim is None:
			self.z0_dim = latent_dim
		else:
			self.z0_dim = z0_dim

		if GRU_update is None:
			self.GRU_update = GRU_unit(latent_dim, input_dim, 
				n_units = n_gru_units, 
				device=device).to(device)
		else:
			self.GRU_update = GRU_update

		self.z0_diffeq_solver = z0_diffeq_solver
		self.latent_dim = latent_dim
		self.input_dim = input_dim
		self.device = device
		self.extra_info = None

		self.transform_z0 = nn.Sequential(
		   nn.Linear(latent_dim * 2, 100),
		   nn.Tanh(),
		   nn.Linear(100, self.z0_dim * 2),)
		utils.init_network_weights(self.transform_z0)
    def __init__(self,
                 input_dim,
                 latent_dim,
                 device,
                 concat_mask=False,
                 obsrv_std=0.1,
                 use_binary_classif=False,
                 linear_classifier=False,
                 classif_per_tp=False,
                 input_space_decay=False,
                 cell="gru",
                 n_units=100,
                 n_labels=1,
                 train_classif_w_reconstr=False):

        super(Classic_RNN,
              self).__init__(input_dim,
                             latent_dim,
                             device,
                             obsrv_std=obsrv_std,
                             use_binary_classif=use_binary_classif,
                             classif_per_tp=classif_per_tp,
                             linear_classifier=linear_classifier,
                             n_labels=n_labels,
                             train_classif_w_reconstr=train_classif_w_reconstr)

        self.concat_mask = concat_mask

        encoder_dim = int(input_dim)
        if concat_mask:
            encoder_dim = encoder_dim * 2

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, n_units),
            nn.Tanh(),
            nn.Linear(n_units, input_dim),
        )

        #utils.init_network_weights(self.encoder)
        utils.init_network_weights(self.decoder)

        if cell == "gru":
            self.rnn_cell = GRUCell(encoder_dim + 1,
                                    latent_dim)  # +1 for delta t
        elif cell == "expdecay":
            self.rnn_cell = GRUCellExpDecay(input_size=encoder_dim,
                                            input_size_for_decay=input_dim,
                                            hidden_size=latent_dim,
                                            device=device)
        else:
            raise Exception("Unknown RNN cell: {}".format(cell))

        if input_space_decay:
            self.w_input_decay = Parameter(torch.Tensor(1, int(input_dim))).to(
                self.device)
            self.b_input_decay = Parameter(torch.Tensor(1, int(input_dim))).to(
                self.device)
        self.input_space_decay = input_space_decay

        self.z0_net = lambda hidden_state: hidden_state
Пример #6
0
    def __init__(self, in_dim, n_hid,out_dim, n_heads, n_layers, dropout = 0.2, conv_name = 'GTrans',aggregate = "add"):
        super(GNN, self).__init__()
        self.gcs = nn.ModuleList()
        self.in_dim    = in_dim
        self.n_hid     = n_hid
        self.drop = nn.Dropout(dropout)
        self.adapt_ws = nn.Linear(in_dim,n_hid)
        self.sequence_w = nn.Linear(n_hid,n_hid) # for encoder
        self.out_w_ode = nn.Linear(n_hid,out_dim)
        self.out_w_encoder = nn.Linear(n_hid,out_dim*2)

        #initialization
        utils.init_network_weights(self.adapt_ws)
        utils.init_network_weights(self.sequence_w)
        utils.init_network_weights(self.out_w_ode)
        utils.init_network_weights(self.out_w_encoder)

        # Normalization
        self.layer_norm = nn.LayerNorm(n_hid)
        self.aggregate = aggregate
        for l in range(n_layers):
            self.gcs.append(GeneralConv(conv_name, n_hid, n_hid,  n_heads, dropout))

        if conv_name == 'GTrans':
            self.temporal_net = TemporalEncoding(n_hid)
            #self.w_transfer = nn.Linear(self.n_hid * 2, self.n_hid, bias=True)
            self.w_transfer = nn.Linear(self.n_hid + 1, self.n_hid, bias=True)
            utils.init_network_weights(self.w_transfer)
Пример #7
0
    def __init__(self, latent_dim, input_dim):
        super(Decoder, self).__init__()
        # decode data from latent space where we are solving an ODE back to the data space

        decoder = nn.Sequential(nn.Linear(latent_dim, input_dim), )

        utils.init_network_weights(decoder)
        self.decoder = decoder
Пример #8
0
    def __init__(self, output_dim, input_dim, hidden_dim, layer_num):
        super(Encoder, self).__init__()
        # decode data from latent space where we are solving an ODE back to the data space

        encoder = utils.create_net(input_dim, output_dim * 2, layer_num,
                                   hidden_dim)

        utils.init_network_weights(encoder)
        self.encoder = encoder
 def __init__(self, input_dim, latent_dim, ode_func_net, layer_type = "concat", device = torch.device("cpu")):
     super(ODEFunc_att, self).__init__()
     
     utils.init_network_weights(ode_func_net)
     self.ode_func_list = ode_func_net
     self.layer_type = layer_type
     
     self.weight = nn.Linear(latent_dim,latent_dim)
     self.register_buffer("_num_evals", torch.tensor(0.))
     self.query = None
    def __init__(self, input_dim, latent_dim, ode_func_net, layer_type = "concat", device = torch.device("cpu")):
        """
        input_dim: dimensionality of the input
        latent_dim: dimensionality used for ODE. Analog of a continous latent state
        """
        super(ODEFunc, self).__init__()

        self.input_dim = input_dim
        self.device = device
        self.layer_type = layer_type
        utils.init_network_weights(ode_func_net)
        self.gradient_net = ode_func_net
    def __init__(self,
                 input_size,
                 input_size_for_decay,
                 hidden_size,
                 device,
                 bias=True):
        super(GRUCellExpDecay, self).__init__(input_size,
                                              hidden_size,
                                              bias,
                                              num_chunks=3)

        self.device = device
        self.input_size_for_decay = input_size_for_decay
        self.decay = nn.Sequential(nn.Linear(input_size_for_decay, 1), )
        utils.init_network_weights(self.decay)
Пример #12
0
 def __init__(self, bottleneck_function_channels, ode_hidden_state_channels,
              learn_length_scale, convcnp):
     super(NeuralCDEEncoder, self).__init__()
     self.bottleneck_function_channels = bottleneck_function_channels
     self.ode_hidden_state_channels = ode_hidden_state_channels
     self.sigma = nn.Parameter(
         np.log(convcnp.init_length_scale) *
         torch.ones(self.bottleneck_function_channels),
         requires_grad=learn_length_scale)
     self.sigma_fn = torch.exp
     self.initial_hidden_state_network = torch.nn.Linear(
         self.bottleneck_function_channels, self.ode_hidden_state_channels)
     utils.init_network_weights(self.initial_hidden_state_network)
     self.current_task = None
     self.convcnp = convcnp
Пример #13
0
    def __init__(
        self,
        input_dim,
        latent_dim,
        device=torch.device("cpu"),
        z0_diffeq_solver=None,
        n_gru_units=100,
        n_units=100,
        concat_mask=False,
        obsrv_std=0.1,
        use_binary_classif=False,
        classif_per_tp=False,
        n_labels=1,
        train_classif_w_reconstr=False,
    ):

        Baseline.__init__(
            self,
            input_dim,
            latent_dim,
            device=device,
            obsrv_std=obsrv_std,
            use_binary_classif=use_binary_classif,
            classif_per_tp=classif_per_tp,
            n_labels=n_labels,
            train_classif_w_reconstr=train_classif_w_reconstr,
        )

        ode_rnn_encoder_dim = latent_dim

        self.ode_gru = Encoder_z0_ODE_RNN(
            latent_dim=ode_rnn_encoder_dim,
            input_dim=(input_dim) * 2,  # input and the mask
            z0_diffeq_solver=z0_diffeq_solver,
            n_gru_units=n_gru_units,
            device=device,
        ).to(device)

        self.z0_diffeq_solver = z0_diffeq_solver

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, n_units),
            nn.Tanh(),
            nn.Linear(n_units, input_dim),
        )

        utils.init_network_weights(self.decoder)
Пример #14
0
    def __init__(self,
                 input_dim,
                 latent_dim,
                 ode_func_net,
                 device=torch.device("cpu")):
        """
		input_dim: dimensionality of the input, will not be used...???
		latent_dim: dimensionality used for ODE. Analog of a continous latent state
		"""
        super(ODEFunc, self).__init__()

        self.input_dim = input_dim
        self.device = device

        #Here: make the initialization different
        utils.init_network_weights(ode_func_net)
        self.gradient_net = ode_func_net
Пример #15
0
    def __init__(self, input_dim, latent_dim, device=torch.device("cpu")):
        """
        input_dim: dimensionality of the input
        latent_dim: dimensionality used for ODE. Analog of a continous latent state
        """
        super(CDEFunc, self).__init__()

        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.device = device
        self.interpolation = None

        # Equ 3 in Neural Controlled Differential Equations for Irregular Time Series
        self.cde_func = utils.create_net(latent_dim,
                                         latent_dim * (input_dim + 1),
                                         n_units=10,
                                         n_layers=1)

        utils.init_network_weights(self.cde_func)
Пример #16
0
 def __init__(self, 
              in_channels,
              hidden_channels,
              out_channels,  
              derivative,
              ode_func_channels):
     super(NeuralCDEDecoder, self).__init__()
     self.hidden_channels = hidden_channels
     self.out_channels = out_channels
     self.in_channels = in_channels
     self.derivative = derivative
     self.sigma_fn = nn.Softplus()
     self.mean_layer = nn.Linear(self.hidden_channels, out_channels)
     self.sigma_layer = nn.Linear(self.hidden_channels, out_channels)
     utils.init_network_weights(self.mean_layer)
     utils.init_network_weights(self.sigma_layer)
     self.cde_func = CDEFunc(self.in_channels,
                             self.hidden_channels,
                             ode_func_channels)
Пример #17
0
    def __init__(self,
                 latent_dim,
                 input_dim,
                 update_gate=None,
                 reset_gate=None,
                 new_state_net=None,
                 n_units=100,
                 device=torch.device("cpu")):
        super(GRU_unit, self).__init__()

        if update_gate is None:
            self.update_gate = nn.Sequential(
                nn.Linear(latent_dim * 2 + input_dim, n_units), nn.Tanh(),
                nn.Linear(n_units, latent_dim), nn.Sigmoid())
            utils.init_network_weights(self.update_gate)
        else:
            self.update_gate = update_gate

        if reset_gate is None:
            self.reset_gate = nn.Sequential(
                nn.Linear(latent_dim * 2 + input_dim, n_units), nn.Tanh(),
                nn.Linear(n_units, latent_dim), nn.Sigmoid())
            utils.init_network_weights(self.reset_gate)
        else:
            self.reset_gate = reset_gate

        if new_state_net is None:
            self.new_state_net = nn.Sequential(
                nn.Linear(latent_dim * 2 + input_dim, n_units), nn.Tanh(),
                nn.Linear(n_units, latent_dim * 2))
            utils.init_network_weights(self.new_state_net)
        else:
            self.new_state_net = new_state_net
Пример #18
0
    def __init__(self, latent_dim, input_dim, n_units=100):
        super(GRU_Unit, self).__init__()

        self.update_gate = nn.Sequential(
            nn.Linear(latent_dim * 2 + input_dim, n_units), nn.Tanh(),
            nn.Linear(n_units, latent_dim), nn.Sigmoid())
        utils.init_network_weights(self.update_gate)

        self.reset_gate = nn.Sequential(
            nn.Linear(latent_dim * 2 + input_dim, n_units), nn.Tanh(),
            nn.Linear(n_units, latent_dim), nn.Sigmoid())
        utils.init_network_weights(self.reset_gate)

        self.new_state_net = nn.Sequential(
            nn.Linear(latent_dim * 2 + input_dim, n_units), nn.Tanh(),
            nn.Linear(n_units, latent_dim * 2))
        utils.init_network_weights(self.new_state_net)
Пример #19
0
    def __init__(self,
                 hidden_size,
                 input_size,
                 n_units=0,
                 bias=True,
                 use_BN=False):
        super(STAR_unit, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias

        if n_units == 0:
            self.x_K = nn.Linear(input_size, hidden_size, bias=bias)
            self.x_z = nn.Linear(input_size, hidden_size, bias=bias)
            self.h_K = nn.Linear(hidden_size, hidden_size, bias=bias)

            init.orthogonal_(self.x_K.weight)
            init.orthogonal_(self.x_z.weight)
            init.orthogonal_(self.h_K.weight)

            self.x_K.bias.data.fill_(0.)
            self.x_z.bias.data.fill_(0.)
            #self.h_K.bias.data.fill_(0.)
        else:

            self.x_K = nn.Sequential(nn.Linear(input_size, n_units), nn.Tanh(),
                                     nn.Linear(n_units, hidden_size))
            utils.init_network_weights(self.x_K, initype="ortho")

            self.x_z = nn.Sequential(nn.Linear(input_size, n_units), nn.Tanh(),
                                     nn.Linear(n_units, hidden_size))
            utils.init_network_weights(self.x_z, initype="ortho")

            self.h_K = nn.Sequential(nn.Linear(hidden_size,
                                               n_units), nn.Tanh(),
                                     nn.Linear(n_units, hidden_size))
            utils.init_network_weights(self.h_K, initype="ortho")

        self.use_BN = use_BN

        if self.use_BN:
            self.bn_x_K = nn.BatchNorm1d(hidden_size)
            self.bn_x_z = nn.BatchNorm1d(hidden_size)
            self.bn_h_K = nn.BatchNorm1d(hidden_size)
Пример #20
0
    def __init__(self,
                 input_dim,
                 latent_dim,
                 device=torch.device("cpu"),
                 z0_diffeq_solver=None,
                 n_gru_units=100,
                 n_units=100,
                 concat_mask=False,
                 obsrv_std=0.1,
                 use_binary_classif=False,
                 classif_per_tp=False,
                 n_labels=1,
                 train_classif_w_reconstr=False,
                 RNNcell='gru_small',
                 stacking=None,
                 linear_classifier=False,
                 ODE_sharing=True,
                 RNN_sharing=False,
                 include_topper=False,
                 linear_topper=False,
                 use_BN=True,
                 resnet=False,
                 ode_type="linear",
                 ode_units=200,
                 rec_layers=1,
                 ode_method="dopri5",
                 stack_order=None,
                 nornnimputation=False,
                 use_pos_encod=False,
                 n_intermediate_tp=2):

        Baseline.__init__(self,
                          input_dim,
                          latent_dim,
                          device=device,
                          obsrv_std=obsrv_std,
                          use_binary_classif=use_binary_classif,
                          classif_per_tp=classif_per_tp,
                          n_labels=n_labels,
                          train_classif_w_reconstr=train_classif_w_reconstr)

        self.include_topper = include_topper
        self.resnet = resnet
        self.use_BN = use_BN
        ode_rnn_encoder_dim = latent_dim

        if ODE_sharing or RNN_sharing or self.resnet or self.include_topper:
            self.include_topper = True
            input_dim_first = latent_dim
        else:
            input_dim_first = input_dim

        if RNNcell == 'lstm':
            ode_latents = int(latent_dim) * 2
        else:
            ode_latents = int(latent_dim)

        #need one Encoder_z0_ODE_RNN per layer.
        self.ode_gru = []
        self.z0_diffeq_solver = []
        first_layer = True
        rnn_input = input_dim_first * 2

        if stack_order is None:
            stack_order = [
                "ode_rnn"
            ] * stacking  # a list of ode_rnn, star, gru, gru_small, lstm

        self.stacking = stacking
        if not (len(stack_order) == stacking
                ):  # stack_order argument must be as long as the stacking list
            print(
                "Warning, the specified stacking order is not the same length as the number of stacked layers, taking stack-order as reference."
            )
            print("Stack-order: ", stack_order)
            print("Stacking argument: ", stacking)
            self.stacking = len(stack_order)

        # get the default ODE and RNN for the weightsharing
        # ODE stuff
        z0_diffeq_solver = get_diffeq_solver(ode_latents,
                                             ode_units,
                                             rec_layers,
                                             ode_method,
                                             ode_type="linear",
                                             device=device)

        # RNNcell
        if RNNcell == 'gru':
            RNN_update = GRU_unit(latent_dim,
                                  rnn_input,
                                  n_units=n_gru_units,
                                  device=device).to(device)

        elif RNNcell == 'gru_small':
            RNN_update = GRU_standard_unit(latent_dim,
                                           rnn_input,
                                           device=device).to(device)

        elif RNNcell == 'lstm':
            RNN_update = LSTM_unit(latent_dim, rnn_input).to(device)

        elif RNNcell == "star":
            RNN_update = STAR_unit(latent_dim, rnn_input,
                                   n_units=n_gru_units).to(device)

        else:
            raise Exception(
                "Invalid RNN-cell type. Hint: expdecay not available for ODE-RNN"
            )

        # Put the layers it into the model
        for s in range(self.stacking):

            use_ODE = (stack_order[s] == "ode_rnn")

            if first_layer:
                # input and the mask
                layer_input_dimension = (input_dim_first) * 2
                first_layer = False

            else:
                # otherwise we just take the latent dimension of the previous layer as the sequence
                layer_input_dimension = latent_dim * 2

            # append the same z0_ODE-RNN for every layer

            if not RNN_sharing:

                if not use_ODE:
                    if use_pos_encod:
                        vertical_rnn_input = layer_input_dimension + 4  # +4 for 2dim encoding and it's mask
                    else:
                        vertical_rnn_input = layer_input_dimension + 2  # +2 for delta t and it's mask

                    thisRNNcell = stack_order[s]

                else:
                    vertical_rnn_input = layer_input_dimension
                    thisRNNcell = RNNcell

                if thisRNNcell == 'gru':
                    #pdb.set_trace()
                    RNN_update = GRU_unit(latent_dim,
                                          vertical_rnn_input,
                                          n_units=n_gru_units,
                                          device=device).to(device)

                elif thisRNNcell == 'gru_small':
                    RNN_update = GRU_standard_unit(latent_dim,
                                                   vertical_rnn_input,
                                                   device=device).to(device)

                elif thisRNNcell == 'lstm':
                    # two times latent dimension because of the cell state!
                    RNN_update = LSTM_unit(latent_dim * 2,
                                           vertical_rnn_input).to(device)

                elif thisRNNcell == "star":
                    RNN_update = STAR_unit(latent_dim,
                                           vertical_rnn_input,
                                           n_units=n_gru_units).to(device)

                else:
                    raise Exception(
                        "Invalid RNN-cell type. Hint: expdecay not available for ODE-RNN"
                    )

            if not ODE_sharing:

                if RNNcell == 'lstm':
                    ode_latents = int(latent_dim) * 2
                else:
                    ode_latents = int(latent_dim)

                z0_diffeq_solver = get_diffeq_solver(ode_latents,
                                                     ode_units,
                                                     rec_layers,
                                                     ode_method,
                                                     ode_type="linear",
                                                     device=device)

            self.Encoder0 = Encoder_z0_ODE_RNN(
                latent_dim=ode_rnn_encoder_dim,
                input_dim=layer_input_dimension,
                z0_diffeq_solver=z0_diffeq_solver,
                n_gru_units=n_gru_units,
                device=device,
                RNN_update=RNN_update,
                use_BN=use_BN,
                use_ODE=use_ODE,
                nornnimputation=nornnimputation,
                use_pos_encod=use_pos_encod,
                n_intermediate_tp=n_intermediate_tp).to(device)

            self.ode_gru.append(self.Encoder0)

        # construct topper
        if self.include_topper:
            if linear_topper:
                self.topper = nn.Sequential(
                    nn.Linear(input_dim, latent_dim),
                    nn.Tanh(),
                ).to(device)
            else:
                self.topper = nn.Sequential(
                    nn.Linear(input_dim, 100),
                    nn.Tanh(),
                    nn.Linear(100, latent_dim),
                    nn.Tanh(),
                ).to(device)

            utils.init_network_weights(self.topper)

            self.topper_bn = nn.BatchNorm1d(latent_dim)
        """
		self.decoder = nn.Sequential(
			nn.Linear(latent_dim, n_units),
			nn.Tanh(),
			nn.Linear(n_units, input_dim),)

		utils.init_network_weights(self.decoder)
		"""

        z0_dim = latent_dim

        # get the end-of-sequence classifier
        if use_binary_classif:
            if linear_classifier:
                self.classifier = nn.Sequential(nn.Linear(z0_dim, n_labels),
                                                nn.Softmax(dim=(2)))
            else:
                self.classifier = create_classifier(z0_dim, n_labels)
            utils.init_network_weights(self.classifier)

            if self.use_BN:
                self.bn_lasthidden = nn.BatchNorm1d(latent_dim)

        self.device = device
Пример #21
0
    def __init__(self, n_heads=2,d_input=6, d_k=6,dropout = 0.1,**kwargs):
        super(GTrans, self).__init__(aggr='add', **kwargs)
        self.n_heads = n_heads
        self.dropout = nn.Dropout(dropout)

        self.d_input = d_input
        self.d_k = d_k//n_heads
        self.d_q = d_k//n_heads
        self.d_e = d_k//n_heads
        self.d_sqrt = math.sqrt(d_k//n_heads)

        #Attention Layer Initialization
        self.w_k_list_same = nn.ModuleList([nn.Linear(self.d_input, self.d_k, bias=True) for i in range(self.n_heads)])
        self.w_k_list_diff = nn.ModuleList([nn.Linear(self.d_input, self.d_k, bias=True) for i in range(self.n_heads)])
        self.w_q_list = nn.ModuleList([nn.Linear(self.d_input, self.d_q, bias=True) for i in range(self.n_heads)])
        self.w_v_list_same = nn.ModuleList([nn.Linear(self.d_input, self.d_e, bias=True) for i in range(self.n_heads)])
        self.w_v_list_diff = nn.ModuleList([nn.Linear(self.d_input, self.d_k, bias=True) for i in range(self.n_heads)])

        #self.w_transfer = nn.ModuleList([nn.Linear(self.d_input*2, self.d_k, bias=True) for i in range(self.n_heads)])
        self.w_transfer = nn.ModuleList([nn.Linear(self.d_input +1, self.d_k, bias=True) for i in range(self.n_heads)])

        #initiallization
        utils.init_network_weights(self.w_k_list_same)
        utils.init_network_weights(self.w_k_list_diff)
        utils.init_network_weights(self.w_q_list)
        utils.init_network_weights(self.w_v_list_same)
        utils.init_network_weights(self.w_v_list_diff)
        utils.init_network_weights(self.w_transfer)


        #Temporal Layer
        self.temporal_net = TemporalEncoding(d_input)

        #Normalization
        self.layer_norm = nn.LayerNorm(d_input)