Пример #1
0
    def test_compile_fit(self):
        def gen_rand_user_item_feature(user_num, item_num, class_num):
            user_id = random.randint(1, user_num)
            item_id = random.randint(1, item_num)
            rating = random.randint(1, class_num)
            sample = Sample.from_ndarray(np.array([user_id, item_id]),
                                         np.array([rating]))
            return UserItemFeature(user_id, item_id, sample)

        model = NeuralCF(200, 80, 5)
        model.summary()
        data = self.sc.parallelize(range(0, 50)) \
            .map(lambda i: gen_rand_user_item_feature(200, 80, 5)) \
            .map(lambda pair: pair.sample)
        model.compile(
            optimizer="adam",
            loss=SparseCategoricalCrossEntropy(zero_based_label=False),
            metrics=['accuracy'])
        tmp_log_dir = create_tmp_path()
        model.set_tensorboard(tmp_log_dir, "training_test")
        model.fit(data, nb_epoch=1, batch_size=32, validation_data=data)
        train_loss = model.get_train_summary("Loss")
        val_loss = model.get_validation_summary("Loss")
        print(np.array(train_loss))
        print(np.array(val_loss))
Пример #2
0
train_rdd= trainPairFeatureRdds.map(lambda pair_feature: pair_feature.sample)
val_rdd= valPairFeatureRdds.map(lambda pair_feature: pair_feature.sample)
val_rdd.persist()

ncf = NeuralCF(user_count=max_user_id,
               item_count=max_movie_id,
               class_num=5,
               hidden_layers=[20, 10],
               include_mf = False)

ncf.compile(optimizer= "adam",
            loss= "sparse_categorical_crossentropy",
            metrics=['accuracy'])

ncf.fit(train_rdd,
        nb_epoch= 10,
        batch_size= 8000,
        validation_data=val_rdd)

ncf.save_model("../save_model/movie_ncf.zoomodel", over_write=True)
#
weights = ncf.get_weights()
# print(weights)
print(len(weights))

for i, weight in enumerate(weights):
    print(i)
    print(weight.shape)
#
loaded = ncf.load_model("../save_model/movie_ncf.zoomodel")
user_embed = loaded.get_weights()[0]
print(user_embed.shape)