def train_twostage(model_ts, train_instances, test_instances, features, algoname):
     optimizer_ts = optim.Adam(model_ts.parameters(),
                        lr=args.lr, weight_decay=args.weight_decay)
     edges = {}
     edges_eval = {}
     labels_eval = {}
     for i in train_instances + test_instances:
         edges[i] = adj_train[i].indices().t()
         edges_eval_i, labels_eval_i = negative_sample(adj_all[i].indices().t(), 1, bin_adj_all[i])
         edges_eval[i] = edges_eval_i
         labels_eval[i] = labels_eval_i
     
     def get_evaluation(instances):
         test_ce = 0
         test_auc = 0
         for i in instances:
             preds_test_eval = model_ts(features[i], adj_train[i], edges_eval[i])
             test_ce += torch.nn.BCEWithLogitsLoss()(preds_test_eval, labels_eval[i])
             test_auc_i = sklearn.metrics.roc_auc_score(labels_eval[i].long().detach().numpy(), nn.Sigmoid()(preds_test_eval).detach().numpy())
             aucs[algoname][test_instances.index(i)] = test_auc
             test_auc += test_auc_i
         return test_ce/len(instances), test_auc/len(instances)
     
     for t in range(150):
         i = np.random.choice(train_instances)
         adj_input = make_normalized_adj(edge_dropout(edges[i], args.edge_dropout), bin_adj_train[i].shape[0])
         edges_eval_i, labels_i = negative_sample(edges[i], args.negsamplerate, bin_adj_train[i])
         preds = model_ts(features[i], adj_input, edges_eval_i)
         loss = torch.nn.BCEWithLogitsLoss()(preds, labels_i)
         optimizer_ts.zero_grad()
         loss.backward()
         if t % 10 == 0:
             test_ce, test_auc = get_evaluation(test_instances)
             print(t, loss.item(), test_ce.item(), test_auc)
         optimizer_ts.step()
Esempio n. 2
0
def train_twostage(model_ts):
    optimizer_ts = optim.Adam(model_ts.parameters(),
                              lr=args.lr,
                              weight_decay=args.weight_decay)
    edges = adj_train.indices().t()
    edges_test = adj_test.indices().t()
    edges_test_eval, labels_test_eval = negative_sample(
        edges_test, 1, bin_adj_train)
    #    print(edges_test_eval)
    for t in range(300):
        adj_input = make_normalized_adj(edge_dropout(edges, args.edge_dropout),
                                        n)
        edges_eval, labels = negative_sample(edges, args.negsamplerate,
                                             bin_adj_train)
        preds = model_ts(features_train, adj_input, edges_eval)
        loss = torch.nn.BCEWithLogitsLoss()(preds, labels)
        optimizer_ts.zero_grad()
        loss.backward()
        if t % 10 == 0:
            preds_test_eval = model_ts(features_train, adj_input,
                                       edges_test_eval)
            test_ce = torch.nn.BCEWithLogitsLoss()(preds_test_eval,
                                                   labels_test_eval)
            test_auc = sklearn.metrics.roc_auc_score(
                labels_test_eval.long().detach().numpy(),
                nn.Sigmoid()(preds_test_eval).detach().numpy())
            print(t, loss.item(), test_ce.item(), test_auc)
        optimizer_ts.step()
