Esempio n. 1
0
def train(args, data, show_loss, show_topk):

    var_to_restore = [
        "user_emb_matrix", "item_emb_matrix", "relation_emb_matrix",
        "entity_emb_matrix"
    ]
    n_user, n_item, n_entity, n_relation = data[0], data[1], data[2], data[3]
    print(n_entity)
    train_data, eval_data, test_data = data[4], data[5], data[6]
    kg = data[7]

    # top-K evaluation settings
    user_num = 100
    k_list = [1, 2, 5, 10, 20, 50, 100]
    train_record = get_user_record(train_data, True)
    test_record = get_user_record(test_data, False)
    user_list = list(set(train_record.keys()) & set(test_record.keys()))
    if len(user_list) > user_num:
        user_list = np.random.choice(user_list, size=user_num, replace=False)
    item_set = set(list(range(n_item)))

    export_version = int(time.time())

    try:
        # Load the latest model
        restore_path = max(
            os.listdir('./model/' + args.dataset + '/' + args.restore))
    except:
        restore_path = None

    model = MKR(args, n_user, n_item, n_entity, n_relation, restore_path)

    with tf.Session() as sess:
        if restore_path is None:

            sess.run(tf.global_variables_initializer())
        else:
            # Weight shift, If new users or movies join
            sess.run(tf.global_variables_initializer())
            user_emb = np.loadtxt('./model/' + args.dataset +
                                  '/vocab/user_emb_matrix.txt',
                                  dtype=np.float32)
            item_emb = np.loadtxt('./model/' + args.dataset +
                                  '/vocab/item_emb_matrix.txt',
                                  dtype=np.float32)
            entity_emb = np.loadtxt('./model/' + args.dataset +
                                    '/vocab/entity_emb_matrix.txt',
                                    dtype=np.float32)
            relation_emb = np.loadtxt('./model/' + args.dataset +
                                      '/vocab/relation_emb_matrix.txt',
                                      dtype=np.float32)
            print(n_user, n_user - len(user_emb))
            user_emb = np.vstack([
                user_emb,
                np.random.normal(size=[n_user - len(user_emb), args.dim])
            ])

            item_emb = np.vstack([
                item_emb,
                np.random.normal(size=[n_item - len(item_emb), args.dim])
            ])
            entity_emb = np.vstack([
                entity_emb,
                np.random.normal(size=[n_entity - len(entity_emb), args.dim])
            ])
            relation_emb = np.vstack([
                relation_emb,
                np.random.normal(
                    size=[n_relation - len(relation_emb), args.dim])
            ])

            var_to_restore = slim.get_variables_to_restore(
                exclude=var_to_restore)
            saver = tf.train.Saver(var_to_restore)
            saver.restore(
                sess,
                tf.train.latest_checkpoint('./model/' + args.dataset + '/' +
                                           args.restore + '/' + restore_path))
            model.init_embeding(
                sess, {
                    model.user_emb: user_emb,
                    model.item_emb: item_emb,
                    model.entity_emb: entity_emb,
                    model.relation_emb: relation_emb
                })

        for step in range(args.n_epochs):
            # RS training
            np.random.shuffle(train_data)
            start = 0
            while start < train_data.shape[0]:
                _, loss = model.train_rs(
                    sess,
                    get_feed_dict_for_rs(model, train_data, start,
                                         start + args.batch_size))
                start += args.batch_size
                if show_loss:
                    print(loss)

            # KGE training
            if step % args.kge_interval == 0:
                np.random.shuffle(kg)
                start = 0
                while start < kg.shape[0]:
                    _, rmse = model.train_kge(
                        sess,
                        get_feed_dict_for_kge(model, kg, start,
                                              start + args.batch_size))
                    start += args.batch_size
                    if show_loss:
                        print(rmse)

            # CTR evaluation
            train_auc, train_acc = model.eval(
                sess,
                get_feed_dict_for_rs(model, train_data, 0,
                                     train_data.shape[0]))
            eval_auc, eval_acc = model.eval(
                sess,
                get_feed_dict_for_rs(model, eval_data, 0, eval_data.shape[0]))
            test_auc, test_acc = model.eval(
                sess,
                get_feed_dict_for_rs(model, test_data, 0, test_data.shape[0]))

            print(
                'epoch %d    train auc: %.4f  acc: %.4f    eval auc: %.4f  acc: %.4f     test auc: %.4f  acc: %.4f'
                % (step, train_auc, train_acc, eval_auc, eval_acc, test_auc,
                   test_acc))

            # top-K evaluation
            if show_topk:
                precision, recall, f1 = topk_eval(sess, model, user_list,
                                                  train_record, test_record,
                                                  item_set, k_list)
                print('precision: ', end='')
                for i in precision:
                    print('%.4f\t' % i, end='')
                print()
                print('recall: ', end='')
                for i in recall:
                    print('%.4f\t' % i, end='')
                print()
                print('f1: ', end='')
                for i in f1:
                    print('%.4f\t' % i, end='')
                print('\n')

        # save embedding
        np.savetxt('./model/' + args.dataset + '/vocab/user_emb_matrix.txt',
                   model.user_emb_matrix.eval())
        np.savetxt('./model/' + args.dataset + '/vocab/item_emb_matrix.txt',
                   model.item_emb_matrix.eval())
        np.savetxt('./model/' + args.dataset + '/vocab/entity_emb_matrix.txt',
                   model.entity_emb_matrix.eval())
        np.savetxt(
            './model/' + args.dataset + '/vocab/relation_emb_matrix.txt',
            model.relation_emb_matrix.eval())

        # Model save recovery save/restore method
        saver = tf.train.Saver()
        wts_name = './model/' + args.dataset + '/restore' + "/{}/mkr.ckpt".format(
            export_version)
        saver.save(sess, wts_name)

        # save .pd ,deploy with tensorFlow Serving
        inputs = {
            "user_id": model.user_indices,
            "item_id": model.item_indices,
            "head_id": model.head_indices,
            "is_dropout": model.dropout_param
        }

        outputs = {"ctr_predict": model.scores_normalized}

        export_path = './model/' + args.dataset + '/result'
        signature = tf.saved_model.signature_def_utils.predict_signature_def(
            inputs=inputs, outputs=outputs)

        export_path = os.path.join(tf.compat.as_bytes(export_path),
                                   tf.compat.as_bytes(str(export_version)))
        builder = tf.saved_model.builder.SavedModelBuilder(export_path)
        legacy_init_op = tf.group(tf.tables_initializer(),
                                  name='legacy_init_op')
        builder.add_meta_graph_and_variables(
            sess=sess,
            tags=[tf.saved_model.tag_constants.SERVING],
            signature_def_map={
                'crt_scores': signature,
            },
            legacy_init_op=legacy_init_op)
        builder.save()
