Exemplo n.º 1
0
    def forward(self, x):
        output = []
        x = torch.selu(self.batch1(self.conv7(x)))
        x = self.max_pool(torch.selu(self.batch2(self.conv3(x))))

        for i in range(self.n_stack - 1):
            init = x
            x = self.hours[i](x)
            x = self.bottles_1[i](x)

            inter_out = []
            for j in range(len(self.out_ch)):
                intermediate_out = self.out_activation[i](
                    self.intermediate_out[i][j](x))
                inter_out.append(intermediate_out)
            output.append(inter_out)
            x = self.bottles_2[i](x)
            intermediate_out = torch.cat([inter for inter in inter_out], dim=1)
            inter = self.bottles_3[i](intermediate_out)
            x = init + inter + x

        last_hour = self.hours[-1](x)

        last_out = []
        for i in range(len(self.out_ch)):
            out = torch.selu(self.out_batch[i](self.out_front[i](last_hour)))
            out = self.out_activation[i](self.out[i](out))
            last_out.append(out)
        output.append(last_out)

        return output
Exemplo n.º 2
0
    def clean_action(self, state, return_only_action=True):
        """Method to forward propagate through the actor's graph
            Parameters:
                  input (tensor): states
            Returns:
                  action (tensor): actions
        """
        #x = self.knet(state)

        #Goal+Feature Processor
        obs = self.feature1(state[:, 0:self.obs_dim])
        obs = torch.selu(obs)

        dict = self.goal1(state[:, self.obs_dim:])
        dict = torch.selu(dict)
        dict = self.goal2(dict)
        dict = torch.selu(dict)

        x = torch.cat([obs, dict], axis=1)

        #Shared
        x = torch.selu(self.linear1(x))
        x = torch.selu(self.linear2(x))
        mean = self.mean_linear(x)

        if return_only_action: return torch.tanh(mean)

        log_std = self.log_std_linear(x)
        log_std = torch.clamp(log_std,
                              min=self.LOG_SIG_MIN,
                              max=self.LOG_SIG_MAX)
        return mean, log_std
Exemplo n.º 3
0
def dist_dqn(x, theta, a_space=3):
    """
    :param x: 128-element vector state
    :param theta: parameter vector -> NP array
    :param a_space: size of action space
    """
    # layer dimensions
    dim_0, dim_1, dim_2, dim_3 = 128, 100, 25, 51
    t1 = dim_0 * dim_1
    t2 = dim_1 * dim_2
    theta1 = theta[0:t1].reshape(dim_0, dim_1)
    theta2 = theta[t1:t1+t2].reshape(dim_1, dim_2)
    l1 = x @ theta1  # (Bx128) -> (Bx100)
    l1 = torch.selu(l1)
    l2 = l1 @ theta2  # (Bx100) -> (Bx25)
    l2 = torch.selu(l2)
    l3 = []
    for i in range(a_space):
        # loop thru each action to get each a-v distribution
        # separate set of parameters for each action
        step = dim_2 * dim_3
        theta5_dim = t1 + t2 + i * step
        theta5 = theta[theta5_dim:theta5_dim + step].reshape(dim_2, dim_3)
        l3_ = l2 @ theta5  # (Bx25) -> (Bx51)
        l3.append(l3_)
    l3 = torch.stack(l3, dim=1)  # (Bx3x51)
    l3 = torch.nn.functional.softmax(l3, dim=2)
    return l3.squeeze()
Exemplo n.º 4
0
    def forward(self, observation):
        hidden_mu = torch.selu(self.l1_mu(observation))
        mu = torch.tanh(self.l2_mu(hidden_mu))

        hidden_sigma = torch.selu(self.l1_sigma(observation))
        sigma = (torch.sigmoid(self.l2_sigma(hidden_sigma)) *
                 self.scale) + self.min_sigma
        return Normal(mu, sigma)
Exemplo n.º 5
0
    def forward(self, z):
        #print("x.shape ", x.shape)

        x =   self.layer1(z) 
        x =   torch.selu(self.layer2(x)  )
        x =   torch.selu(self.layer3(x)  )
        x =   torch.sigmoid(self.layer4(x) )

        return x