Esempio n. 3
0
    def generate_train_data(self, df):

        pos, neg, negInt, negCity, seq = [], [], [], [], []

        for city, rows in tqdm(df.groupby("city")):
            for idx, row in rows.iterrows():
                impressions = [
                    self.item_index[int(i)]
                    for i in row['impressions'].split("|")
                ]

                gtItem = self.item_index[int(row['reference'])]
                pos.append(gtItem)

                interactions = [
                    self.item_index[int(i)]
                    for i in row['interactions'].split("|")
                ] if type(row['interactions']) == str else []

                impPool = impressions if len(impressions) > 1 else np.arange(
                    len(self.item_index)).tolist()

                intPool = interactions if len(interactions) > 1 else impPool
                cityPool = [
                    self.item_index[i]
                    for i in rows['reference'].unique().tolist()
                ] if rows['reference'].nunique() > 1 else impPool

                negInt.append(negative_sample(intPool, gtItem))
                negCity.append(negative_sample(cityPool, gtItem))
                neg.append(negative_sample(impPool, gtItem))

                #             neg_feature.append(get_item_feature(sample))
                # neg_price.append(prices[impressions.index(sample)] if sample in impressions else max(prices))
                # neg_position.append(impressions.index(sample) + 1 if sample in impressions else 26)

                seq.append(interactions)
                #             feature_interactions = [get_item_feature(item_index[int(i)]) for i in row['interactions'].split("|")] if type(row['interactions']) == str else []
                #             seq_feature.append(feature_interactions)

        pos = np.array(pos)
        negInt = np.array(negInt)
        negCity = np.array(negCity)
        neg = np.array(neg)
        seq = pad_sequences(seq, maxlen=self.maxlen)
        # pos_feature = np.array(pos_feature)
        # neg_feature = np.array(neg_feature)
        # seq_feature = pad_sequences(seq_feature, maxlen=self.maxlen)
        # pos_price = np.array(pos_price)
        # neg_price = np.array(neg_price)
        # pos_position = np.array(pos_position)
        # neg_position = np.array(neg_position)
        labels = np.ones(len(pos))

        return [pos, negInt, negCity, neg, seq], [labels]
def train_twostage(model_ts):
    optimizer_ts = optim.Adam(model_ts.parameters(),
                              lr=0.005,
                              weight_decay=5e-4)
    edges = adj_train.indices().t()
    for t in range(300):
        adj_input = make_normalized_adj(edge_dropout(edges, 0.2), n)
        edges_eval, labels = negative_sample(edges, 1, bin_adj_train)
        preds = model_ts(features_train, adj_input, edges_eval)
        loss = torch.nn.BCEWithLogitsLoss()(preds, labels)
        optimizer_ts.zero_grad()
        loss.backward()
        optimizer_ts.step()
Esempio n. 5
0
    def generate_dynamic_data(self, df):
        while (True):
            for idx, rows in df.groupby("session_id"):
                sessions, itemIds, seq_items, seq_actions, seq_prices, seq_positions, labels = [], [], [], [], [], [], []
                nseq_items, nseq_actions, nseq_prices, nseq_positions = [], [], [], []
                seq_times, seq_steps = [], []
                nseq_times, nseq_steps = [], []

                seq_item, seq_action, seq_price, seq_position, seq_time, seq_step = [], [], [], [], [], []
                lastRow = rows.iloc[-1]

                if lastRow["action_type"] != "clickout item":
                    continue

                impressions = [
                    self.item_index[int(i)]
                    for i in lastRow['impressions'].split("|")
                ]
                prices = [int(i) for i in lastRow['prices'].split("|")]

                firstTime = rows.iloc[0]['timestamp']
                # print((row['timestamp'] - firstTime).total_seconds())

                gtItem = self.item_index[int(lastRow['reference'])]
                firstItem = -1
                for _i, _r in rows.iterrows():
                    _item = self.item_index[int(_r['reference'])]
                    # reduce duplicate item
                    if _item != firstTime:
                        firstTime = _item
                    else:
                        continue

                    _action, _position, _price = self.get_features(
                        _item, impressions, prices, _r)

                    if self.timeMode == 2:
                        seq_time.append(int(
                            (_r['timestamp'] - firstTime) / 60))
                        seq_step.append(int(_r['step']))

                    seq_item.append(_item)
                    seq_position.append(_position)
                    seq_action.append(_action)
                    seq_price.append(_price)

                    seq_items.append(seq_item[:])
                    seq_positions.append(seq_position[:])
                    seq_actions.append(seq_action[:])
                    seq_prices.append(seq_price[:])
                    seq_times.append(seq_time[:])
                    seq_steps.append(seq_step[:])

                    # sample negative instance from impressions
                    pool = impressions if len(impressions) > 1 else np.arange(
                        len(self.item_index)).tolist()
                    sample = negative_sample(pool, gtItem)

                    action, position, price = self.get_features(
                        sample, impressions, prices, lastRow)

                    # clone seq inputs and change last element to from negative instance
                    nseq_item = seq_item[:]
                    nseq_item[-1] = sample
                    nseq_position = seq_position[:]
                    nseq_position[-1] = position
                    nseq_action = seq_action[:]
                    nseq_price = seq_price[:]
                    nseq_price[-1] = price
                    nseq_time = seq_time[:]
                    nseq_step = seq_step[:]

                    nseq_items.append(nseq_item[:])
                    nseq_positions.append(nseq_position[:])
                    nseq_actions.append(nseq_action[:])
                    nseq_prices.append(nseq_price[:])
                    nseq_times.append(nseq_time[:])
                    nseq_steps.append(nseq_step[:])

                seq_items = pad_sequences(seq_items, maxlen=self.maxlen)
                seq_positions = pad_sequences(seq_positions,
                                              maxlen=self.maxlen)
                seq_actions = pad_sequences(seq_actions, maxlen=self.maxlen)
                seq_prices = pad_sequences(seq_prices, maxlen=self.maxlen)
                seq_times = pad_sequences(seq_times, maxlen=self.maxlen)
                seq_steps = pad_sequences(seq_steps, maxlen=self.maxlen)

                nseq_items = pad_sequences(nseq_items, maxlen=self.maxlen)
                nseq_positions = pad_sequences(nseq_positions,
                                               maxlen=self.maxlen)
                nseq_actions = pad_sequences(nseq_actions, maxlen=self.maxlen)
                nseq_prices = pad_sequences(nseq_prices, maxlen=self.maxlen)
                nseq_times = pad_sequences(nseq_times, maxlen=self.maxlen)
                nseq_steps = pad_sequences(nseq_steps, maxlen=self.maxlen)

                if self.positionMode == 1:
                    seq_positions = np.expand_dims(seq_positions, axis=-1)
                    nseq_positions = np.expand_dims(nseq_positions, axis=-1)

                seq_prices = np.expand_dims(seq_prices, axis=-1)
                nseq_prices = np.expand_dims(nseq_prices, axis=-1)
                seq_times = np.expand_dims(seq_times, axis=-1)
                nseq_times = np.expand_dims(nseq_times, axis=-1)
                seq_steps = np.expand_dims(seq_steps, axis=-1)
                nseq_steps = np.expand_dims(nseq_steps, axis=-1)

                x = [
                    seq_items, seq_actions, seq_positions, seq_prices,
                    seq_times, seq_steps, nseq_items, nseq_actions,
                    nseq_positions, nseq_prices, nseq_times, nseq_steps
                ]
                y = np.ones(seq_items.shape[0])
                # print(seq_items.shape)
                # print(seq_items)
                yield (x, y)