Esempio n. 2
0
def train(args, data, show_loss, show_topk):
    n_user, n_item, n_entity, n_relation = data[0], data[1], data[2], data[3]
    train_data, eval_data, test_data = data[4], data[5], data[6]
    kg = data[7]
    adj_entity, adj_relation = data[8], data[9]

    model = MKR(args, n_user, n_item, n_entity, n_relation, adj_entity,
                adj_relation)
    # model.load_weights('model_weights')
    # top-K evaluation settings
    user_num = 100
    k_list = [1, 2, 5, 10, 20, 50, 100]
    train_record = get_user_record(train_data, True)
    test_record = get_user_record(test_data, False)
    user_list = list(set(train_record.keys()) & set(test_record.keys()))
    if len(user_list) > user_num:
        user_list = np.random.choice(user_list, size=user_num, replace=False)
    item_set = set(list(range(n_item)))

    for step in range(args.n_epochs):
        # RS training
        np.random.shuffle(train_data)
        start = 0
        optimizers = tf.keras.optimizers.Adam(learning_rate=model.args.lr_rs)
        while start < train_data.shape[0]:
            with tf.GradientTape() as tape:
                _, loss = model.train_rs(
                    get_feed_dict_for_rs(train_data, start,
                                         start + args.batch_size))
                g = tape.gradient(loss, model.trainable_variables)
            # optimizers = tf.keras.optimizers.Adam(learning_rate=model.args.lr_rs)
            optimizers.apply_gradients(
                grads_and_vars=zip(g, model.trainable_variables))
            # _,loss=model.train_rs (get_feed_dict_for_rs(train_data, start, start + args.batch_size))
            start += args.batch_size
            if show_loss:
                print(loss)

        if step % args.kge_interval == 0:
            np.random.shuffle(kg)
            start = 0
            optimizers = tf.keras.optimizers.Adam(
                learning_rate=model.args.lr_kge)
            while start < kg.shape[0]:
                with tf.GradientTape() as tape:
                    loss, rmse = model.train_kge(
                        get_feed_dict_for_kge(kg, start,
                                              start + args.batch_size))
                    g = tape.gradient(loss, model.trainable_variables)
                # optimizers = tf.keras.optimizers.Adam(learning_rate=model.args.lr_kge)
                optimizers.apply_gradients(zip(g, model.trainable_variables))
                # _, rmse = model.train_kge(get_feed_dict_for_kge(kg, start, start + args.batch_size))
                start += args.batch_size
                if show_loss:
                    print(rmse)

    # CTR evaluation
    #     train_auc, train_acc = model.eval(get_feed_dict_for_rs(train_data, 0, train_data.shape[0]))
    #     eval_auc, eval_acc = model.eval(get_feed_dict_for_rs(eval_data, 0, eval_data.shape[0]))
    #     test_auc, test_acc = model.eval(get_feed_dict_for_rs(test_data, 0, test_data.shape[0]))
        train_auc, train_acc = batch_eval(model, train_data, args.batch_size)
        eval_auc, eval_acc = batch_eval(model, eval_data, args.batch_size)
        test_auc, test_acc = batch_eval(model, test_data, args.batch_size)

        print(
            'epoch %d    train auc: %.4f  acc: %.4f    eval auc: %.4f  acc: %.4f    test auc: %.4f  acc: %.4f'
            % (step, train_auc, train_acc, eval_auc, eval_acc, test_auc,
               test_acc))

    model.save_weights('model_weights')
