def test_DIN_sum(): model_name = "DIN_sum" x, y, feature_dim_dict, behavior_feature_list = get_xy_fd() model = DIN(feature_dim_dict, behavior_feature_list, hist_len_max=4, embedding_size=8, use_din=False, hidden_size=[4, 4, 4], keep_prob=0.6, activation="sigmoid") model.compile('adam', 'binary_crossentropy', metrics=['binary_crossentropy']) model.fit(x, y, verbose=1, validation_split=0.5) print(model_name + " test train valid pass!") model.save_weights(model_name + '_weights.h5') model.load_weights(model_name + '_weights.h5') print(model_name + " test save load weight pass!") save_model(model, model_name + '.h5') model = load_model(model_name + '.h5', custom_objects) print(model_name + " test save load model pass!") print(model_name + " test pass!")
def test_DIN_att(): model_name = "DIN_att" x, y, feature_dim_dict, behavior_feature_list = get_xy_fd() model = DIN( feature_dim_dict, behavior_feature_list, hist_len_max=4, embedding_size=8, use_din=True, hidden_size=[4, 4, 4], keep_prob=0.6, ) model.compile('adam', 'binary_crossentropy', metrics=['binary_crossentropy']) model.fit(x, y, verbose=1, validation_split=0.5) print(model_name + " test train valid pass!") model.save_weights(model_name + '_weights.h5') model.load_weights(model_name + '_weights.h5') print(model_name + " test save load weight pass!") # try: # save_model(model, name + '.h5') # model = load_model(name + '.h5', custom_objects) # print(name + " test save load model pass!") # except: # print("【Error】There is a bug when save model use Dice---------------------------------------------------") print(model_name + " test pass!")
def test_DIN_att(): model_name = "DIN_att" x, y, feature_dim_dict, behavior_feature_list = get_xy_fd() model = DIN(feature_dim_dict, behavior_feature_list, hist_len_max=4, embedding_size=8, use_din=True, hidden_size=[4, 4, 4], keep_prob=0.6,) model.compile('adam', 'binary_crossentropy', metrics=['binary_crossentropy']) model.fit(x, y, verbose=1, validation_split=0.5) print(model_name+" test train valid pass!") model.save_weights(model_name + '_weights.h5') model.load_weights(model_name + '_weights.h5') print(model_name+" test save load weight pass!") # try: # save_model(model, name + '.h5') # model = load_model(name + '.h5', custom_objects) # print(name + " test save load model pass!") # except: # print("【Error】There is a bug when save model use Dice---------------------------------------------------") print(model_name + " test pass!")