def train_embed(data_path, train_list, val_list, px_net_path, pt_net_path, qy_net_path, qz_net_path, init_path, epochs, preproc_file, output_path, freeze_embed, **kwargs): g = reserve_gpu() set_float_cpu(float_keras()) if preproc_file is not None: preproc = TransformList.load(preproc_file) else: preproc = None sg_args = G.filter_args(**kwargs) sg = G(data_path, train_list, shuffle_seqs=True, reset_rng=False, transform=preproc, **sg_args) max_length = sg.max_seq_length gen_val = None if val_list is not None: sg_val = G(data_path, val_list, transform=preproc, shuffle_seqs=False, reset_rng=True, **sg_args) max_length = max(max_length, sg_val.max_seq_length) gen_val = data_generator(sg_val, max_length) gen_train = data_generator(sg, max_length) if init_path is None: model, init_epoch = KML.load_checkpoint(output_path, epochs) if model is None: embed_args = VAE.filter_args(**kwargs) logging.debug(embed_args) px_net = load_model_arch(px_net_path) qy_net = load_model_arch(qy_net_path) qz_net = load_model_arch(qz_net_path) pt_net = load_model_arch(pt_net_path) model = VAE(px_net, qy_net, qz_net, pt_net, **embed_args) else: sg.cur_epoch = init_epoch sg.reset() else: logging.info('loading init model: %s' % init_path) model = KML.load(init_path) model.px_weight = kwargs['px_weight'] model.pt_weight = kwargs['pt_weight'] model.kl_qy_weight = kwargs['kl_qy_weight'] model.kl_qz_weight = kwargs['kl_qz_weight'] opt_args = KOF.filter_args(**kwargs) cb_args = KCF.filter_args(**kwargs) logging.debug(sg_args) logging.debug(opt_args) logging.debug(cb_args) logging.info('max length: %d' % max_length) t1 = time.time() if freeze_embed: model.prepool_net.trainable = False model.build(max_length) logging.info(time.time() - t1) cb = KCF.create_callbacks(model, output_path, **cb_args) opt = KOF.create_optimizer(**opt_args) model.compile(optimizer=opt) h = model.fit_generator(gen_train, validation_data=gen_val, steps_per_epoch=sg.steps_per_epoch, validation_steps=sg_val.steps_per_epoch, initial_epoch=sg.cur_epoch, epochs=epochs, callbacks=cb, max_queue_size=10) logging.info('Train elapsed time: %.2f' % (time.time() - t1)) model.save(output_path + '/model')
def train_embed(data_path, train_list, val_list, enc_net_path, pt_net_path, loss, init_path, epochs, preproc_file, output_path, freeze_enc, freeze_enc_layers, freeze_pt_layers, **kwargs): g = reserve_gpu() set_float_cpu(float_keras()) if init_path is None: model, cur_epoch = KML.load_checkpoint(output_path, epochs) if model is None: emb_args = SeqEmbed.filter_args(**kwargs) enc_net = load_model_arch(enc_net_path) pt_net = load_model_arch(pt_net_path) model = SeqEmbed(enc_net, pt_net, loss=loss, **emb_args) else: kwargs['init_epoch'] = cur_epoch + 1 else: logging.info('loading init model: %s' % init_path) model = KML.load(init_path) sg_args = G.filter_args(**kwargs) opt_args = KOF.filter_args(**kwargs) cb_args = KCF.filter_args(**kwargs) logging.debug(sg_args) logging.debug(opt_args) logging.debug(cb_args) if preproc_file is not None: preproc = TransformList.load(preproc_file) else: preproc = None sg = G(data_path, train_list, shuffle_seqs=True, reset_rng=False, transform=preproc, **sg_args) max_length = sg.max_seq_length gen_val = None if val_list is not None: sg_val = G(data_path, val_list, transform=preproc, shuffle_seqs=False, reset_rng=True, **sg_args) max_length = max(max_length, sg_val.max_seq_length) gen_val = data_generator(sg_val, max_length) sys.stdout.flush() gen_train = data_generator(sg, max_length) logging.info('max length: %d' % max_length) t1 = time.time() if freeze_enc: model.freeze_enc_net() if freeze_enc_layers is not None: model.freeze_enc_layers(freeze_enc_layers) if freeze_pt_layers is not None: model.freeze_pt_layers(freeze_pt_layers) model.build(max_length) cb = KCF.create_callbacks(model, output_path, **cb_args) opt = KOF.create_optimizer(**opt_args) model.compile(metrics=['accuracy'], optimizer=opt) h = model.fit_generator(gen_train, validation_data=gen_val, steps_per_epoch=sg.steps_per_epoch, validation_steps=sg_val.steps_per_epoch, initial_epoch=sg.cur_epoch, epochs=epochs, callbacks=cb, max_queue_size=10) logging.info('Train elapsed time: %.2f' % (time.time() - t1)) model.save(output_path + '/model')