Esempio n. 3
0
def train(args, data, show_loss, show_topk):
    n_user, n_item, n_entity, n_relation = data[0], data[1], data[2], data[3]
    train_data, eval_data, test_data = data[4], data[5], data[6]
    kg = data[7]

    model = MKR(args, n_user, n_item, n_entity, n_relation)

    # top-K evaluation settings
    user_num = 100
    k_list = [1, 2, 5, 10, 20, 50, 100]
    train_record = get_user_record(train_data, True)
    test_record = get_user_record(test_data, False)
    user_list = list(set(train_record.keys()) & set(test_record.keys()))
    if len(user_list) > user_num:
        user_list = np.random.choice(user_list, size=user_num, replace=False)
    item_set = set(list(range(n_item)))

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for step in range(args.n_epochs):
            # RS training
            np.random.shuffle(train_data)
            start = 0
            while start < train_data.shape[0]:
                _, loss = model.train_rs(sess, get_feed_dict_for_rs(model, train_data, start, start + args.batch_size))
                start += args.batch_size
                if show_loss:
                    print(loss)

            # KGE training
            if step % args.kge_interval == 0:
                np.random.shuffle(kg)
                start = 0
                while start < kg.shape[0]:
                    _, rmse = model.train_kge(sess, get_feed_dict_for_kge(model, kg, start, start + args.batch_size))
                    start += args.batch_size
                    if show_loss:
                        print(rmse)

            # CTR evaluation
            train_auc, train_acc = model.eval(sess, get_feed_dict_for_rs(model, train_data, 0, train_data.shape[0]))
            eval_auc, eval_acc = model.eval(sess, get_feed_dict_for_rs(model, eval_data, 0, eval_data.shape[0]))
            test_auc, test_acc = model.eval(sess, get_feed_dict_for_rs(model, test_data, 0, test_data.shape[0]))

            print('epoch %d    train auc: %.4f  acc: %.4f    eval auc: %.4f  acc: %.4f    test auc: %.4f  acc: %.4f'
                  % (step, train_auc, train_acc, eval_auc, eval_acc, test_auc, test_acc))

            # top-K evaluation
            if show_topk:
                precision, recall, f1 = topk_eval(
                    sess, model, user_list, train_record, test_record, item_set, k_list)
                print('precision: ', end='')
                for i in precision:
                    print('%.4f\t' % i, end='')
                print()
                print('recall: ', end='')
                for i in recall:
                    print('%.4f\t' % i, end='')
                print()
                print('f1: ', end='')
                for i in f1:
                    print('%.4f\t' % i, end='')
                print('\n')
