def build_model(lr=0.0, lr_d=0.0): """data""" build_vocab() # load_vocab() # train_data_generator = get_data_generator("train") X_query_text_pad, X_doc_text_pad, X_type_pad, y_ohe = get_data_all("train", 100000) """layers""" # inp_a = Input(shape=(max_len_q,)) # inp_object_s = Input(shape=(max_len_d,)) inp_type = Input(shape=(type_len,)) embedding = BERTEmbedding(emn_path, 50) base_model1 = embedding.model embedding2 = BERTEmbedding(emn_path, 50) base_model2 = embedding2.model # base_model2.inputs out = get_layer(inp_type, base_model1, base_model1) """call back""" check_point = ModelCheckpoint(model_path, monitor="val_loss", verbose=1, save_best_only=True, mode="min") # early_stop = EarlyStopping(monitor="val_loss", mode="min", patience=3) early_stop = EarlyStopping(monitor="val_loss", mode="min", patience=5) tb_cb = TensorBoard(log_dir=log_filepath) """fine-tune""" model = Model(inputs=[*base_model1.inputs, *base_model2.inputs, inp_type], outputs=out) # model.trainable = True # for layer in model.layers[:1]: # layer.trainable = False model.summary() """train""" model.compile(loss="binary_crossentropy", optimizer=Adam(lr=lr, decay=lr_d), metrics=["accuracy"]) model.fit(x=[X_query_text_pad, X_doc_text_pad, X_type_pad], y=y_ohe, batch_size=1, epochs=20, validation_split=0.5, verbose=1, class_weight="auto", callbacks=[check_point, early_stop, tb_cb]) # model.fit_generator(train_data_generator, # batch_size=24, # epochs=20, # validation_split=0.3, # verbose=1, # class_weight={0: 1, 1: 10}, # callbacks=[check_point, early_stop, tb_cb]) K.clear_session() tf.reset_default_graph() model = load_model(model_path) return model
def train_BERT_BiLSTM_CRF( train_test_devide=0.9, epoch=20, path='/home/peitian_zhang/data/corpus/labeled_train.txt'): train_x, train_y = getTrain(path) x = train_x[:int(len(train_x) * train_test_devide) + 1] y = train_y[:int(len(train_x) * train_test_devide) + 1] bert = BERTEmbedding( model_folder='/home/peitian_zhang/data/chinese_L-12_H-768_A-12', sequence_length=400, task=kashgari.LABELING) model = BiLSTM_CRF_Model(bert) model.fit(x, y, x, y, epochs=epoch, batch_size=64) print('---------evaluate on train---------\n{}'.format( model.evaluate(train_x, train_y))) print('---------evaluate on test----------\n{}'.format( model.evaluate(train_x[int(len(train_x) * train_test_devide) + 1:], train_y[int(len(train_x) * train_test_devide) + 1:]))) try: model.save('/home/peitian_zhang/models/bert_epoch_{}'.format(epoch)) print('Success in saving!') except: pass return model
def main(train=False, eval=False, predict=False, checkpoint=None): if train: bert_embed = BERTEmbedding( '/media/ding/Data/model_weight/nlp_weights/tensorflow/google/chinese_L-12_H-768_A-12', task=kashgari.LABELING, sequence_length=150) # 还可以选择 `CNN_LSTM_Model`, `BiLSTM_Model`, `BiGRU_Model` 或 `BiGRU_CRF_Model` # model = BiLSTM_CRF_Model(bert_embed) model = BiGRU_CRF_Model(bert_embed) model.fit(train_x, train_y, x_validate=valid_x, y_validate=valid_y, epochs=10, batch_size=64) model.save('checkpoint') print('验证集的测试结果:') model.evaluate(valid_x, valid_y) # 验证集输出结果 if eval: model = kashgari.utils.load_model(checkpoint) print('测试集的测试结果:') model.evaluate(test_x, test_y) # 测试机输出结果 if predict: model = kashgari.utils.load_model(checkpoint) out = model.predict( ['十 一 月 , 我 们 被 称 为 “ 南 京 市 首 届 家 庭 藏 书 状 元 明 星 户 ” 。'.split()]) print(out)
def train_ner(x_train, y_train, x_valid, y_valid, x_test, y_test, sequence_length, epoch, batch_size, bert_model_path, model_save_path): """ BERT-BiLSTM-CRF 模型训练,提取症状内部特征 """ bert_embedding = BERTEmbedding(bert_model_path, task=kashgari.LABELING, sequence_length=sequence_length) model = BiLSTM_CRF_Model(bert_embedding) eval_callback_val = EvalCallBack(kash_model=model, valid_x=x_valid, valid_y=y_valid, step=1) eval_callback_test = EvalCallBack(kash_model=model, valid_x=x_test, valid_y=y_test, step=1) model.fit(x_train, y_train, x_validate=x_valid, y_validate=y_valid, epochs=epoch, batch_size=batch_size, callbacks=[eval_callback_val, eval_callback_test]) model.save(model_save_path) model.evaluate(x_test, y_test) return model
def main(): # train_x, train_y = ChineseDailyNerCorpus.load_data("train") # valid_x, valid_y = ChineseDailyNerCorpus.load_data("validate") ChineseDailyNerCorpus.__zip_file__name test_x, test_y = ChineseDailyNerCorpus.load_data("test") # print(f"train data count: {len(train_x)}") # print(f"validate data count: {len(valid_x)}") print(f"test data count: {len(test_x)}") bert_embed = BERTEmbedding("models/chinese_L-12_H-768_A-12", task=kashgari.LABELING, sequence_length=100) model = BiLSTM_CRF_Model(bert_embed) # model.fit( # train_x, # train_y, # x_validate=valid_x, # y_validate=valid_y, # epochs=1, # batch_size=512, # ) model.save("models/ner.h5") model.evaluate(test_x, test_y) predictions = model.predict_classes(test_x) print(predictions)
def initial_total_model(self): # 同时加载prosody和phone print('=============init bert model=========================') print("bert model path:", self._bert_model_path) print("phone model path:", self._phone_model_path) print("prosody model path:", self._prosody_model_path) self.sess = tf.Session() set_session(self.sess) self._embed_model = BERTEmbedding(self._bert_model_path, task=kashgari.LABELING, sequence_length=50) self._prosody_model = BiLSTM_CRF_Model(self._embed_model) train_data, train_label, test_data, test_label = \ pickle.load(open(self._psd_data_path, 'rb')) self._prosody_model.build_model(x_train=train_data, y_train=train_label, x_validate=test_data, y_validate=test_label) self._prosody_model.compile_model() self._prosody_model.tf_model.load_weights(self._prosody_model_path) self._phone_model = BiLSTM_CRF_Model(self._embed_model) train_data, train_label, test_data, test_label = \ pickle.load(open(self._data_path, 'rb')) self._phone_model.build_model(x_train=train_data, y_train=train_label, x_validate=test_data, y_validate=test_label) self._phone_model.compile_model() self._phone_model.tf_model.load_weights(self._phone_model_path) print('============= model loaded=========================') return
def build(self): embed = BERTEmbedding(model_folder=self.folder, task=kashgari.LABELING, trainable=self.fine_tune, sequence_length=self.seq_len) model = BiLSTM_CRF_Model(embed) return model
def get_bert(cls): if cls.bert_embedding is None: cls.bert_embedding = BERTEmbedding('bert-base-chinese', sequence_length=15) logging.info('bert_embedding seq len: {}'.format( cls.bert_embedding.sequence_length)) return cls.bert_embedding
def main(): train_path = '/home/qianlang/WordSeg-master/Data/train/data_generate_train.utf8' dev_path = '/home/qianlang/WordSeg-master/Data/train/data_generate_train.utf8' test_path = '/home/qianlang/WordSeg-master/Data/test/data_generate_test.utf8' # dev_path = r'D:\Pycharm\Project\data_analyze\Data_processing\data_generate\data_generate_dev.utf8' train_x, train_y = load_dataset(train_path) dev_x, dev_y = load_dataset(dev_path) test_x, test_y = load_dataset(test_path) bert_embed = BERTEmbedding('chinese_wwm_ext_L-12_H-768_A-12', task=kashgari.LABELING, sequence_length=128) model = BiGRU_CRF_Model(bert_embed) tf_board_callback = keras.callbacks.TensorBoard(log_dir='./logs', update_freq=1000) # Build-in callback for print precision, recall and f1 at every epoch step eval_callback = EvalCallBack(kash_model=model, valid_x=dev_x, valid_y=dev_y, step=5) model.fit(train_x, train_y, dev_x, dev_y, batch_size=256, callbacks=[eval_callback, tf_board_callback]) model.evaluate(test_x, test_y) model.save('cws.h5')
def __init__(self, component_config=None, model=None): super(KashgariIntentClassifier, self).__init__(component_config) bert_model_path = self.component_config.get('bert_model_path') sequence_length = self.component_config.get('sequence_length') layer_nums = self.component_config.get('layer_nums') trainable = self.component_config.get('trainable') use_cudnn_cell = self.component_config.get('use_cudnn_cell') self.multi_label = self.component_config.get('multi_label') self.split_symbol = self.component_config.get('split_symbol') kashgari.config.use_cudnn_cell = use_cudnn_cell processor = ClassificationProcessor(multi_label=self.multi_label) self.classifier_model = self.component_config.get('classifier_model') self.bert_embedding = BERTEmbedding(bert_model_path, task=kashgari.CLASSIFICATION, layer_nums=layer_nums, trainable=trainable, processor=processor, sequence_length=sequence_length) self.tokenizer = self.bert_embedding.tokenizer self.model = model
def get_bert(cls): if cls.bert_embedding is None: dir_path = os.path.dirname(os.path.realpath(__file__)) bert_path = os.path.join(dir_path, 'data', 'test_bert_checkpoint') cls.bert_embedding = BERTEmbedding(bert_path, sequence_length=SEQUENCE_LENGTH) return cls.bert_embedding
def test_bert_embedding(self): text, label = ChineseDailyNerCorpus.load_data() is_bold = np.random.randint(1, 3, (len(text), 12)) bert_path = get_file( 'bert_sample_model', "http://s3.bmio.net/kashgari/bert_sample_model.tar.bz2", cache_dir=DATA_PATH, untar=True) text_embedding = BERTEmbedding(bert_path, task=kashgari.LABELING, sequence_length=12) num_feature_embedding = NumericFeaturesEmbedding(2, 'is_bold', sequence_length=12) stack_embedding = StackedEmbedding( [text_embedding, num_feature_embedding]) stack_embedding.analyze_corpus((text, is_bold), label) tensor = stack_embedding.process_x_dataset((text[:3], is_bold[:3])) print(tensor[0][0].shape) print(tensor[0][1].shape) print(tensor[1].shape) print(stack_embedding.embed_model.input_shape) print(stack_embedding.embed_model.summary()) r = stack_embedding.embed((text[:3], is_bold[:3])) assert r.shape == (3, 12, 24)
def cal_cosine(): bert = BERTEmbedding(chinese_bert_file, task=kashgari.CLASSIFICATION, sequence_length=10) # call for single embed seg_list1 = jieba.cut("我来到北京清华大学", cut_all=False) seg_list2 = jieba.cut("天然矿泉水是指从地下深处自然涌出的或钻井采集的", cut_all=False) seg_list1 = list(seg_list1) seg_list2 = list(seg_list2) embed_tensor1 = bert.embed_one(seg_list1) embed_tensor2 = bert.embed_one(seg_list2) # embed_tensor1 = bert.embed_one(['今','天','天','气','不','错']) # embed_tensor2 = bert.embed_one(['我','住','在','南','京']) print(embed_tensor1.shape) print(embed_tensor2.shape) embedding1 = np.zeros(shape=(1, 3072)) embedding2 = np.zeros(shape=(1, 3072)) for i in range(embed_tensor1.shape[0]): # print(embed_tensor1[i][:]) embedding1 += embed_tensor1[i][:] embedding2 += embed_tensor2[i][:] print(embedding1) print(embedding2) cos_value = cosine_similarity(embedding1, embedding2) print('cos_value =', str(cos_value[0][0]))
def main(): examples = [ "《中国风水十讲》是2007年华夏出版社出版的图书,作者是杨文衡", "你是最爱词:许常德李素珍/曲:刘天健你的故事写到你离去后为止", "《苏州商会档案丛编第二辑》是2012年华中师范大学出版社出版的图书,作者是马敏、祖苏、肖芃" ] sess = tf.compat.v1.Session() model_path = "/home/johnsaxon/github.com/oushu1zhangxiangxuan1/HolmesNER/serving/savedmodel_loader/models/ner/m1" # tf.saved_model.loader.load( tf.compat.v1.saved_model.loader.load(sess, [tf.saved_model.SERVING], model_path) prediction = sess.graph.get_tensor_by_name("layer_crf/cond/Merge:0") bert_embed = BERTEmbedding( "/home/johnsaxon/github.com/Entity-Relation-Extraction/pretrained_model/chinese_L-12_H-768_A-12", task=kashgari.LABELING, sequence_length=100) x0, x1 = bert_embed.process_x_dataset(examples) print(x0, x1) predictions_result = sess.run(prediction, feed_dict={ 'Input-Segment_1:0': x0, 'Input-Token_1:0': x1 }) sess.close() print(predictions_result)
def train(): parser = argparse.ArgumentParser() parser.add_argument('model_dir', default='model dir') args = parser.parse_args() model_dir = args.model_dir hdf_dir = os.path.join(model_dir, "hdf5") os.makedirs(hdf_dir, exist_ok=True) bert_model_path = os.path.join(ROOT_DIR, 'BERT-baseline') data_path = os.path.join(model_dir, "feature.pkl") with open(data_path, 'rb') as fr: train_data, train_label, test_data, test_label = pickle.load(fr) print("load {}/{} train/dev items ".format(len(train_data), len(test_data))) bert_embed = BERTEmbedding(bert_model_path, task=kashgari.LABELING, sequence_length=50) model = KashModel(bert_embed) model.build_model(x_train=train_data, y_train=train_label, x_validate=test_data, y_validate=test_label) from src.get_model_path import get_model_path model_path, init_epoch = get_model_path(hdf_dir) if init_epoch > 0: print("load epoch from {}".format(model_path)) model.tf_model.load_weights(model_path) optimizer = RAdam(learning_rate=0.0001) model.compile_model(optimizer=optimizer) hdf5_path = os.path.join(hdf_dir, "crf-{epoch:03d}-{val_accuracy:.3f}.hdf5") checkpoint = ModelCheckpoint(hdf5_path, monitor='val_accuracy', verbose=1, save_best_only=True, save_weights_only=False, mode='auto', period=1) tensorboard = TensorBoard(log_dir=os.path.join(model_dir, "log")) eval_callback = EvalCallBack(kash_model=model, valid_x=test_data, valid_y=test_label, step=1, log_path=os.path.join(model_dir, "acc.txt")) callbacks = [checkpoint, tensorboard, eval_callback] model.fit(train_data, train_label, x_validate=test_data, y_validate=test_label, epochs=100, batch_size=256, callbacks=callbacks) return
def get_bert(cls): if cls.bert_embedding is None: dir_path = os.path.dirname(os.path.realpath(__file__)) bert_path = os.path.join(dir_path, 'data', 'test_bert_checkpoint') cls.bert_embedding = BERTEmbedding(bert_path, sequence_length=15) logging.info('bert_embedding seq len: {}'.format( cls.bert_embedding.sequence_length)) return cls.bert_embedding
def train(self, datasets): """ 训练函数 :param datasets: :return: """ epochs = self.defaults.get('epochs') batch_size = self.defaults.get('batch_size') validation_split = self.defaults.get('validation_split') patience = self.defaults.get('patience') # 训练学习率下降的改变因子 factor = self.defaults.get('factor') verbose = self.defaults.get('verbose') x_train, y_train, x_val, y_val = self.get_dataset(datasets) self.bert_embedding = BERTEmbedding(self.bert_model_path, task=kashgari.LABELING, layer_nums=self.layer_nums, trainable=self.trainable, sequence_length=self.sequence_length) labeling_model = eval("labeling." + self.labeling_model) # load 模型结构 self.model = labeling_model(self.bert_embedding) # 设置回调状态 checkpoint = ModelCheckpoint( 'entity_weights.h5', monitor='val_loss', save_best_only=True, save_weights_only=False, verbose=verbose) early_stopping = EarlyStopping( monitor='val_loss', patience=patience) reduce_lr = ReduceLROnPlateau( monitor='val_loss', factor=factor, patience=patience, verbose=verbose) log_dir = path.LOG_PATH if os.path.exists(log_dir): # 路径已经存在删除路径 shutil.rmtree(log_dir) tensor_board = TensorBoard( log_dir=log_dir, batch_size=batch_size) # 训练模型 self.model.fit( x_train, y_train, x_val, y_val, epochs=epochs, batch_size=batch_size, callbacks=[checkpoint, early_stopping, reduce_lr, tensor_board] )
def train(): pos_data_path = '../dataset/weibo60000/pos60000_utf8.txt_updated' pos_x, pos_y = read_pos_data(pos_data_path) print(len(pos_x)) print(len(pos_y)) # print(pos_y) neg_data_path = '../dataset/weibo60000/neg60000_utf8.txt_updated' neg_x, neg_y = read_neg_data(neg_data_path) print(len(neg_x)) print(len(neg_y)) # print(neg_y) train_pos_x = pos_x[:41025] train_pos_y = pos_y[:41025] val_pos_x = pos_x[41025:52746] val_pos_y = pos_y[41025:52746] test_pos_x = pos_x[52746:] test_pos_y = pos_y[52746:] train_neg_x = neg_x[:41165] train_neg_y = neg_y[:41165] val_neg_x = neg_x[41165:52926] val_neg_y = neg_y[41165:52926] test_neg_x = neg_x[52926:] test_neg_y = neg_y[52926:] train_x, train_y = concate_data(train_pos_x, train_pos_y, train_neg_x, train_neg_y) val_x, val_y = concate_data(val_pos_x, val_pos_y, val_neg_x, val_neg_y) test_x, test_y = concate_data(test_pos_x, test_pos_y, test_neg_x, test_neg_y) print('The number of train-set:', len(train_x)) # print(len(train_y)) print('The number of val-set:', len(val_x)) # print(len(val_y)) print('The number of test-set:', len(test_x)) # print(len(test_y)) embedding = BERTEmbedding('../dataset/chinese_L-12_H-768_A-12', sequence_length=100) print('embedding_size', embedding.embedding_size) # print(embedding.model.output model = CNNModel(embedding) model.fit(train_x, train_y, val_x, val_y, batch_size=128, epochs=20, fit_kwargs={'callbacks': [tf_board_callback]}) model.evaluate(test_x, test_y) model.save('./model/cnn_bert_model')
def __init__(self, hyper_parameters): kashgari.config.use_cudnn_cell = False processor = ClassificationProcessor(multi_label=False) self.bert_embedding = BERTEmbedding( hyper_parameters['model']['bert_model_path'], task=kashgari.CLASSIFICATION, layer_nums=hyper_parameters['model']['layer_nums'], trainable=hyper_parameters['model']['trainable'], processor=processor, sequence_length='auto') print(len(self.bert_embedding._tokenizer._token_dict_inv)) self.tokenizer = self.bert_embedding.tokenizer
def test_bert(self): embedding = BERTEmbedding('chinese_L-12_H-768_A-12', sequence_length=30) self.prepare_model(embedding) self.model.fit(self.x_data, self.y_data, x_validate=self.x_eval, y_validate=self.y_eval) sentence = list('语言学包含了几种分支领域。') logging.info(self.model.embedding.tokenize(sentence)) logging.info(self.model.predict(sentence)) self.assertTrue(isinstance(self.model.predict(sentence), str)) self.assertTrue(isinstance(self.model.predict([sentence]), list))
def train(self): # filepath = "saved-model-{epoch:02d}-{acc:.2f}.hdf5" # checkpoint_callback = ModelCheckpoint(filepath, # monitor='acc', # verbose=1) x_items, train_y = self.read_message('../data/yingyangshi/train.txt') x_dev, dev_y = self.read_message('../data/yingyangshi/dev.txt') # 获取bert字向量 bert = BERTEmbedding('textclassfation/input0/chinese_L-12_H-768_A-12') model = BLSTMModel(bert) # model.build_multi_gpu_model(gpus=2) model.fit(x_items, train_y, x_dev, dev_y, epochs=2, batch_size=64) # 保存模型 model.save("../健康管理师单选分字BERT-model")
def train(self): x_items, train_y = read_message() # 获取bert字向量 bert = BERTEmbedding(self.bert_place, sequence_length=256) model = CNNModel(bert) # 输入模型训练数据 标签 步数 model.fit(x_items, train_y, epochs=200, batch_size=32, fit_kwargs={'callbacks': [tf_board_callback]}) # 保存模型 model.save("output/classification-model") model.evaluate(x_items, train_y)
def train(train_x, train_y): embedding = BERTEmbedding("bert-base-chinese", sequence_length=512) model = BLSTMModel(embedding) # model = CNNModel(embedding) # tmp = [] # for i in train_x: # tmp.append([i]) # train_x = tmp length = int(len(train_x) * 0.9) print(len(train_x[:length]), len(train_y[:length])) model.fit(train_x[:length], train_y[:length], train_x[length:], train_y[length:]) model.save('BLSTM_model')
def create_model(self): if os.path.exists(self.model_path): print("Loading bert embedding KMax CNN Model...") model = kashgari.utils.load_model(self.model_path) else: print("Creating bert embedding KMax CNN Model...") # 初始化 Embedding embed = BERTEmbedding(model_folder=self.bert_path, task=kashgari.CLASSIFICATION, sequence_length=self.max_len) # 使用 embedding 初始化模型 model = KMax_CNN_Model(embed) return model
def test_bert_model(self): embedding = BERTEmbedding(bert_path, task=kashgari.CLASSIFICATION, sequence_length=100) model = BLSTMModel(embedding=embedding) model.fit(valid_x, valid_y, epochs=1) res = model.predict(valid_x[:20]) assert True model_path = os.path.join(tempfile.gettempdir(), str(time.time())) model.save(model_path) new_model = kashgari.utils.load_model(model_path) new_res = new_model.predict(valid_x[:20]) assert np.array_equal(new_res, res)
def __init__(self, ): super(CRF_BERT, self).__init__() self.agg = True self.parms = { # 'n_vocab_char':0, 'MAXLEN': 768, 'n_entity': 4, 'embed_size': 256, 'n_lstm': 256, 'epochs': 5, 'batch_size': 512, } from kashgari.embeddings import BERTEmbedding model_path = "/sdb1/zhangle/2019/zhangle11/bert/chinese_L-12_H-768_A-12" self.embedding = BERTEmbedding(model_path, 512)
def train(self): x_xiyao, xiyao_y = self.read_message('../data/西药执业药师/train.txt') x_dev, dev_y = self.read_message('../data/西药执业药师/dev.txt') # 获取bert字向量 bert = BERTEmbedding('bert-base-chinese', sequence_length=200) model = BLSTMModel(bert) # 输入模型训练数据 标签 步数 model.fit(x_xiyao, xiyao_y, x_dev, dev_y, epochs=8, batch_size=256, fit_kwargs={'callbacks': [tf_board_callback]}) # 保存模型 model.save("../西药执业药师-model")
def __init__(self, component_config=None, model=None): super(KashgariEntityExtractor, self).__init__(component_config) bert_model_path = self.component_config.get('bert_model_path') sequence_length = self.component_config.get('sequence_length') layer_nums = self.component_config.get('layer_nums') trainable = self.component_config.get('trainable') self.labeling_model = self.component_config.get('labeling_model') self.bert_embedding = BERTEmbedding(bert_model_path, task=kashgari.LABELING, layer_nums=layer_nums, trainable=trainable, sequence_length=sequence_length) self.model = model
def __init__(self, component_config=None, model=None): super(KashgariIntentClassifier, self).__init__(component_config) bert_model_path = self.component_config.get('bert_model_path') sequence_length = self.component_config.get('sequence_length') layer_nums = self.component_config.get('layer_nums') trainable = self.component_config.get('trainable') self.classifier_model = self.component_config.get('classifier_model') self.bert_embedding = BERTEmbedding(bert_model_path, task=kashgari.CLASSIFICATION, layer_nums=layer_nums, trainable=trainable, sequence_length=sequence_length) self.tokenizer = self.bert_embedding.tokenizer self.model = model
def test_build_with_BERT_and_fit(self): from kashgari.embeddings import BERTEmbedding from tensorflow.python.keras.utils import get_file from kashgari.macros import DATA_PATH sample_bert_path = get_file( 'bert_sample_model', "http://s3.bmio.net/kashgari/bert_sample_model.tar.bz2", cache_dir=DATA_PATH, untar=True) processor = MultiOutputProcessor() embedding = BERTEmbedding(model_folder=sample_bert_path, processor=processor) m = MultiOutputModel(embedding=embedding) m.build_model(train_x, (output_1, output_2)) m.fit(train_x, (output_1, output_2), epochs=2) res = m.predict(train_x[:10]) assert len(res) == 2 assert res[0].shape == (10, 3)