for train, test in get_10_fold_data(data_1, num_labels):
    flag += 1
    data_1_train = data_1[train]
    data_2_train = data_2[train]
    y_train = labels_all[train]

    data_1_test = data_1[test]
    data_2_test = data_2[test]
    y_test = labels_all[test]
    end2end_model = E2E_MCTN_Model(configs, data_1_train, data_2_train)

    for i in range(10):
        print('now it is epoch %d' % i)
        x2_train = end2end_model.embeding_model.predict(data_2_train).mean(1)
        x2_test = end2end_model.embeding_model.predict(data_2_test).mean(1)
        end2end_model.model.fit(
            x=[data_1_train],
            y=[x2_train, y_train],
            epochs=1,
            validation_data=[data_1_test, [x2_test, y_test]],
            # self.input_test,
            batch_size=256,
            verbose=2,
        )
        predictions = end2end_model.model.predict(data_1_test)[-1]
        get_preds_statistics(predictions, y_test)

    print(
        '--------------------------finish %d cross validation------------------------------------------------'
        % flag)
    ModelCheckpoint(weights_path,
                    monitor='val_loss',
                    save_best_only=True,
                    verbose=1),
]

try:
    end2end_model.model.load_weights(weights_path)
    print("\nWeights loaded from {}\n".format(weights_path))
except:
    print("\nCannot load weight. Training from scratch\n")

#
print("TRAINING NOW...")
train = 1

if train == 1:
    history = end2end_model.train(weights_path=weights_path,
                                  n_epochs=args.train_epoch,
                                  val_split=args.val_split,
                                  batch_size=args.batch_size,
                                  callbacks=callbacks)

    with open('history_params.sav', 'wb') as f:
        pickle.dump(history.history, f, -1)

print("PREDICTING...")
predictions = end2end_model.predict()
# predictions = predictions.reshape(-1, )
get_preds_statistics(predictions, feats_dict['test_labels'])