Esempio n. 1
0
    def _next_batch_data(self):
        if self.neg_sample_args['strategy'] == 'by':
            uid_list = self.uid_list[self.pr:self.pr + self.step]
            data_list = []
            idx_list = []
            positive_u = []
            positive_i = torch.tensor([], dtype=torch.int64)

            for idx, uid in enumerate(uid_list):
                index = self.uid2index[uid]
                data_list.append(self._neg_sampling(self.dataset[index]))
                idx_list += [
                    idx for i in range(self.uid2items_num[uid] * self.times)
                ]
                positive_u += [idx for i in range(self.uid2items_num[uid])]
                positive_i = torch.cat(
                    (positive_i, self.dataset[index][self.iid_field]), 0)

            cur_data = cat_interactions(data_list)
            idx_list = torch.from_numpy(np.array(idx_list))
            positive_u = torch.from_numpy(np.array(positive_u))

            self.pr += self.step

            return cur_data, idx_list, positive_u, positive_i
        else:
            cur_data = self._neg_sampling(self.dataset[self.pr:self.pr +
                                                       self.step])
            self.pr += self.step
            return cur_data, None, None, None
Esempio n. 2
0
 def _next_batch_data(self):
     uid_list = self.uid_list[self.pr:self.pr + self.step]
     data_list = []
     for uid in uid_list:
         index = self.uid2index[uid]
         data_list.append(self.dataset[index])
     cur_data = cat_interactions(data_list)
     pos_len_list = self.uid2items_num_pos[uid_list]
     user_len_list = self.uid2items_num[uid_list]
     cur_data.set_additional_info(list(pos_len_list), list(user_len_list))
     self.pr += self.step
     return cur_data
Esempio n. 3
0
 def _neg_sampling(self, data):
     if self.user_inter_in_one_batch:
         data_len = len(data[self.uid_field])
         data_list = []
         for i in range(data_len):
             uids = data[self.uid_field][i: i + 1]
             neg_iids = self.sampler.sample_by_user_ids(uids, self.neg_sample_by)
             cur_data = data[i: i + 1]
             data_list.append(self.sampling_func(cur_data, neg_iids))
         return cat_interactions(data_list)
     else:
         uids = data[self.uid_field]
         neg_iids = self.sampler.sample_by_user_ids(uids, self.neg_sample_by)
         return self.sampling_func(data, neg_iids)
Esempio n. 4
0
 def _neg_sample_by_point_wise_sampling(self, inter_feat, pos_idx, neg_iids,
                                        neg_idx, pos_iids):
     pos_inter_num = len(pos_idx[0])
     neg_inter_num = len(neg_idx[0])
     new_data_pos = inter_feat[pos_idx].repeat(self.neg_sample_by + 1)
     new_data_pos[self.iid_field][pos_inter_num:] = neg_iids
     new_data_pos = self.dataset.join(new_data_pos)
     labels_pos = torch.zeros(pos_inter_num * (self.neg_sample_by + 1))
     labels_pos[:pos_inter_num] = 1
     new_data_pos.update(Interaction({self.label_field: labels_pos}))
     if neg_inter_num > 0:
         new_data_neg = inter_feat[neg_idx].repeat(
             round(self.neg_sample_by / 3) + 1)
         new_data_neg[self.iid_field][neg_inter_num:] = pos_iids
         new_data_neg = self.dataset.join(new_data_neg)
         labels_neg = torch.ones(neg_inter_num *
                                 (round(self.neg_sample_by / 3) + 1))
         labels_neg[:neg_inter_num] = 0
         new_data_neg.update(Interaction({self.label_field: labels_neg}))
         new_data = cat_interactions([new_data_pos, new_data_neg])
     else:
         new_data = cat_interactions([new_data_pos])
     return new_data
Esempio n. 5
0
 def _next_batch_data(self):
     if self.user_inter_in_one_batch:
         uid_list = self.uid_list[self.pr:self.pr + self.step]
         data_list = []
         for uid in uid_list:
             index = self.uid2index[uid]
             data_list.append(self._neg_sampling(self.dataset[index]))
         cur_data = cat_interactions(data_list)
         pos_len_list = self.uid2items_num[uid_list]
         user_len_list = pos_len_list * self.times
         cur_data.set_additional_info(list(pos_len_list),
                                      list(user_len_list))
         self.pr += self.step
         return cur_data
     else:
         cur_data = self._neg_sampling(self.dataset[self.pr:self.pr +
                                                    self.step])
         self.pr += self.step
         return cur_data