예제 #1
0
파일: eval.py 프로젝트: yrpang/mindspore
def evaluation():
    """evaluation"""
    num_user = train_graph.graph_info()["node_num"][0]
    num_item = train_graph.graph_info()["node_num"][1]

    eval_class = BGCFEvaluate(parser, train_graph, test_graph, parser.Ks)
    for _epoch in range(parser.eval_interval, parser.num_epoch+1, parser.eval_interval) \
                  if parser.device_target == "Ascend" else range(parser.num_epoch, parser.num_epoch+1):
        bgcfnet_test = BGCF([parser.input_dim, num_user, num_item],
                            parser.embedded_dimension, parser.activation,
                            [0.0, 0.0, 0.0], num_user, num_item,
                            parser.input_dim)

        load_checkpoint(parser.ckptpath + "/bgcf_epoch{}.ckpt".format(_epoch),
                        net=bgcfnet_test)

        forward_net = ForwardBGCF(bgcfnet_test)
        user_reps, item_reps = TestBGCF(forward_net, num_user, num_item,
                                        parser.input_dim, test_graph_dataset)

        test_recall_bgcf, test_ndcg_bgcf, \
        test_sedp, test_nov = eval_class.eval_with_rep(user_reps, item_reps, parser)

        if parser.log_name:
            log.write(
                'epoch:%03d,      recall_@10:%.5f,     recall_@20:%.5f,     ndcg_@10:%.5f,    ndcg_@20:%.5f,   '
                'sedp_@10:%.5f,     sedp_@20:%.5f,    nov_@10:%.5f,    nov_@20:%.5f\n'
                % (_epoch, test_recall_bgcf[1], test_recall_bgcf[2],
                   test_ndcg_bgcf[1], test_ndcg_bgcf[2], test_sedp[0],
                   test_sedp[1], test_nov[1], test_nov[2]))
        else:
            print(
                'epoch:%03d,      recall_@10:%.5f,     recall_@20:%.5f,     ndcg_@10:%.5f,    ndcg_@20:%.5f,   '
                'sedp_@10:%.5f,     sedp_@20:%.5f,    nov_@10:%.5f,    nov_@20:%.5f\n'
                % (_epoch, test_recall_bgcf[1], test_recall_bgcf[2],
                   test_ndcg_bgcf[1], test_ndcg_bgcf[2], test_sedp[0],
                   test_sedp[1], test_nov[1], test_nov[2]))
예제 #2
0
parser.add_argument("--activation",
                    type=str,
                    default="tanh",
                    choices=["relu", "tanh"],
                    help="activation function")
args = parser.parse_args()

context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.device_target == "Ascend":
    context.set_context(device_id=args.device_id)

if __name__ == "__main__":
    num_user, num_item = 7068, 3570

    network = BGCF([args.input_dim, num_user, num_item],
                   args.embedded_dimension, args.activation, [0.0, 0.0, 0.0],
                   num_user, num_item, args.input_dim)

    load_checkpoint(args.ckpt_file, net=network)

    forward_net = ForwardBGCF(network)

    users = Tensor(np.zeros([
        num_user,
    ]).astype(np.int32))
    items = Tensor(np.zeros([
        num_item,
    ]).astype(np.int32))
    neg_items = Tensor(np.zeros([num_item, 1]).astype(np.int32))
    u_test_neighs = Tensor(
        np.zeros([num_user, args.row_neighs]).astype(np.int32))