Exemplo n.º 6
0
    def forward(self, x, epoch, n_epochs):
        #print("x.shape ", x.shape)

        z =   self.layer1(x) 
        z =   torch.selu( self.layer2(z) )
        z =   torch.selu( self.layer3(z) )
        z = torch.sigmoid(  self.layer4(z)  )
        #print("z.shape ", z.shape)
        '''
        if(epoch >= n_epochs/2):
            z = z.where(z < 0.5, torch.ones(self.code_size).to(device))
            z = z.where(z >= 0.5, torch.zeros(self.code_size).to(device))
        '''
        return z
Exemplo n.º 7
0
    def forward(self, x, alpha):
        B = x.shape[0]
        # x: [B, 192, H/16, W/16]
        x = F.leaky_relu(self.c1(x))  # [B, 256, 14, 18]
        x = self.max_pool(x)  # [B, 256, 7, 9]
        x = F.leaky_relu(self.c2(x))  # [B, 256, 5, 7]
        x = F.leaky_relu(self.c3(x))  # [B, 32, 3, 5]
        x = x.reshape(B, -1)
        x = self.l2(torch.selu(self.l1(x)))

        pup_c = self.c_actfunc(x[:, 0:2])
        pup_param = self.param_actfunc(x[:, 2:4])
        pup_angle = x[:, 4]
        iri_c = self.c_actfunc(x[:, 5:7])
        iri_param = self.param_actfunc(x[:, 7:9])
        iri_angle = x[:, 9]

        op = torch.cat([
            pup_c, pup_param,
            pup_angle.unsqueeze(1), iri_c, iri_param,
            iri_angle.unsqueeze(1)
        ],
                       dim=1)
        #print(op)
        return op
    def forward(self, *inputs: torch.Tensor) -> torch.Tensor:
        out = self.transform_inputs(inputs)

        for layer in self.layers:
            out = torch.selu(layer(out))
        out = self.output(out)
        return torch.log_softmax(out, dim=1)
Exemplo n.º 9
0
 def encoder(self, x):
     for idx, w in enumerate(self.encode_w):
         x = torch.selu(
             input=F.linear(input=x, weight=w, bias=self.encode_b[idx]))
     if self._dp_drop_prob > 0:
         x = self.drop(x)
     return x
Exemplo n.º 10
0
    def forward(self, x):
        out_put = []
        x = self.conv7(x)
        x = self.max_pool(self.block1(x))
        x = self.block2(x)
        x = self.block3(x)
        # org = x
        for i in range(self.n_stack - 1):
            init = x
            x = self.hours[i](x)
            x = self.bottles_1[i](x)

            inter_h = self.out_activation[0](self.inter_h_o[i](x))
            inter_l = self.out_activation[1](self.inter_l_o[i](x))

            out_put.append([inter_h, inter_l])
            x = self.bottles_2[i](x)

            i_h = self.inter_h_a[i](inter_h)
            i_l = self.inter_l_a[i](inter_l)
            inter = (i_h + i_l) / 2
            x = x + init + inter
            # x = x + org + inter

        last_hour = self.hours[-1](x)

        out_block = torch.selu(self.batch3(self.block4(last_hour)))
        h_f = self.out_h_f(out_block)
        l_f = self.out_l_f(out_block)
        out_h = self.out_activation[0](self.out_h(h_f))
        out_l = self.out_activation[1](self.out_l(l_f))

        out_put.append([out_h, out_l])

        return out_put
Exemplo n.º 11
0
 def forward(self, x, adj_t):
     x = self.conv1(x, adj_t)
     x = torch.selu(x)
     x = self.bn(x)
     x = self.dropout(x)
     x = self.conv2(x, adj_t)
     x = F.log_softmax(x, dim=1)
     return x
