return kfold.split(X, Y)


data_1, data_2, labels_all, num_labels = load_data_for_cross_test()
flag = 0
for train, test in get_10_fold_data(data_1, num_labels):
    flag += 1
    data_1_train = data_1[train]
    data_2_train = data_2[train]
    y_train = labels_all[train]

    data_1_test = data_1[test]
    data_2_test = data_2[test]
    y_test = labels_all[test]
    end2end_model = E2E_MCTN_Model(configs, data_1_train, data_2_train)

    for i in range(10):
        print('now it is epoch %d' % i)
        x2_train = end2end_model.embeding_model.predict(data_2_train).mean(1)
        x2_test = end2end_model.embeding_model.predict(data_2_test).mean(1)
        end2end_model.model.fit(
            x=[data_1_train],
            y=[x2_train, y_train],
            epochs=1,
            validation_data=[data_1_test, [x2_test, y_test]],
            # self.input_test,
            batch_size=256,
            verbose=2,
        )
        predictions = end2end_model.model.predict(data_1_test)[-1]

np.random.seed(123)
tf.set_random_seed(456)

# arguments
args, configs = parse_args()

# data load
is_cycled = configs['translation']['is_cycled']
feats_dict = load_search_data()

print("FORMING SEQ2SEQ MODEL...")
features = args.feature  # e.g. ['a', 't']
assert len(features) == 2, 'Wrong number of features'
end2end_model = E2E_MCTN_Model(configs, features, feats_dict)

print("PREP FOR TRAINING...")
filename = '_'.join(args.feature) + "_attention_seq2seq_" + \
           str("bi_directional" if configs['translation']['is_bidirectional']
               else '') + \
           "_bimodal.h5"

output_dir = configs['general']['output_dir']
weights_path = os.path.join(output_dir, filename)
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

callbacks = [
    # EarlyStopping(monitor='val_loss', patience=args.train_patience, verbose=0),
    ModelCheckpoint(weights_path,