Esempio n. 4
0
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
parser.add_argument('--l2_weight', type=float, default=1e-6, help='weight of l2 regularization')
parser.add_argument('--lr_rs', type=float, default=1e-3, help='learning rate of RS task')
parser.add_argument('--lr_kge', type=float, default=2e-4, help='learning rate of KGE task')
parser.add_argument('--kge_interval', type=int, default=2, help='training interval of KGE task')
parser.add_argument('--cuda', action="store_true", default=False, help='set this to use cuda')
args = parser.parse_args()
data = load_data(args)

n_user, n_item, n_entity, n_relation = data[0], data[1], data[2], data[3]
# train_data: [user item score]
train_data, test_data = data[4], data[5]
# kg: [head relation tail]
kg = data[6]

mkr = MKR(args, n_user, n_item, n_entity, n_relation)
loss_func = nn.BCELoss()
optimizer_kge = optim.Adam(mkr.parameters(), lr=args.lr_kge)

feed, tail_indices = get_data_for_kge(kg, 0, args.batch_size)
feed, tail_indices = torch.Tensor(feed).long(), torch.Tensor(tail_indices).long()

print("feed", feed)
print("labels", tail_indices)

for _ in range(10):
    tail_pred = mkr("kge", feed)

    tail_embeddings = mkr.entity_emb_matrix(tail_indices)
    for i in range(mkr.L):
        tail_embeddings = mkr.tail_mlps[i](tail_embeddings)
Esempio n. 5
0
def train(args, data, show_loss, show_topk):
    n_user, n_item, n_entity, n_relation = data[0], data[1], data[2], data[3]
    train_data, eval_data, test_data = data[4], data[5], data[6]
    kg = data[7]

    model = MKR(args, n_user, n_item, n_entity, n_relation)

    # top-K evaluation settings
    user_num = 100
    k_list = [1, 2, 5, 10, 20, 50, 100]
    train_record = get_user_record(train_data, True)
    test_record = get_user_record(test_data, False)
    user_list = list(set(train_record.keys()) & set(test_record.keys()))
    if len(user_list) > user_num:
        user_list = np.random.choice(user_list, size=user_num, replace=False)
    item_set = set(list(range(n_item)))

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for step in range(args.n_epochs):
            # RS training
            np.random.shuffle(train_data)
            start = 0
            while start < train_data.shape[0]:
                a = get_feed_dict_for_rs(model, train_data, start, start + args.batch_size)
                print(a)
                _, loss = model.train_rs(sess, get_feed_dict_for_rs(model, train_data, start, start + args.batch_size))
                start += args.batch_size
                if show_loss:
                    print(loss)

            # KGE training
            if step % args.kge_interval == 0:
                np.random.shuffle(kg)
                start = 0
                while start < kg.shape[0]:
                    _, rmse = model.train_kge(sess, get_feed_dict_for_kge(model, kg, start, start + args.batch_size))
                    start += args.batch_size
                    if show_loss:
                        print(rmse)

            # CTR evaluation
            train_auc, train_acc = model.eval(sess, get_feed_dict_for_rs(model, train_data, 0, train_data.shape[0]))
            eval_auc, eval_acc = model.eval(sess, get_feed_dict_for_rs(model, eval_data, 0, eval_data.shape[0]))
            test_auc, test_acc = model.eval(sess, get_feed_dict_for_rs(model, test_data, 0, test_data.shape[0]))

            print('epoch %d    train auc: %.4f  acc: %.4f    eval auc: %.4f  acc: %.4f    test auc: %.4f  acc: %.4f'
                  % (step, train_auc, train_acc, eval_auc, eval_acc, test_auc, test_acc))

            # top-K evaluation
            if show_topk:
                precision, recall, f1 = topk_eval(
                    sess, model, user_list, train_record, test_record, item_set, k_list)
                print('precision: ', end='')
                for i in precision:
                    print('%.4f\t' % i, end='')
                print()
                print('recall: ', end='')
                for i in recall:
                    print('%.4f\t' % i, end='')
                print()
                print('f1: ', end='')
                for i in f1:
                    print('%.4f\t' % i, end='')
                print('\n')