Exemplo n.º 12
0
    def forward(self, inp, action):
        """Method to forward propagate through the critic's graph

             Parameters:
                   input (tensor): states
                   input (tensor): actions

             Returns:
                   Q1 (tensor): Qval 1
                   Q2 (tensor): Qval 2
                   V (tensor): Value



         """
        #Goal+Feature Processor
        obs = self.feature1(inp[:, 0:self.obs_dim])
        obs = torch.selu(obs)

        dict = self.goal1(inp[:, self.obs_dim:])
        dict = torch.selu(dict)
        dict = self.goal2(dict)
        dict = torch.selu(dict)

        #Concatenate observation+action as critic state
        state = torch.cat([obs, dict, action], 1)

        ###### Q1 HEAD ####
        q1 = torch.selu(self.q1f1(state))
        #q1 = self.q1ln1(q1)
        q1 = torch.selu(self.q1f2(q1))
        #q1 = self.q1ln2(q1)
        q1 = self.q1out(q1)

        ###### Q2 HEAD ####
        q2 = torch.selu(self.q2f1(state))
        #q2 = self.q2ln1(q2)
        q2 = torch.selu(self.q2f2(q2))
        #q2 = self.q2ln2(q2)
        q2 = self.q2out(q2)

        if self.USE_V:
            ###### Value HEAD ####
            v = torch.selu(self.v1(torch.cat([obs, dict], 1)))
            #v = self.vln1(v)
            v = torch.tanh(self.v2(v))
            #v = self.vln2(v)
            v = self.vout(v)

            q1 -= v
            q2 -= v

        #self.half()

        return q1, q2, None
Exemplo n.º 13
0
    def clean_action(self, state, return_only_action=True):
        """Method to forward propagate through the actor's graph
            Parameters:
                  input (tensor): states
            Returns:
                  action (tensor): actions
        """
        #x = self.knet(state)

        #Goal+Feature Processor
        obs = self.feature1(state[:, 0:97])
        obs = torch.selu(obs)

        dict = self.goal1(state[:, 97:])
        dict = torch.selu(dict)
        dict = self.goal2(dict)
        dict = torch.selu(dict)

        x = torch.cat([obs, dict], axis=1)

        #Shared
        x = torch.selu(self.linear1(x))
        x = torch.selu(self.linear2(x))

        val = self.val(x)
        adv = self.adv(x)

        #DDQN Style baselining
        for i in range(0, self.num_heads * self.num_actions, self.num_heads):

            # if len(state) > 1:
            #     print(val[:,int(i/self.num_heads)].shape, adv[:,i:i+self.num_heads].shape, adv[:,i:i+self.num_heads].mean().shape)
            #     input()

            adv[:, i:i +
                self.num_heads] = val[:, int(i / self.num_heads)].unsqueeze(
                    1) + adv[:, i:i +
                             self.num_heads] - adv[:, i:i +
                                                   self.num_heads].mean()

        if return_only_action: return self.multi_argmax(logits=adv)

        temp = self.temperature(x)
        temp = torch.clamp(temp, min=self.TEMP_MIN, max=self.TEMP_MAX)
        return adv, temp
Exemplo n.º 14
0
 def forward(self, x):
     x = torch.selu(self.linear1(x))
     x = self.dropout1(torch.selu(self.linear2(x)))
     x = self.dropout2(torch.selu(self.linear3(x)))
     x = self.dropout3(torch.selu(self.linear4(x)))
     x = self.dropout4(torch.selu(self.linear5(x)))
     x = self.dropout5(torch.selu(self.linear6(x)))
     x = self.dropout6(torch.selu(self.linear7(x)))
     x = self.dropout7(torch.selu(self.linear8(x)))
     x = self.linear9(x)
     return x
Exemplo n.º 15
0
 def decoder(self, z):
     for idx, w in enumerate(list(reversed(self.encode_w))):
         if idx != self._last:
             z = torch.selu(input=F.linear(input=z,
                                           weight=w.transpose(0, 1),
                                           bias=self.decode_b[idx]))
         else:
             z = torch.sigmoid(
                 F.linear(input=z,
                          weight=w.transpose(0, 1),
                          bias=self.decode_b[idx]))
     return z