Esempio n. 6
0
    def generate_data(self, df, mode="train"):

        sessions, itemIds, seq_items, seq_actions, seq_prices, seq_positions, labels = [], [], [], [], [], [], []
        nseq_items, nseq_actions, nseq_prices, nseq_positions = [], [], [], []
        seq_times, seq_steps = [], []
        nseq_times, nseq_steps = [], []

        for idx, rows in tqdm(df.groupby("session_id")):
            seq_item, seq_action, seq_price, seq_position, seq_time, seq_step = [], [], [], [], [], []
            lastRow = rows.iloc[-1]

            if lastRow["action_type"] != "clickout item":
                continue

            if mode == "val":
                if type(lastRow['reference']) == float:
                    continue
            elif mode == "test":
                if type(lastRow['reference']) != float:
                    continue

            impressions = [
                self.item_index[int(i)]
                for i in lastRow['impressions'].split("|")
            ]
            prices = [int(i) for i in lastRow['prices'].split("|")]

            if mode == "train":

                firstTime = rows.iloc[0]['timestamp']
                # print((row['timestamp'] - firstTime).total_seconds())

                gtItem = self.item_index[int(lastRow['reference'])]
                for _i, _r in rows.iterrows():
                    _item = self.item_index[int(_r['reference'])]
                    _action, _position, _price = self.get_features(
                        _item, impressions, prices, _r)

                    seq_item.append(_item)
                    seq_position.append(_position)
                    seq_action.append(_action)
                    seq_price.append(_price)
                    if self.timeMode == 2:
                        seq_time.append(int(
                            (_r['timestamp'] - firstTime) / 60))
                        seq_step.append(int(_r['step']))

                    seq_items.append(seq_item[:])
                    seq_positions.append(seq_position[:])
                    seq_actions.append(seq_action[:])
                    seq_prices.append(seq_price[:])
                    seq_times.append(seq_time[:])
                    seq_steps.append(seq_step[:])

                    # sample negative instance from impressions
                    pool = impressions if len(impressions) > 1 else np.arange(
                        len(self.item_index)).tolist()
                    sample = negative_sample(pool, gtItem)

                    action, position, price = self.get_features(
                        sample, impressions, prices, lastRow)

                    # clone seq inputs and change last element to from negative instance
                    nseq_item = seq_item[:]
                    nseq_item[-1] = sample
                    nseq_position = seq_position[:]
                    nseq_position[-1] = position
                    nseq_action = seq_action[:]
                    nseq_price = seq_price[:]
                    nseq_price[-1] = price
                    nseq_time = seq_time[:]
                    nseq_step = seq_step[:]

                    nseq_items.append(nseq_item[:])
                    nseq_positions.append(nseq_position[:])
                    nseq_actions.append(nseq_action[:])
                    nseq_prices.append(nseq_price[:])
                    nseq_times.append(nseq_time[:])
                    nseq_steps.append(nseq_step[:])

            else:

                firstTime = rows.iloc[0]['timestamp']
                for _i, _r in rows.iterrows():
                    _item = self.item_index[int(_r['reference'])] if type(
                        _r['reference']) == str else 0
                    _action, _position, _price = self.get_features(
                        _item, impressions, prices, _r)
                    seq_item.append(_item)
                    seq_position.append(_position)
                    seq_action.append(_action)
                    seq_price.append(_price)
                    if self.timeMode == 2:
                        seq_time.append(int(
                            (_r['timestamp'] - firstTime) / 60))
                        seq_step.append(int(_r['step']))

                if mode == "val":
                    gtItem = self.item_index[int(lastRow['reference'])]

                for position, (item,
                               price) in enumerate(zip(impressions, prices)):

                    _seq_item = seq_item[:]
                    _seq_item[-1] = item
                    _seq_position = seq_position[:]
                    _seq_position[-1] = position + 1
                    _seq_action = seq_action[:]
                    _seq_price = seq_price[:]
                    _seq_price[-1] = price

                    seq_items.append(_seq_item)
                    seq_positions.append(_seq_position)
                    seq_actions.append(_seq_action)
                    seq_prices.append(_seq_price)

                    if self.timeMode == 2:
                        seq_times.append(seq_time)
                        seq_steps.append(seq_step)

                    if mode == "val":
                        labels.append(1 if item == gtItem else 0)

                sessions.extend([lastRow['session_id']] * len(impressions))
                itemIds.extend([i for i in lastRow['impressions'].split("|")])

        seq_items = pad_sequences(seq_items, maxlen=self.maxlen)
        seq_positions = pad_sequences(seq_positions, maxlen=self.maxlen)
        seq_actions = pad_sequences(seq_actions, maxlen=self.maxlen)
        seq_prices = pad_sequences(seq_prices, maxlen=self.maxlen)
        seq_times = pad_sequences(seq_times, maxlen=self.maxlen)
        seq_steps = pad_sequences(seq_steps, maxlen=self.maxlen)

        nseq_items = pad_sequences(nseq_items, maxlen=self.maxlen)
        nseq_positions = pad_sequences(nseq_positions, maxlen=self.maxlen)
        nseq_actions = pad_sequences(nseq_actions, maxlen=self.maxlen)
        nseq_prices = pad_sequences(nseq_prices, maxlen=self.maxlen)
        nseq_times = pad_sequences(nseq_times, maxlen=self.maxlen)
        nseq_steps = pad_sequences(nseq_steps, maxlen=self.maxlen)

        if self.positionMode == 1:
            seq_positions = np.expand_dims(seq_positions, axis=-1)
            nseq_positions = np.expand_dims(nseq_positions, axis=-1)

        seq_prices = np.expand_dims(seq_prices, axis=-1)
        nseq_prices = np.expand_dims(nseq_prices, axis=-1)
        seq_times = np.expand_dims(seq_times, axis=-1)
        nseq_times = np.expand_dims(nseq_times, axis=-1)
        seq_steps = np.expand_dims(seq_steps, axis=-1)
        nseq_steps = np.expand_dims(nseq_steps, axis=-1)

        if mode == "train":
            if self.timeMode == 1:
                x = [
                    seq_items, seq_actions, seq_positions, seq_prices,
                    nseq_items, nseq_actions, nseq_positions, nseq_prices
                ]
            else:
                x = [
                    seq_items, seq_actions, seq_positions, seq_prices,
                    seq_times, seq_steps, nseq_items, nseq_actions,
                    nseq_positions, nseq_prices, nseq_times, nseq_steps
                ]
            # for i in x:
            #     print(i)
            y = np.ones(seq_items.shape[0])
            return x, y
        elif mode == "val":
            if self.timeMode == 1:
                x = [seq_items, seq_actions, seq_positions, seq_prices]
            else:
                x = [
                    seq_items, seq_actions, seq_positions, seq_prices,
                    seq_times, seq_steps
                ]
            y = np.array(labels)
            return sessions, x, y
        else:
            if self.timeMode == 1:
                x = [seq_items, seq_actions, seq_positions, seq_prices]
            else:
                x = [
                    seq_items, seq_actions, seq_positions, seq_prices,
                    seq_times, seq_steps
                ]
            return sessions, itemIds, x
