Ejemplo n.º 1
0
def test_DIN_sum():

    model_name = "DIN_sum"
    x, y, feature_dim_dict, behavior_feature_list = get_xy_fd()

    model = DIN(feature_dim_dict,
                behavior_feature_list,
                hist_len_max=4,
                embedding_size=8,
                use_din=False,
                hidden_size=[4, 4, 4],
                keep_prob=0.6,
                activation="sigmoid")

    model.compile('adam',
                  'binary_crossentropy',
                  metrics=['binary_crossentropy'])
    model.fit(x, y, verbose=1, validation_split=0.5)

    print(model_name + " test train valid pass!")
    model.save_weights(model_name + '_weights.h5')
    model.load_weights(model_name + '_weights.h5')
    print(model_name + " test save load weight pass!")

    save_model(model, model_name + '.h5')
    model = load_model(model_name + '.h5', custom_objects)
    print(model_name + " test save load model pass!")

    print(model_name + " test pass!")
Ejemplo n.º 2
0
def test_DIN_att():
    model_name = "DIN_att"

    x, y, feature_dim_dict, behavior_feature_list = get_xy_fd()

    model = DIN(
        feature_dim_dict,
        behavior_feature_list,
        hist_len_max=4,
        embedding_size=8,
        use_din=True,
        hidden_size=[4, 4, 4],
        keep_prob=0.6,
    )

    model.compile('adam',
                  'binary_crossentropy',
                  metrics=['binary_crossentropy'])
    model.fit(x, y, verbose=1, validation_split=0.5)

    print(model_name + " test train valid pass!")
    model.save_weights(model_name + '_weights.h5')
    model.load_weights(model_name + '_weights.h5')
    print(model_name + " test save load weight pass!")

    # try:
    #     save_model(model,  name + '.h5')
    #     model = load_model(name + '.h5', custom_objects)
    #     print(name + " test save load model pass!")
    # except:
    #     print("【Error】There is a bug when save model use Dice---------------------------------------------------")

    print(model_name + " test pass!")
Ejemplo n.º 3
0
def test_DIN_att():
    model_name = "DIN_att"

    x, y, feature_dim_dict, behavior_feature_list = get_xy_fd()

    model = DIN(feature_dim_dict, behavior_feature_list, hist_len_max=4, embedding_size=8,
                use_din=True, hidden_size=[4, 4, 4], keep_prob=0.6,)

    model.compile('adam', 'binary_crossentropy',
                  metrics=['binary_crossentropy'])
    model.fit(x, y, verbose=1, validation_split=0.5)

    print(model_name+" test train valid pass!")
    model.save_weights(model_name + '_weights.h5')
    model.load_weights(model_name + '_weights.h5')
    print(model_name+" test save load weight pass!")

    # try:
    #     save_model(model,  name + '.h5')
    #     model = load_model(name + '.h5', custom_objects)
    #     print(name + " test save load model pass!")
    # except:
    #     print("【Error】There is a bug when save model use Dice---------------------------------------------------")

    print(model_name + " test pass!")