Exemplo n.º 16
0
    def clean_action(self, state, return_only_action=True):
        """Method to forward propagate through the critic's graph

             Parameters:
                   input (tensor): states
                   input (tensor): actions

             Returns:
                   Q1 (tensor): Qval 1
                   Q2 (tensor): Qval 2
                   V (tensor): Value



         """
        #Goal+Feature Processor
        obs = self.feature1(state[:, 0:97])
        obs = torch.selu(obs)

        dict = self.goal1(state[:, 97:])
        dict = torch.selu(dict)
        dict = self.goal2(dict)
        dict = torch.selu(dict)

        x = torch.cat([obs, dict], axis=1)

        #Shared
        x = torch.selu(self.linear1(x))
        x = torch.selu(self.linear2(x))

        val = self.val(x)
        adv = self.adv(x)

        logits = val + adv - adv.mean()

        if return_only_action:
            return self.multi_argmax(logits)
        else:
            return self.multi_argmax(logits), None, logits
Exemplo n.º 17
0
    def forward(self, observation):
        inter = self.l1(observation)
        hidden = torch.selu(inter)
        if torch.isnan(hidden).any():
            raise ExplodedGradient
        action = self.l2(hidden)
        if torch.isnan(action).any():
            raise ExplodedGradient
        mu, sigma = torch.split(action, self.actions, dim=1)
        mu = torch.tanh(mu)
        sigma = (torch.sigmoid(sigma) * self.scale) + self.min_sigma

        return Normal(mu, sigma)
Exemplo n.º 18
0
    def forward(self, Corpus_, batch_inputs, entity_embeddings, relation_embed,
                edge_list, edge_type, edge_embed, edge_list_nhop,
                edge_type_nhop):
        x = entity_embeddings
        edge_embed_nhop = relation_embed[
            edge_type_nhop[:, 0]] + relation_embed[edge_type_nhop[:, 1]]

        #    def forward(self, input, edge, edge_embed, edge_list_nhop, edge_embed_nhop):
        # multi head h'
        x = torch.cat([
            att(x, edge_list, edge_embed, edge_list_nhop, edge_embed_nhop)
            for att in self.attentions
        ],
                      dim=1)
        x = self.dropout_layer(x)

        head_e = x[edge_list[0, :], :]
        tail_e = x[edge_list[1, :], :]
        h_t_hdm_e = head_e * tail_e
        # head
        head_n = x[edge_list_nhop[0, :], :]
        # tail
        tail_n = x[edge_list_nhop[1, :], :]
        h_t_hdm_n = head_n * tail_n
        out_relation_1 = relation_embed.mm(self.W)
        #out_relation_1 = out_relation_1 + h_t_hdm
        edge_embed = out_relation_1[edge_type] + torch.selu(h_t_hdm_e)
        edge_embed_nhop = out_relation_1[
            edge_type_nhop[:, 0]] + out_relation_1[
                edge_type_nhop[:, 1]] + torch.selu(h_t_hdm_n)

        x = F.elu(
            self.out_att(x, edge_list, edge_embed, edge_list_nhop,
                         edge_embed_nhop))
        del edge_embed_nhop, edge_embed
        torch.cuda.empty_cache()
        return x, out_relation_1
Exemplo n.º 19
0
 def forward(self, inputs):
     """Apply sequence embedding CNN to inputs.
     
     Parameters
     ----------
     inputs: torch.Tensor
         Torch tensor of shape (n_sequences, n_in_features, n_sequence_positions), where n_in_features is
         `n_aa_features + n_position_features = 20 + 3 = 23`.
     
     Returns
     ---------
     max_conv_acts: torch.Tensor
         Sequences embedded to tensor of shape (n_sequences, n_kernels)
     """
     # Calculate activations for AAs and positions
     conv_acts = torch.selu(self.conv_aas(inputs))
     # Apply additional conv. layers
     conv_acts = self.additional_convs(conv_acts)
     # Take maximum over sequence positions (-> 1 output per kernel per sequence)
     max_conv_acts, _ = conv_acts.max(dim=-1)
     return max_conv_acts