Esempio n. 7
0
    def generate_train_data(self, df):

        pos, neg, seq, pos_feature, neg_feature, seq_feature, pos_price, neg_price, pos_position, neg_position = [], [], [], [], [], [], [], [], [], []
        for city, rows in tqdm(df.groupby("city")):
            for idx, row in rows.iterrows():
                impressions = [self.item_index[int(i)] for i in row['impressions'].split("|")]
                # prices = [int(i) for i in row['prices'].split("|")]

                gtItem = self.item_index[int(row['reference'])]
                pos.append(gtItem)
                #             pos_feature.append(get_item_feature(gtItem))
                # pos_price.append(prices[impressions.index(gtItem)] if gtItem in impressions else max(prices))
                # pos_position.append(impressions.index(gtItem) + 1 if gtItem in impressions else 26)

                interactions = [self.item_index[int(i)] for i in row['interactions'].split("|")] if type(
                    row['interactions']) == str else []

                if self.negSampleMode == "city":


                    if rows['reference'].nunique() > 1:
                        pool = [self.item_index[i] for i in rows['reference'].unique().tolist()]
                    else:
                        pool = impressions
                        if len(pool) == 1:
                            pool = np.arange(len(self.item_index)).tolist()

                elif self.negSampleMode == "imp":
                    pool = impressions
                    if len(pool) == 1:
                        pool = np.arange(len(self.item_index)).tolist()

                elif self.negSampleMode == "nce":
                    pool = impressions
                    if len(pool) == 1:
                        pool = np.arange(len(self.item_index)).tolist()
                    else:
                        if gtItem in pool:
                            pool.remove(gtItem)
                        tmpSeq = [interactions] * len(pool)
                        tmpSeq = pad_sequences(tmpSeq, maxlen=self.maxlen)
                        # print(pool)

                        pool = np.expand_dims(pool, axis=-1)
                        pred = self.predict([np.array(pool), tmpSeq])
                        pool = [[x for _, x in sorted(zip(pred, pool))][-1]]

                sample = negative_sample(pool, gtItem)
                neg.append(sample)
                #             neg_feature.append(get_item_feature(sample))
                # neg_price.append(prices[impressions.index(sample)] if sample in impressions else max(prices))
                # neg_position.append(impressions.index(sample) + 1 if sample in impressions else 26)

                seq.append(interactions)
                #             feature_interactions = [get_item_feature(item_index[int(i)]) for i in row['interactions'].split("|")] if type(row['interactions']) == str else []
                #             seq_feature.append(feature_interactions)

        pos = np.array(pos)
        neg = np.array(neg)
        seq = pad_sequences(seq, maxlen=self.maxlen)
        # pos_feature = np.array(pos_feature)
        # neg_feature = np.array(neg_feature)
        # seq_feature = pad_sequences(seq_feature, maxlen=self.maxlen)
        # pos_price = np.array(pos_price)
        # neg_price = np.array(neg_price)
        # pos_position = np.array(pos_position)
        # neg_position = np.array(neg_position)
        labels = np.ones(len(pos))

        return [pos, neg, seq], [labels]
    def generate_data(self, df, mode="train"):

        sessions, itemIds, cand_items, cand_actions, cand_cities, cand_positions, cand_prices, seq_items, seq_cities, seq_actions, seq_prices, seq_positions, labels = [], [], [], [], [], [], [], [], [], [], [], [], []
        for idx, rows in df.groupby("session_id"):
            seq_item, seq_city, seq_action, seq_price, seq_position = [], [], [], [], []
            lastRow = rows.iloc[-1]

            if lastRow["action_type"] != "clickout item":
                continue

            if mode == "val":
                if type(lastRow['reference']) == float:
                    continue
            elif mode == "test":
                if type(lastRow['reference']) != float:
                    continue

            histRows = rows.iloc[:-1]

            impressions = [self.item_index[int(i)] for i in lastRow['impressions'].split("|")]
            prices = [int(i) for i in lastRow['prices'].split("|")]
            if mode == "train":
                gtItem = self.item_index[int(lastRow['reference'])]
                action, city, position, price = self.get_features(gtItem, impressions, prices, lastRow)

            if len(histRows) > 0:
                for _i, _r in histRows.iterrows():
                    _item = self.item_index[int(_r['reference'])]

                    _action, _city, _position, _price = self.get_features(_item, impressions, prices, _r)

                    seq_item.append(_item)
                    seq_position.append(_position)
                    seq_city.append(_city)
                    seq_action.append(_action)
                    seq_price.append(_price)

            if mode == "train":

                seq_items.append(seq_item)
                seq_positions.append(seq_position)
                seq_cities.append(seq_city)
                seq_actions.append(seq_action)
                seq_prices.append(seq_price)
                labels.append(1)

                # sample negative instance from impressions
                pool = impressions if len(impressions) > 1 else np.arange(len(self.item_index)).tolist()
                sample = negative_sample(pool, gtItem)
                action, city, position, price = self.get_features(sample, impressions, prices, lastRow)

                cand_items.append(sample)
                cand_actions.append(action)
                cand_cities.append(city)
                cand_positions.append(position)
                cand_prices.append(price)
                seq_items.append(seq_item)
                seq_positions.append(seq_position)
                seq_cities.append(seq_city)
                seq_actions.append(seq_action)
                seq_prices.append(seq_price)
                labels.append(0)

            else:
                _action = self.action_index[lastRow['action_type']]
                _city = self.city_index[lastRow['city']]
                if mode == "val":
                    gtItem = self.item_index[int(lastRow['reference'])]

                for _position, (_item, _price) in enumerate(zip(impressions, prices)):
                    cand_items.append(_item)
                    cand_actions.append(_action)
                    cand_cities.append(_city)
                    cand_positions.append(_position+1)
                    cand_prices.append(_price)

                    seq_items.append(seq_item)
                    seq_positions.append(seq_position)
                    seq_cities.append(seq_city)
                    seq_actions.append(seq_action)
                    seq_prices.append(seq_price)

                    if mode == "val":
                        labels.append(1 if _item == gtItem else 0)

                sessions.extend([lastRow['session_id']]*len(impressions))
                itemIds.extend([i for i in lastRow['impressions'].split("|")])


        cand_items = np.array(cand_items)
        cand_positions = np.array(cand_positions)
        cand_cities = np.array(cand_cities)
        cand_actions = np.array(cand_actions)
        cand_prices = np.array(cand_prices)
        seq_items = pad_sequences(seq_items, maxlen=self.maxlen)
        seq_positions = pad_sequences(seq_positions, maxlen=self.maxlen)
        seq_cities = pad_sequences(seq_cities, maxlen=self.maxlen)
        seq_actions = pad_sequences(seq_actions, maxlen=self.maxlen)
        seq_prices = pad_sequences(seq_prices, maxlen=self.maxlen)
        labels = np.array(labels)


        feature_dict = {'item': cand_items, 'position': cand_positions, 'city': cand_cities, 'action': cand_actions,
                        'price': cand_prices,
                        'seq_item': seq_items, 'seq_position': seq_positions, 'seq_city': seq_cities,
                        'seq_action': seq_actions, 'price': cand_prices, 'seq_price': seq_prices}

        x = [feature_dict[feat.name] for feat in self.feature_dim_dict["sparse"]] + [feature_dict[feat.name] for feat in
                                                                                     self.feature_dim_dict["dense"]] + [
                feature_dict['seq_' + feat] for feat in self.behavior_feature_list]

        x += [np.arange(len(cand_items))]

        y = labels

        # for i in [cand_items, cand_positions, cand_cities, cand_actions, cand_prices, seq_items, seq_positions, seq_cities, seq_actions, seq_prices, labels]:
        #     print(i.shape)

        if mode == "train":
            return x, y
        elif mode == "val":
            return sessions, x, y
        else:
            return sessions, itemIds, x