Esempio n. 6
0
def train(args, data, show_loss, show_topk):
    n_user, n_item, n_entity, n_relation = data[0], data[1], data[2], data[3]
    train_data, eval_data, test_data = data[4], data[5], data[6]
    kg = data[7]
    user_set = data[8]
    user_item_dict = data[9]

    BASELINE_OUTPUT_FILE = 'baseline_output.txt'
    OUTPUT_PATH = os.path.join('..', 'data', args.dataset,
                               BASELINE_OUTPUT_FILE)
    model = MKR(args, n_user, n_entity, n_entity, n_relation)

    # top-K evaluation settings
    user_num = 100
    k_list = [1, 2, 5, 10, 20, 50, 100]
    train_record = get_user_record(train_data, True)
    test_record = get_user_record(test_data, False)
    user_list = list(set(train_record.keys()) & set(test_record.keys()))
    if len(user_list) > user_num:
        user_list = np.random.choice(user_list, size=user_num, replace=False)
    item_set = set(list(range(n_item)))

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for step in range(args.n_epochs):
            # RS training
            np.random.shuffle(train_data)
            start = 0
            while start < train_data.shape[0]:
                _, loss = model.train_rs(
                    sess,
                    get_feed_dict_for_rs(model, train_data, start,
                                         start + args.batch_size))
                start += args.batch_size
                if show_loss:
                    print(loss)

            # KGE training
            if step % args.kge_interval == 0:
                np.random.shuffle(kg)
                start = 0
                while start < kg.shape[0]:
                    _, rmse = model.train_kge(
                        sess,
                        get_feed_dict_for_kge(model, kg, start,
                                              start + args.batch_size))
                    start += args.batch_size
                    if show_loss:
                        print(rmse)

            # CTR evaluation
            train_auc, train_acc = model.eval(
                sess,
                get_feed_dict_for_rs(model, train_data, 0,
                                     train_data.shape[0]))
            eval_auc, eval_acc = model.eval(
                sess,
                get_feed_dict_for_rs(model, eval_data, 0, eval_data.shape[0]))
            test_auc, test_acc = model.eval(
                sess,
                get_feed_dict_for_rs(model, test_data, 0, test_data.shape[0]))

            print(
                'epoch %d    train auc: %.4f  acc: %.4f    eval auc: %.4f  acc: %.4f    test auc: %.4f  acc: %.4f'
                % (step, train_auc, train_acc, eval_auc, eval_acc, test_auc,
                   test_acc))

            # top-K evaluation
            if show_topk:
                precision, recall, f1 = topk_eval(sess, model, user_list,
                                                  train_record, test_record,
                                                  item_set, k_list)
                print('precision: ', end='')
                for i in precision:
                    print('%.4f\t' % i, end='')
                print()
                print('recall: ', end='')
                for i in recall:
                    print('%.4f\t' % i, end='')
                print()
                print('f1: ', end='')
                for i in f1:
                    print('%.4f\t' % i, end='')
                print('\n')

        user_list = list(user_set)

        with open(OUTPUT_PATH, 'w') as writer:
            for user in user_list:

                test_item_list = list(user_item_dict[user])
                items, scores = model.get_scores(
                    sess, {
                        model.user_indices: [user] * len(test_item_list),
                        model.item_indices: test_item_list,
                        model.head_indices: test_item_list
                    })
                for item, score in zip(items, scores):
                    writer.write('%d\t%d\t%f\n' % (user, item, score))

        writer.close()
