def _next_batch_data(self): if self.neg_sample_args['strategy'] == 'by': uid_list = self.uid_list[self.pr:self.pr + self.step] data_list = [] idx_list = [] positive_u = [] positive_i = torch.tensor([], dtype=torch.int64) for idx, uid in enumerate(uid_list): index = self.uid2index[uid] data_list.append(self._neg_sampling(self.dataset[index])) idx_list += [ idx for i in range(self.uid2items_num[uid] * self.times) ] positive_u += [idx for i in range(self.uid2items_num[uid])] positive_i = torch.cat( (positive_i, self.dataset[index][self.iid_field]), 0) cur_data = cat_interactions(data_list) idx_list = torch.from_numpy(np.array(idx_list)) positive_u = torch.from_numpy(np.array(positive_u)) self.pr += self.step return cur_data, idx_list, positive_u, positive_i else: cur_data = self._neg_sampling(self.dataset[self.pr:self.pr + self.step]) self.pr += self.step return cur_data, None, None, None
def _next_batch_data(self): uid_list = self.uid_list[self.pr:self.pr + self.step] data_list = [] for uid in uid_list: index = self.uid2index[uid] data_list.append(self.dataset[index]) cur_data = cat_interactions(data_list) pos_len_list = self.uid2items_num_pos[uid_list] user_len_list = self.uid2items_num[uid_list] cur_data.set_additional_info(list(pos_len_list), list(user_len_list)) self.pr += self.step return cur_data
def _neg_sampling(self, data): if self.user_inter_in_one_batch: data_len = len(data[self.uid_field]) data_list = [] for i in range(data_len): uids = data[self.uid_field][i: i + 1] neg_iids = self.sampler.sample_by_user_ids(uids, self.neg_sample_by) cur_data = data[i: i + 1] data_list.append(self.sampling_func(cur_data, neg_iids)) return cat_interactions(data_list) else: uids = data[self.uid_field] neg_iids = self.sampler.sample_by_user_ids(uids, self.neg_sample_by) return self.sampling_func(data, neg_iids)
def _neg_sample_by_point_wise_sampling(self, inter_feat, pos_idx, neg_iids, neg_idx, pos_iids): pos_inter_num = len(pos_idx[0]) neg_inter_num = len(neg_idx[0]) new_data_pos = inter_feat[pos_idx].repeat(self.neg_sample_by + 1) new_data_pos[self.iid_field][pos_inter_num:] = neg_iids new_data_pos = self.dataset.join(new_data_pos) labels_pos = torch.zeros(pos_inter_num * (self.neg_sample_by + 1)) labels_pos[:pos_inter_num] = 1 new_data_pos.update(Interaction({self.label_field: labels_pos})) if neg_inter_num > 0: new_data_neg = inter_feat[neg_idx].repeat( round(self.neg_sample_by / 3) + 1) new_data_neg[self.iid_field][neg_inter_num:] = pos_iids new_data_neg = self.dataset.join(new_data_neg) labels_neg = torch.ones(neg_inter_num * (round(self.neg_sample_by / 3) + 1)) labels_neg[:neg_inter_num] = 0 new_data_neg.update(Interaction({self.label_field: labels_neg})) new_data = cat_interactions([new_data_pos, new_data_neg]) else: new_data = cat_interactions([new_data_pos]) return new_data
def _next_batch_data(self): if self.user_inter_in_one_batch: uid_list = self.uid_list[self.pr:self.pr + self.step] data_list = [] for uid in uid_list: index = self.uid2index[uid] data_list.append(self._neg_sampling(self.dataset[index])) cur_data = cat_interactions(data_list) pos_len_list = self.uid2items_num[uid_list] user_len_list = pos_len_list * self.times cur_data.set_additional_info(list(pos_len_list), list(user_len_list)) self.pr += self.step return cur_data else: cur_data = self._neg_sampling(self.dataset[self.pr:self.pr + self.step]) self.pr += self.step return cur_data