示例#1
0
def TestBGCF(forward_net, num_user, num_item, input_dim, test_graph_dataset):
    """BGCF test wrapper"""
    user_reps = np.zeros([num_user, input_dim * 3])
    item_reps = np.zeros([num_item, input_dim * 3])

    for _ in range(50):
        test_graph_dataset.random_select_sampled_graph()
        u_test_neighs, u_test_gnew_neighs = test_graph_dataset.get_user_sapmled_neighbor(
        )
        i_test_neighs, i_test_gnew_neighs = test_graph_dataset.get_item_sampled_neighbor(
        )

        u_test_neighs = Tensor(convert_item_id(u_test_neighs, num_user),
                               mstype.int32)
        u_test_gnew_neighs = Tensor(
            convert_item_id(u_test_gnew_neighs, num_user), mstype.int32)
        i_test_neighs = Tensor(i_test_neighs, mstype.int32)
        i_test_gnew_neighs = Tensor(i_test_gnew_neighs, mstype.int32)

        users = Tensor(np.arange(num_user).reshape(-1, ), mstype.int32)
        items = Tensor(np.arange(num_item).reshape(-1, ), mstype.int32)
        neg_items = Tensor(np.arange(num_item).reshape(-1, 1), mstype.int32)

        user_rep, item_rep = forward_net(users, items, neg_items,
                                         u_test_neighs, u_test_gnew_neighs,
                                         i_test_neighs, i_test_gnew_neighs)

        user_reps += user_rep.asnumpy()
        item_reps += item_rep.asnumpy()

    user_reps /= 50
    item_reps /= 50
    return user_reps, item_reps
示例#2
0
    def __init__(self, parser, train_graph, test_graph, Ks):
        self.num_user = train_graph.graph_info()["node_num"][0]
        self.num_item = train_graph.graph_info()["node_num"][1]
        self.Ks = Ks

        self.test_set = []
        self.train_set = []
        for i in range(0, self.num_user):
            train_item = train_graph.get_all_neighbors(node_list=[i],
                                                       neighbor_type=1)
            train_item = train_item[1:]
            self.train_set.append(train_item)
        for i in range(0, self.num_user):
            test_item = test_graph.get_all_neighbors(node_list=[i],
                                                     neighbor_type=1)
            test_item = test_item[1:]
            self.test_set.append(test_item)
        self.train_set = convert_item_id(self.train_set,
                                         self.num_user).tolist()
        self.test_set = convert_item_id(self.test_set, self.num_user).tolist()

        self.item_deg_dict = {}
        self.item_full_set = []
        for i in range(self.num_user, self.num_user + self.num_item):
            train_users = train_graph.get_all_neighbors(node_list=[i],
                                                        neighbor_type=0)
            train_users = train_users.tolist()
            if isinstance(train_users, int):
                train_users = []
            else:
                train_users = train_users[1:]
            self.item_deg_dict[i - self.num_user] = len(train_users)
            test_users = test_graph.get_all_neighbors(node_list=[i],
                                                      neighbor_type=0)
            test_users = test_users.tolist()
            if isinstance(test_users, int):
                test_users = []
            else:
                test_users = test_users[1:]
            self.item_full_set.append(train_users + test_users)
示例#3
0
def train():
    """Train"""
    num_user = train_graph.graph_info()["node_num"][0]
    num_item = train_graph.graph_info()["node_num"][1]
    num_pairs = train_graph.graph_info()['edge_num'][0]

    bgcfnet = BGCF([parser.input_dim, num_user, num_item],
                   parser.embedded_dimension, parser.activation,
                   parser.neighbor_dropout, num_user, num_item,
                   parser.input_dim)

    train_net = TrainBGCF(bgcfnet, parser.num_neg, parser.l2,
                          parser.learning_rate, parser.epsilon,
                          parser.dist_reg)
    train_net.set_train(True)

    itr = train_ds.create_dict_iterator(parser.num_epoch, output_numpy=True)
    num_iter = int(num_pairs / parser.batch_pairs)

    for _epoch in range(1, parser.num_epoch + 1):

        epoch_start = time.time()
        iter_num = 1

        for data in itr:

            u_id = Tensor(data["users"], mstype.int32)
            pos_item_id = Tensor(convert_item_id(data["items"], num_user),
                                 mstype.int32)
            neg_item_id = Tensor(
                convert_item_id(data["neg_item_id"], num_user), mstype.int32)
            pos_users = Tensor(data["pos_users"], mstype.int32)
            pos_items = Tensor(convert_item_id(data["pos_items"], num_user),
                               mstype.int32)

            u_group_nodes = Tensor(data["u_group_nodes"], mstype.int32)
            u_neighs = Tensor(convert_item_id(data["u_neighs"], num_user),
                              mstype.int32)
            u_gnew_neighs = Tensor(
                convert_item_id(data["u_gnew_neighs"], num_user), mstype.int32)

            i_group_nodes = Tensor(
                convert_item_id(data["i_group_nodes"], num_user), mstype.int32)
            i_neighs = Tensor(data["i_neighs"], mstype.int32)
            i_gnew_neighs = Tensor(data["i_gnew_neighs"], mstype.int32)

            neg_group_nodes = Tensor(
                convert_item_id(data["neg_group_nodes"], num_user),
                mstype.int32)
            neg_neighs = Tensor(data["neg_neighs"], mstype.int32)
            neg_gnew_neighs = Tensor(data["neg_gnew_neighs"], mstype.int32)

            train_loss = train_net(u_id, pos_item_id, neg_item_id, pos_users,
                                   pos_items, u_group_nodes, u_neighs,
                                   u_gnew_neighs, i_group_nodes, i_neighs,
                                   i_gnew_neighs, neg_group_nodes, neg_neighs,
                                   neg_gnew_neighs)

            if iter_num == num_iter:
                print(
                    'Epoch', '%03d' % _epoch, 'iter', '%02d' % iter_num,
                    'loss',
                    '{}, cost:{:.4f}'.format(train_loss,
                                             time.time() - epoch_start))
            iter_num += 1

        if _epoch % parser.eval_interval == 0:
            save_checkpoint(
                bgcfnet, parser.ckptpath + "/bgcf_epoch{}.ckpt".format(_epoch))