Esempio n. 7
0
def train(args, rs_dataset, kg_dataset):

    show_loss = args.show_loss
    show_topk = args.show_topk

    # Get RS data
    n_user = rs_dataset.n_user
    n_item = rs_dataset.n_item
    train_data, eval_data, test_data = rs_dataset.data
    train_indices, eval_indices, test_indices = rs_dataset.indices

    # Get KG data
    n_entity = kg_dataset.n_entity
    n_relation = kg_dataset.n_relation
    kg = kg_dataset.kg

    # Init train sampler
    train_sampler = SubsetRandomSampler(train_indices)

    # Init MKR model
    model = MKR(args, n_user, n_item, n_entity, n_relation)

    # Init Sumwriter
    writer = SummaryWriter(args.summary_path)

    # Top-K evaluation settings
    user_num = 100
    k_list = [1, 2, 5, 10, 20, 50, 100]
    train_record = get_user_record(train_data, True)
    test_record = get_user_record(test_data, False)
    user_list = list(set(train_record.keys()) & set(test_record.keys()))
    if len(user_list) > user_num:
        user_list = np.random.choice(user_list, size=user_num, replace=False)
    item_set = set(list(range(n_item)))
    step = 0
    for epoch in range(args.n_epochs):
        print("Train RS")
        train_loader = DataLoader(rs_dataset, batch_size=args.batch_size,
                                  num_workers=args.workers, sampler=train_sampler)
        for i, rs_batch_data in enumerate(train_loader):
            loss, base_loss_rs, l2_loss_rs = model.train_rs(rs_batch_data)
            writer.add_scalar("rs_loss", loss.cpu().detach().numpy(), global_step=step)
            writer.add_scalar("rs_base_loss", base_loss_rs.cpu().detach().numpy(), global_step=step)
            writer.add_scalar("rs_l2_loss", l2_loss_rs.cpu().detach().numpy(), global_step=step)
            step += 1
            if show_loss:
                print(loss)

        if epoch % args.kge_interval == 0:
            print("Train KGE")
            kg_train_loader = DataLoader(kg_dataset, batch_size=args.batch_size,
                                         num_workers=args.workers, shuffle=True)
            for i, kg_batch_data in enumerate(kg_train_loader):
                rmse, loss_kge, base_loss_kge, l2_loss_kge = model.train_kge(kg_batch_data)
                writer.add_scalar("kge_rmse_loss", rmse.cpu().detach().numpy(), global_step=step)
                writer.add_scalar("kge_loss", loss_kge.cpu().detach().numpy(), global_step=step)
                writer.add_scalar("kge_base_loss", base_loss_kge.cpu().detach().numpy(), global_step=step)
                writer.add_scalar("kge_l2_loss", l2_loss_kge.cpu().detach().numpy(), global_step=step)
                step += 1
                if show_loss:
                    print(rmse)


        # CTR evaluation
        train_auc, train_acc = model.eval(train_data)
        eval_auc, eval_acc = model.eval(eval_data)
        test_auc, test_acc = model.eval(test_data)

        print('epoch %d    train auc: %.4f  acc: %.4f    eval auc: %.4f  acc: %.4f    test auc: %.4f  acc: %.4f'
              % (epoch, train_auc, train_acc, eval_auc, eval_acc, test_auc, test_acc))

        # top-K evaluation
        if show_topk:
            precision, recall, f1 = model.topk_eval(user_list, train_record, test_record, item_set, k_list)
            print('precision: ', end='')
            for i in precision:
                print('%.4f\t' % i, end='')
            print()
            print('recall: ', end='')
            for i in recall:
                print('%.4f\t' % i, end='')
            print()
            print('f1: ', end='')
            for i in f1:
                print('%.4f\t' % i, end='')
            print('\n')