예제 #3
0
def train():
    """Train"""
    num_user = train_graph.graph_info()["node_num"][0]
    num_item = train_graph.graph_info()["node_num"][1]
    num_pairs = train_graph.graph_info()['edge_num'][0]

    bgcfnet = BGCF([parser.input_dim, num_user, num_item],
                   parser.embedded_dimension, parser.activation,
                   parser.neighbor_dropout, num_user, num_item,
                   parser.input_dim)

    train_net = TrainBGCF(bgcfnet, parser.num_neg, parser.l2,
                          parser.learning_rate, parser.epsilon,
                          parser.dist_reg)
    train_net.set_train(True)

    itr = train_ds.create_dict_iterator(parser.num_epoch, output_numpy=True)
    num_iter = int(num_pairs / parser.batch_pairs)

    for _epoch in range(1, parser.num_epoch + 1):

        epoch_start = time.time()
        iter_num = 1

        for data in itr:

            u_id = Tensor(data["users"], mstype.int32)
            pos_item_id = Tensor(convert_item_id(data["items"], num_user),
                                 mstype.int32)
            neg_item_id = Tensor(
                convert_item_id(data["neg_item_id"], num_user), mstype.int32)
            pos_users = Tensor(data["pos_users"], mstype.int32)
            pos_items = Tensor(convert_item_id(data["pos_items"], num_user),
                               mstype.int32)

            u_group_nodes = Tensor(data["u_group_nodes"], mstype.int32)
            u_neighs = Tensor(convert_item_id(data["u_neighs"], num_user),
                              mstype.int32)
            u_gnew_neighs = Tensor(
                convert_item_id(data["u_gnew_neighs"], num_user), mstype.int32)

            i_group_nodes = Tensor(
                convert_item_id(data["i_group_nodes"], num_user), mstype.int32)
            i_neighs = Tensor(data["i_neighs"], mstype.int32)
            i_gnew_neighs = Tensor(data["i_gnew_neighs"], mstype.int32)

            neg_group_nodes = Tensor(
                convert_item_id(data["neg_group_nodes"], num_user),
                mstype.int32)
            neg_neighs = Tensor(data["neg_neighs"], mstype.int32)
            neg_gnew_neighs = Tensor(data["neg_gnew_neighs"], mstype.int32)

            train_loss = train_net(u_id, pos_item_id, neg_item_id, pos_users,
                                   pos_items, u_group_nodes, u_neighs,
                                   u_gnew_neighs, i_group_nodes, i_neighs,
                                   i_gnew_neighs, neg_group_nodes, neg_neighs,
                                   neg_gnew_neighs)

            if iter_num == num_iter:
                print(
                    'Epoch', '%03d' % _epoch, 'iter', '%02d' % iter_num,
                    'loss',
                    '{}, cost:{:.4f}'.format(train_loss,
                                             time.time() - epoch_start))
            iter_num += 1

        if _epoch % parser.eval_interval == 0:
            save_checkpoint(
                bgcfnet, parser.ckptpath + "/bgcf_epoch{}.ckpt".format(_epoch))