Exemplo n.º 20
0
 def forward(self, observation):
     hidden = torch.selu(self.l1(observation))
     hidden = self.l2(hidden)
     logprobs = NN.log_softmax(hidden, dim=1)
     return Categorical(logits=logprobs)
Exemplo n.º 21
0
 def forward(self, observation):
     hidden = torch.selu(self.l1(observation))
     hidden = self.l2(hidden)
     return NN.log_softmax(hidden, dim=1)
Exemplo n.º 22
0
def selu(x: T.Tensor, **kwargs):
    """
    SELU activation.
    """

    return T.selu(x)
 def act(self, x: Tensor) -> Tensor:
     return tr.selu(x)
 def act2(self, z: Tensor) -> Tensor:
     return tr.selu(z)
Exemplo n.º 25
0
 def forward(self, input):
     embedding = torch.selu(self.cnn(input))
     embedding = torch.selu(self.linear1(embedding))
     pred_logits = self.linear2(embedding)
     return pred_logits
Exemplo n.º 26
0
 def __chanel_attention__(ch_input):
     temp = torch.flatten(ch_input, start_dim=1)
     temp = torch.selu(self.w0(temp))
     temp = self.w1(temp)
     return temp
Exemplo n.º 27
0
 def forward(self, batch):
     batch_size = batch['current_loc'].shape[0]
     origin_len = batch.get_origin_len('current_loc')
     current_loc = batch['current_loc']
     current_tim = batch['current_tim']
     items = self.loc_emb(current_loc).permute(
         1, 0, 2)  # sequence * batch_size * embedding
     current_loc = current_loc.tolist()
     # pack x and history_x
     pack_items = pack_padded_sequence(items,
                                       lengths=origin_len,
                                       enforce_sorted=False)
     h1 = torch.zeros(1, batch_size, self.hidden_size).to(self.device)
     c1 = torch.zeros(1, batch_size, self.hidden_size).to(self.device)
     out, (h1, c1) = self.lstmcell(pack_items, (h1, c1))
     # batch_size * sequence_length * hidden_size
     out, _ = pad_packed_sequence(out, batch_first=True)
     items = items.permute(1, 0, 2)  # sequence * batch_size * embeeding
     y_list = []
     out_hie = []  # batch_size * hidden_size
     dilated_rnn_input_index = batch['dilated_rnn_input_index']
     for ii in range(batch_size):
         current_session_input_dilated_rnn_index = dilated_rnn_input_index[
             ii].tolist()  # origin_cur_len
         hiddens_current = items[ii]
         dilated_lstm_outs_h = []
         dilated_lstm_outs_c = []
         for index_dilated in range(
                 len(current_session_input_dilated_rnn_index)):
             index_dilated_explicit = current_session_input_dilated_rnn_index[
                 index_dilated]
             hidden_current = hiddens_current[index_dilated].unsqueeze(0)
             if index_dilated == 0:
                 h = torch.zeros(1, self.hidden_size).to(self.device)
                 c = torch.zeros(1, self.hidden_size).to(self.device)
                 (h, c) = self.dilated_rnn(hidden_current, (h, c))
                 dilated_lstm_outs_h.append(h)
                 dilated_lstm_outs_c.append(c)
             else:
                 (h, c) = self.dilated_rnn(
                     hidden_current,
                     (dilated_lstm_outs_h[index_dilated_explicit],
                      dilated_lstm_outs_c[index_dilated_explicit]))
                 dilated_lstm_outs_h.append(h)
                 dilated_lstm_outs_c.append(c)
         out_hie.append(dilated_lstm_outs_h[-1])
         current_session_timid = current_tim[ii].tolist()[
             origin_len[ii] - 1]  # 不包含我 pad 的那个点
         current_session_embed = out[ii]  # sequence_len * hidden_size
         # FloatTensor sequence_len * 1
         current_session_mask = self._pad_batch_of_lists_masks(
             current_loc[ii], origin_len[ii]).unsqueeze(1)
         # mask_batch_ix_non_local[ii].unsqueeze(1)
         sequence_length = origin_len[ii]
         # do average pooling for current_session
         # 1 * hidden_size
         current_session_represent = torch.sum(
             current_session_embed * current_session_mask,
             dim=0).unsqueeze(0) / sum(current_session_mask)
         list_for_sessions = []  # his_cnt * hidden_size
         h2 = torch.zeros(1, 1, self.hidden_size).to(self.device)
         c2 = torch.zeros(1, 1, self.hidden_size).to(self.device)
         # 处理历史轨迹
         for jj in range(len(batch['history_loc'][ii])):
             sequence = batch['history_loc'][ii][jj]
             sequence_emb = self.loc_emb(sequence).unsqueeze(
                 1)  # his_seq_len * 1 * embedding_size
             sequence_emb, (h2, c2) = self.lstmcell_history(
                 sequence_emb, (h2, c2))
             sequence_tim_id = batch['history_tim'][ii][jj].tolist()
             # 根据 time slot 相似度修正历史轨迹表征
             # tim_size
             jaccard_sim_row = torch.FloatTensor(
                 self.tim_sim_matrix[current_session_timid]).to(self.device)
             jaccard_sim_expicit = jaccard_sim_row[
                 sequence_tim_id]  # his_seq_len
             jaccard_sim_expicit_last = F.softmax(jaccard_sim_expicit,
                                                  dim=0).unsqueeze(
                                                      0)  # 1 * his_seq_len
             # 1 * hidden_size
             hidden_sequence_for_current = torch.mm(
                 jaccard_sim_expicit_last, sequence_emb.squeeze(1))
             list_for_sessions.append(hidden_sequence_for_current)
         # 1 * his_cnt
         avg_distance = batch['history_avg_distance'][ii].unsqueeze(0)
         # 1 * his_cnt * hidden_size
         sessions_represent = torch.cat(list_for_sessions,
                                        dim=0).unsqueeze(0)
         # 1 * hidden_size * 1
         current_session_represent = current_session_represent.unsqueeze(2)
         # 1 * 1 * his_cnt
         sim_between_cur_his = F.softmax(
             sessions_represent.bmm(current_session_represent).squeeze(2),
             dim=1).unsqueeze(1)
         # TODO: why do linear1 and selu?
         # 1 * hidden_size
         out_y_current = torch.selu(
             self.linear1(
                 sim_between_cur_his.bmm(sessions_represent).squeeze(1)))
         # 1 * hidden_size * 1
         layer_2_current = (
             0.5 * out_y_current +
             0.5 * current_session_embed[sequence_length - 1]).unsqueeze(2)
         # 1 * 1 * his_cnt
         layer_2_sims = F.softmax(
             sessions_represent.bmm(layer_2_current).squeeze(2) * 1.0 /
             avg_distance,
             dim=1).unsqueeze(1)
         # 1 * hidden_size
         out_layer_2 = layer_2_sims.bmm(sessions_represent).squeeze(1)
         y_list.append(out_layer_2)
     # batch_size * hidden_size
     y = torch.selu(torch.cat(y_list, dim=0))
     # 得到 shor-term 的输出
     # batch_size * hidden_size
     out_hie = F.selu(torch.cat(out_hie, dim=0))
     # get the final lstm out
     final_out_index = torch.tensor(origin_len) - 1
     final_out_index = final_out_index.reshape(final_out_index.shape[0], 1,
                                               -1)
     final_out_index = final_out_index.repeat(1, 1, self.hidden_size).to(
         self.device)
     final_out = torch.gather(out, 1, final_out_index).squeeze(1)
     final_out = F.selu(final_out)
     final_out = (final_out + out_hie) * 0.5
     out_put_emb_v1 = torch.cat([y, final_out], dim=1)
     output_ln = self.linear(out_put_emb_v1)
     # batch_size * loc_size
     output = F.log_softmax(output_ln, dim=1)
     return output