示例#4
0
def train_and_eval():
    """Train and eval"""
    num_user = train_graph.graph_info()["node_num"][0]
    num_item = train_graph.graph_info()["node_num"][1]
    num_pairs = train_graph.graph_info()['edge_num'][0]

    bgcfnet = BGCF([parser.input_dim, num_user, num_item],
                   parser.embedded_dimension, parser.activation,
                   parser.neighbor_dropout, num_user, num_item,
                   parser.input_dim)

    train_net = TrainBGCF(bgcfnet, parser.num_neg, parser.l2,
                          parser.learning_rate, parser.epsilon,
                          parser.dist_reg)
    train_net.set_train(True)

    eval_class = BGCFEvaluate(parser, train_graph, test_graph, parser.Ks)

    itr = train_ds.create_dict_iterator(parser.num_epoch, output_numpy=True)
    num_iter = int(num_pairs / parser.batch_pairs)

    for _epoch in range(1, parser.num_epoch + 1):

        epoch_start = time.time()
        iter_num = 1

        for data in itr:

            u_id = Tensor(data["users"], mstype.int32)
            pos_item_id = Tensor(convert_item_id(data["items"], num_user),
                                 mstype.int32)
            neg_item_id = Tensor(
                convert_item_id(data["neg_item_id"], num_user), mstype.int32)
            pos_users = Tensor(data["pos_users"], mstype.int32)
            pos_items = Tensor(convert_item_id(data["pos_items"], num_user),
                               mstype.int32)

            u_group_nodes = Tensor(data["u_group_nodes"], mstype.int32)
            u_neighs = Tensor(convert_item_id(data["u_neighs"], num_user),
                              mstype.int32)
            u_gnew_neighs = Tensor(
                convert_item_id(data["u_gnew_neighs"], num_user), mstype.int32)

            i_group_nodes = Tensor(
                convert_item_id(data["i_group_nodes"], num_user), mstype.int32)
            i_neighs = Tensor(data["i_neighs"], mstype.int32)
            i_gnew_neighs = Tensor(data["i_gnew_neighs"], mstype.int32)

            neg_group_nodes = Tensor(
                convert_item_id(data["neg_group_nodes"], num_user),
                mstype.int32)
            neg_neighs = Tensor(data["neg_neighs"], mstype.int32)
            neg_gnew_neighs = Tensor(data["neg_gnew_neighs"], mstype.int32)

            train_loss = train_net(u_id, pos_item_id, neg_item_id, pos_users,
                                   pos_items, u_group_nodes, u_neighs,
                                   u_gnew_neighs, i_group_nodes, i_neighs,
                                   i_gnew_neighs, neg_group_nodes, neg_neighs,
                                   neg_gnew_neighs)

            if iter_num == num_iter:
                print(
                    'Epoch', '%03d' % _epoch, 'iter', '%02d' % iter_num,
                    'loss',
                    '{}, cost:{:.4f}'.format(train_loss,
                                             time.time() - epoch_start))
            iter_num += 1

        if _epoch % parser.eval_interval == 0:
            if os.path.exists("ckpts/bgcf.ckpt"):
                os.remove("ckpts/bgcf.ckpt")
            save_checkpoint(bgcfnet, "ckpts/bgcf.ckpt")

            bgcfnet_test = BGCF([parser.input_dim, num_user, num_item],
                                parser.embedded_dimension, parser.activation,
                                [0.0, 0.0, 0.0], num_user, num_item,
                                parser.input_dim)

            load_checkpoint("ckpts/bgcf.ckpt", net=bgcfnet_test)

            forward_net = ForwardBGCF(bgcfnet_test)
            user_reps, item_reps = TestBGCF(forward_net, num_user, num_item,
                                            parser.input_dim,
                                            test_graph_dataset)

            test_recall_bgcf, test_ndcg_bgcf, \
            test_sedp, test_nov = eval_class.eval_with_rep(user_reps, item_reps, parser)

            if parser.log_name:
                log.write(
                    'epoch:%03d,      recall_@10:%.5f,     recall_@20:%.5f,     ndcg_@10:%.5f,    ndcg_@20:%.5f,   '
                    'sedp_@10:%.5f,     sedp_@20:%.5f,    nov_@10:%.5f,    nov_@20:%.5f\n'
                    % (_epoch, test_recall_bgcf[1], test_recall_bgcf[2],
                       test_ndcg_bgcf[1], test_ndcg_bgcf[2], test_sedp[0],
                       test_sedp[1], test_nov[1], test_nov[2]))
            else:
                print(
                    'epoch:%03d,      recall_@10:%.5f,     recall_@20:%.5f,     ndcg_@10:%.5f,    ndcg_@20:%.5f,   '
                    'sedp_@10:%.5f,     sedp_@20:%.5f,    nov_@10:%.5f,    nov_@20:%.5f\n'
                    % (_epoch, test_recall_bgcf[1], test_recall_bgcf[2],
                       test_ndcg_bgcf[1], test_ndcg_bgcf[2], test_sedp[0],
                       test_sedp[1], test_nov[1], test_nov[2]))