예제 #4
0
def train_and_eval():
    """Train and eval"""
    num_user = train_graph.graph_info()["node_num"][0]
    num_item = train_graph.graph_info()["node_num"][1]
    num_pairs = train_graph.graph_info()['edge_num'][0]

    bgcfnet = BGCF([parser.input_dim, num_user, num_item],
                   parser.embedded_dimension, parser.activation,
                   parser.neighbor_dropout, num_user, num_item,
                   parser.input_dim)

    train_net = TrainBGCF(bgcfnet, parser.num_neg, parser.l2,
                          parser.learning_rate, parser.epsilon,
                          parser.dist_reg)
    train_net.set_train(True)

    eval_class = BGCFEvaluate(parser, train_graph, test_graph, parser.Ks)

    itr = train_ds.create_dict_iterator(parser.num_epoch, output_numpy=True)
    num_iter = int(num_pairs / parser.batch_pairs)

    for _epoch in range(1, parser.num_epoch + 1):

        epoch_start = time.time()
        iter_num = 1

        for data in itr:

            u_id = Tensor(data["users"], mstype.int32)
            pos_item_id = Tensor(convert_item_id(data["items"], num_user),
                                 mstype.int32)
            neg_item_id = Tensor(
                convert_item_id(data["neg_item_id"], num_user), mstype.int32)
            pos_users = Tensor(data["pos_users"], mstype.int32)
            pos_items = Tensor(convert_item_id(data["pos_items"], num_user),
                               mstype.int32)

            u_group_nodes = Tensor(data["u_group_nodes"], mstype.int32)
            u_neighs = Tensor(convert_item_id(data["u_neighs"], num_user),
                              mstype.int32)
            u_gnew_neighs = Tensor(
                convert_item_id(data["u_gnew_neighs"], num_user), mstype.int32)

            i_group_nodes = Tensor(
                convert_item_id(data["i_group_nodes"], num_user), mstype.int32)
            i_neighs = Tensor(data["i_neighs"], mstype.int32)
            i_gnew_neighs = Tensor(data["i_gnew_neighs"], mstype.int32)

            neg_group_nodes = Tensor(
                convert_item_id(data["neg_group_nodes"], num_user),
                mstype.int32)
            neg_neighs = Tensor(data["neg_neighs"], mstype.int32)
            neg_gnew_neighs = Tensor(data["neg_gnew_neighs"], mstype.int32)

            train_loss = train_net(u_id, pos_item_id, neg_item_id, pos_users,
                                   pos_items, u_group_nodes, u_neighs,
                                   u_gnew_neighs, i_group_nodes, i_neighs,
                                   i_gnew_neighs, neg_group_nodes, neg_neighs,
                                   neg_gnew_neighs)

            if iter_num == num_iter:
                print(
                    'Epoch', '%03d' % _epoch, 'iter', '%02d' % iter_num,
                    'loss',
                    '{}, cost:{:.4f}'.format(train_loss,
                                             time.time() - epoch_start))
            iter_num += 1

        if _epoch % parser.eval_interval == 0:
            if os.path.exists("ckpts/bgcf.ckpt"):
                os.remove("ckpts/bgcf.ckpt")
            save_checkpoint(bgcfnet, "ckpts/bgcf.ckpt")

            bgcfnet_test = BGCF([parser.input_dim, num_user, num_item],
                                parser.embedded_dimension, parser.activation,
                                [0.0, 0.0, 0.0], num_user, num_item,
                                parser.input_dim)

            load_checkpoint("ckpts/bgcf.ckpt", net=bgcfnet_test)

            forward_net = ForwardBGCF(bgcfnet_test)
            user_reps, item_reps = TestBGCF(forward_net, num_user, num_item,
                                            parser.input_dim,
                                            test_graph_dataset)

            test_recall_bgcf, test_ndcg_bgcf, \
            test_sedp, test_nov = eval_class.eval_with_rep(user_reps, item_reps, parser)

            if parser.log_name:
                log.write(
                    'epoch:%03d,      recall_@10:%.5f,     recall_@20:%.5f,     ndcg_@10:%.5f,    ndcg_@20:%.5f,   '
                    'sedp_@10:%.5f,     sedp_@20:%.5f,    nov_@10:%.5f,    nov_@20:%.5f\n'
                    % (_epoch, test_recall_bgcf[1], test_recall_bgcf[2],
                       test_ndcg_bgcf[1], test_ndcg_bgcf[2], test_sedp[0],
                       test_sedp[1], test_nov[1], test_nov[2]))
            else:
                print(
                    'epoch:%03d,      recall_@10:%.5f,     recall_@20:%.5f,     ndcg_@10:%.5f,    ndcg_@20:%.5f,   '
                    'sedp_@10:%.5f,     sedp_@20:%.5f,    nov_@10:%.5f,    nov_@20:%.5f\n'
                    % (_epoch, test_recall_bgcf[1], test_recall_bgcf[2],
                       test_ndcg_bgcf[1], test_ndcg_bgcf[2], test_sedp[0],
                       test_sedp[1], test_nov[1], test_nov[2]))
예제 #5
0
def bgcf(*args, **kwargs):
    return BGCF(*args, **kwargs)