Exemplo n.º 28
0
    def forward(self, user_vectors, item_vectors, mask_batch_ix_non_local, session_id_batch, sequence_tim_batch, is_train, poi_distance_matrix, sequence_dilated_rnn_index_batch):
        batch_size = item_vectors.size()[0]
        sequence_size = item_vectors.size()[1]
        items = self.item_emb(item_vectors)
        item_vectors = item_vectors.cpu()
        x = items
        x = x.transpose(0, 1)
        h1 = Variable(torch.zeros(1, batch_size, self.hidden_units)).cuda()
        c1 = Variable(torch.zeros(1, batch_size, self.hidden_units)).cuda()
        out, (h1, c1) = self.lstmcell(x, (h1, c1))
        out = out.transpose(0, 1)#batch_size * sequence_length * embedding_dim
        x1 = items
        # ###########################################################
        user_batch = np.array(user_vectors.cpu())
        y_list = []
        out_hie = []
        for ii in range(batch_size):
            ##########################################
            current_session_input_dilated_rnn_index = sequence_dilated_rnn_index_batch[ii]
            hiddens_current = x1[ii]
            dilated_lstm_outs_h = []
            dilated_lstm_outs_c = []
            for index_dilated in range(len(current_session_input_dilated_rnn_index)):
                index_dilated_explicit = current_session_input_dilated_rnn_index[index_dilated]
                hidden_current = hiddens_current[index_dilated].unsqueeze(0)
                if index_dilated == 0:
                    h = Variable(torch.zeros(1, self.hidden_units)).cuda()
                    c = Variable(torch.zeros(1, self.hidden_units)).cuda()
                    (h, c) = self.dilated_rnn(hidden_current, (h, c))
                    dilated_lstm_outs_h.append(h)
                    dilated_lstm_outs_c.append(c)
                else:
                    (h, c) = self.dilated_rnn(hidden_current, (dilated_lstm_outs_h[index_dilated_explicit], dilated_lstm_outs_c[index_dilated_explicit]))
                    dilated_lstm_outs_h.append(h)
                    dilated_lstm_outs_c.append(c)
            dilated_lstm_outs_h.append(hiddens_current[len(current_session_input_dilated_rnn_index):])
            dilated_out = torch.cat(dilated_lstm_outs_h, dim = 0).unsqueeze(0)
            out_hie.append(dilated_out)
            user_id_current = user_batch[ii]
            current_session_timid = sequence_tim_batch[ii][:-1]
            current_session_poiid = item_vectors[ii][:len(current_session_timid)]
            session_id_current = session_id_batch[ii]
            current_session_embed = out[ii]
            current_session_mask = mask_batch_ix_non_local[ii].unsqueeze(1)
            sequence_length = int(sum(np.array(current_session_mask.cpu()))[0])
            current_session_represent_list = []
            if is_train:
                for iii in range(sequence_length-1):
                    current_session_represent = torch.sum(current_session_embed * current_session_mask, dim=0).unsqueeze(0)/sum(current_session_mask)
                    current_session_represent_list.append(current_session_represent)
            else:
                for iii in range(sequence_length-1):
                    current_session_represent_rep_item = current_session_embed[0:iii+1]
                    current_session_represent_rep_item = torch.sum(current_session_represent_rep_item, dim = 0).unsqueeze(0)/(iii + 1)
                    current_session_represent_list.append(current_session_represent_rep_item)

            current_session_represent = torch.cat(current_session_represent_list, dim = 0)
            list_for_sessions = []
            list_for_avg_distance = []
            h2 = Variable(torch.zeros(1, 1, self.hidden_units)).cuda()###whole sequence
            c2 = Variable(torch.zeros(1, 1, self.hidden_units)).cuda()
            for jj in range(session_id_current):
                sequence = [s[0] for s in self.data_neural[user_id_current]['sessions'][jj]]
                sequence = Variable(torch.LongTensor(np.array(sequence))).cuda()
                sequence_emb = self.item_emb(sequence).unsqueeze(1)
                sequence = sequence.cpu()
                sequence_emb, (h2, c2) = self.lstmcell_history(sequence_emb, (h2, c2))
                sequence_tim_id = [s[1] for s in self.data_neural[user_id_current]['sessions'][jj]]
                jaccard_sim_row = Variable(torch.FloatTensor(self.tim_sim_matrix[current_session_timid]),requires_grad=False).cuda()
                jaccard_sim_expicit = jaccard_sim_row[:,sequence_tim_id]
                distance_row = poi_distance_matrix[current_session_poiid]
                distance_row_expicit = Variable(torch.FloatTensor(distance_row[:,sequence]),requires_grad=False).cuda()
                distance_row_expicit_avg = torch.mean(distance_row_expicit, dim = 1)
                jaccard_sim_expicit_last = F.softmax(jaccard_sim_expicit)
                hidden_sequence_for_current1 = torch.mm(jaccard_sim_expicit_last, sequence_emb.squeeze(1))
                hidden_sequence_for_current =  hidden_sequence_for_current1
                list_for_sessions.append(hidden_sequence_for_current.unsqueeze(0))
                list_for_avg_distance.append(distance_row_expicit_avg.unsqueeze(0))
            avg_distance = torch.cat(list_for_avg_distance, dim = 0).transpose(0,1)
            sessions_represent = torch.cat(list_for_sessions, dim=0).transpose(0,1) ##current_items * history_session_length * embedding_size
            current_session_represent = current_session_represent.unsqueeze(2) ### current_items * embedding_size * 1
            sims = F.softmax(sessions_represent.bmm(current_session_represent).squeeze(2), dim = 1).unsqueeze(1) ##==> current_items * 1 * history_session_length
            #out_y_current = sims.bmm(sessions_represent).squeeze(1)
            out_y_current =torch.selu(self.linear1(sims.bmm(sessions_represent).squeeze(1)))
            ##############layer_2
            #layer_2_current = (lambda*out_y_current + (1-lambda)*current_session_embed[:sequence_length-1]).unsqueeze(2) #lambda from [0.1-0.9] better performance
            # layer_2_current = (out_y_current + current_session_embed[:sequence_length-1]).unsqueeze(2)##==>current_items * embedding_size * 1
            layer_2_current = (0.5 *out_y_current + 0.5 * current_session_embed[:sequence_length - 1]).unsqueeze(2)
            layer_2_sims =  F.softmax(sessions_represent.bmm(layer_2_current).squeeze(2) * 1.0/avg_distance, dim = 1).unsqueeze(1)##==>>current_items * 1 * history_session_length
            out_layer_2 = layer_2_sims.bmm(sessions_represent).squeeze(1)
            out_y_current_padd = Variable(torch.FloatTensor(sequence_size - sequence_length + 1, self.emb_size).zero_(),requires_grad=False).cuda()
            out_layer_2_list = []
            out_layer_2_list.append(out_layer_2)
            out_layer_2_list.append(out_y_current_padd)
            out_layer_2 = torch.cat(out_layer_2_list,dim = 0).unsqueeze(0)
            y_list.append(out_layer_2)
        y = torch.selu(torch.cat(y_list,dim=0))
        out_hie = F.selu(torch.cat(out_hie, dim = 0))
        out = F.selu(out)
        out = (out + out_hie) * 0.5
        out_put_emb_v1 = torch.cat([y, out], dim=2)
        output_ln = self.linear(out_put_emb_v1)
        output = F.log_softmax(output_ln, dim=-1)
        return output
 def forward(self, sample_a, sample_b):
     embedding_a = torch.selu(self.cnn1d(sample_a))
     embedding_b = torch.selu(self.cnn1d(sample_b))
     abs_difference = torch.abs(embedding_a - embedding_b)
     logit_score = self.linear(abs_difference).squeeze(-1)
     return logit_score
 def update(self, nodes, messages):
     hidden_nodes = torch.selu(messages)
     return hidden_nodes