def extract_embed(seq_file, model_file, preproc_file, output_path, layer_names, max_seq_length, **kwargs): set_float_cpu('float32') sr_args = SDRF.filter_args(**kwargs) if preproc_file is not None: preproc = TransformList.load(preproc_file) else: preproc = None sr = SDRF.create(seq_file, transform=preproc, **sr_args) t1 = time.time() model = SeqEmbed.load(model_file) model.build(max_seq_length) model.build_embed(layer_names) y_dim = model.embed_dim _, seq_lengths = sr.read_num_rows() sr.reset() num_seqs = len(seq_lengths) y = np.zeros((num_seqs, y_dim), dtype=float_keras()) keys = [] for i in xrange(num_seqs): ti1 = time.time() key, data = sr.read(1) ti2 = time.time() logging.info('Extracting embeddings %d/%d for %s, num_frames: %d' % (i, num_seqs, key[0], data[0].shape[0])) keys.append(key[0]) y[i] = model.predict_embed(data[0]) ti4 = time.time() logging.info( 'Elapsed time embeddings %d/%d for %s, total: %.2f read: %.2f, vae: %.2f' % (i, num_seqs, key, ti4 - ti1, ti2 - ti1, ti4 - ti2)) logging.info('Extract elapsed time: %.2f' % (time.time() - t1)) hw = DWF.create(output_path) hw.write(keys, y)
def train_embed(data_path, train_list, val_list, train_list_adapt, val_list_adapt, prepool_net_path, postpool_net_path, init_path, epochs, preproc_file, output_path, freeze_prepool, freeze_postpool_layers, **kwargs): set_float_cpu(float_keras()) if init_path is None: model, init_epoch = KML.load_checkpoint(output_path, epochs) if model is None: emb_args = SeqEmbed.filter_args(**kwargs) prepool_net = load_model_arch(prepool_net_path) postpool_net = load_model_arch(postpool_net_path) model = SeqEmbed(prepool_net, postpool_net, loss='categorical_crossentropy', **emb_args) else: kwargs['init_epoch'] = init_epoch else: logging.info('loading init model: %s' % init_path) model = KML.load(init_path) sg_args = G.filter_args(**kwargs) opt_args = KOF.filter_args(**kwargs) cb_args = KCF.filter_args(**kwargs) logging.debug(sg_args) logging.debug(opt_args) logging.debug(cb_args) if preproc_file is not None: preproc = TransformList.load(preproc_file) else: preproc = None sg = G(data_path, train_list, train_list_adapt, shuffle_seqs=True, reset_rng=False, transform=preproc, **sg_args) max_length = sg.max_seq_length gen_val = None if val_list is not None: sg_val = G(data_path, val_list, val_list_adapt, transform=preproc, shuffle_seqs=False, reset_rng=True, **sg_args) max_length = max(max_length, sg_val.max_seq_length) gen_val = data_generator(sg_val, max_length) gen_train = data_generator(sg, max_length) logging.info('max length: %d' % max_length) t1 = time.time() if freeze_prepool: model.freeze_prepool_net() if freeze_postpool_layers is not None: model.freeze_postpool_net_layers(freeze_postpool_layers) model.build(max_length) cb = KCF.create_callbacks(model, output_path, **cb_args) opt = KOF.create_optimizer(**opt_args) model.compile(optimizer=opt) h = model.fit_generator(gen_train, validation_data=gen_val, steps_per_epoch=sg.steps_per_epoch, validation_steps=sg_val.steps_per_epoch, initial_epoch=sg.cur_epoch, epochs=epochs, callbacks=cb, max_queue_size=10) logging.info('Train elapsed time: %.2f' % (time.time() - t1)) model.save(output_path + '/model')
fromfile_prefix_chars='@', description='Train sequence embeddings') parser.add_argument('--data-path', dest='data_path', required=True) parser.add_argument('--train-list', dest='train_list', required=True) parser.add_argument('--val-list', dest='val_list', default=None) parser.add_argument('--train-list-adapt', dest='train_list_adapt', required=True) parser.add_argument('--val-list-adapt', dest='val_list_adapt', default=None) parser.add_argument('--prepool-net', dest='prepool_net_path', required=True) parser.add_argument('--postpool-net', dest='postpool_net_path', required=True) parser.add_argument('--init-path', dest='init_path', default=None) parser.add_argument('--preproc-file', dest='preproc_file', default=None) parser.add_argument('--output-path', dest='output_path', required=True) SeqEmbed.add_argparse_args(parser) G.add_argparse_args(parser) KOF.add_argparse_args(parser) KCF.add_argparse_args(parser) parser.add_argument('--freeze-prepool', dest='freeze_prepool', default=False, action='store_true') parser.add_argument('--freeze-postpool-layers', dest='freeze_postpool_layers', nargs='+', default=None) parser.add_argument('--epochs', dest='epochs', default=1000, type=int) parser.add_argument('-v', '--verbose', dest='verbose', default=1, choices=[0, 1, 2, 3], type=int) args=parser.parse_args() config_logger(args.verbose) del args.verbose logging.debug(args)