Esempio n. 8
0
def train(args, data):
    n_user, n_item, n_entity, n_relation = data[0], data[1], data[2], data[3]
    # train_data: [user item score]
    train_data, test_data = data[4], data[5]
    # kg: [head relation tail]
    kg = data[6]

    mkr = MKR(args, n_user, n_item, n_entity, n_relation)
    if args.cuda:
        mkr.cuda()

    loss_func = nn.BCELoss()
    optimizer_rs = optim.Adam(mkr.parameters(), lr=args.lr_rs)
    optimizer_kge = optim.Adam(mkr.parameters(), lr=args.lr_kge)

    # store best state
    best_test_acc = 0.0
    best_state_dict = None

    for epoch in range(args.n_epochs):
        # RS training
        np.random.shuffle(train_data)
        start = 0
        while start < train_data.shape[0]:
            feed, labels = get_data_for_rs(train_data, start,
                                           start + args.batch_size)
            feed, labels = torch.Tensor(feed).long(), torch.Tensor(
                labels).float()
            if args.cuda:
                feed, labels = feed.cuda(), labels.cuda()
            scores_normalized = mkr("rs", feed)

            # build loss for RS
            base_loss_rs = loss_func(scores_normalized, labels)
            l2_loss_rs = (mkr.user_embeddings**
                          2).sum() / 2 + (mkr.item_embeddings**2).sum() / 2
            loss_rs = base_loss_rs + l2_loss_rs * args.l2_weight

            optimizer_rs.zero_grad()
            loss_rs.backward()
            optimizer_rs.step()

            start += args.batch_size

        if epoch % args.kge_interval == 0:
            # KGE training
            np.random.shuffle(kg)
            start = 0
            while start < kg.shape[0]:
                feed, tail_indices = get_data_for_kge(kg, start,
                                                      start + args.batch_size)
                feed, tail_indices = torch.Tensor(feed).long(), torch.Tensor(
                    tail_indices).long()
                if args.cuda:
                    feed, tail_indices = feed.cuda(), tail_indices.cuda()
                tail_pred = mkr("kge", feed)

                # build loss for KGE
                tail_embeddings = mkr.entity_emb_matrix(tail_indices)
                for i in range(mkr.L):
                    tail_embeddings = mkr.tail_mlps[i](tail_embeddings)
                #scores_kge = mkr.sigmoid((tail_embeddings * tail_pred).sum(1))
                #l2_loss_kge = (mkr.head_embeddings ** 2).sum() / 2 + (tail_embeddings ** 2).sum() / 2
                #loss_kge = -scores_kge + l2_loss_kge * args.l2_weight
                rmse = (((tail_embeddings - tail_pred)**2).sum(1) /
                        args.dim).sqrt().sum()

                optimizer_kge.zero_grad()
                rmse.backward()
                optimizer_kge.step()

                start += args.batch_size

        # Evaluating——train data
        inputs, y_true = get_data_for_rs(train_data, 0, train_data.shape[0])
        inputs, y_true = torch.Tensor(inputs).long(), torch.Tensor(
            y_true).byte()
        if args.cuda:
            inputs, y_true = inputs.cuda(), y_true.cuda()
        train_auc, train_acc = mkr.evaluate(inputs, y_true)

        # Evaluating——test data
        inputs, y_true = get_data_for_rs(test_data, 0, test_data.shape[0])
        inputs, y_true = torch.Tensor(inputs).long(), torch.Tensor(
            y_true).byte()
        if args.cuda:
            inputs, y_true = inputs.cuda(), y_true.cuda()
        test_auc, test_acc = mkr.evaluate(inputs, y_true)

        if test_acc > best_test_acc:
            best_test_acc = test_acc
            best_state_dict = deepcopy(mkr.state_dict())

        print(
            "epoch {:3d} | train auc: {:.4f} acc: {:.4f} | test auc: {:.4f} acc:{:.4f}"
            .format(epoch, train_auc, train_acc, test_auc, test_acc))

    # Save model
    wts_name = "../model/MKR_{}_{:.4f}.pth".format(args.dataset, best_test_acc)
    torch.save(best_state_dict, wts_name)
    print("Saved model to {}".format(wts_name))

    return mkr