Exemplo n.º 1
0
def train_embed(data_path, train_list, val_list, px_net_path, pt_net_path,
                qy_net_path, qz_net_path, init_path, epochs, preproc_file,
                output_path, freeze_embed, **kwargs):

    g = reserve_gpu()
    set_float_cpu(float_keras())

    if preproc_file is not None:
        preproc = TransformList.load(preproc_file)
    else:
        preproc = None

    sg_args = G.filter_args(**kwargs)
    sg = G(data_path,
           train_list,
           shuffle_seqs=True,
           reset_rng=False,
           transform=preproc,
           **sg_args)
    max_length = sg.max_seq_length
    gen_val = None
    if val_list is not None:
        sg_val = G(data_path,
                   val_list,
                   transform=preproc,
                   shuffle_seqs=False,
                   reset_rng=True,
                   **sg_args)
        max_length = max(max_length, sg_val.max_seq_length)
        gen_val = data_generator(sg_val, max_length)

    gen_train = data_generator(sg, max_length)

    if init_path is None:
        model, init_epoch = KML.load_checkpoint(output_path, epochs)
        if model is None:
            embed_args = VAE.filter_args(**kwargs)
            logging.debug(embed_args)
            px_net = load_model_arch(px_net_path)
            qy_net = load_model_arch(qy_net_path)
            qz_net = load_model_arch(qz_net_path)
            pt_net = load_model_arch(pt_net_path)

            model = VAE(px_net, qy_net, qz_net, pt_net, **embed_args)
        else:
            sg.cur_epoch = init_epoch
            sg.reset()
    else:
        logging.info('loading init model: %s' % init_path)
        model = KML.load(init_path)

    model.px_weight = kwargs['px_weight']
    model.pt_weight = kwargs['pt_weight']
    model.kl_qy_weight = kwargs['kl_qy_weight']
    model.kl_qz_weight = kwargs['kl_qz_weight']

    opt_args = KOF.filter_args(**kwargs)
    cb_args = KCF.filter_args(**kwargs)
    logging.debug(sg_args)
    logging.debug(opt_args)
    logging.debug(cb_args)

    logging.info('max length: %d' % max_length)

    t1 = time.time()

    if freeze_embed:
        model.prepool_net.trainable = False

    model.build(max_length)
    logging.info(time.time() - t1)

    cb = KCF.create_callbacks(model, output_path, **cb_args)
    opt = KOF.create_optimizer(**opt_args)
    model.compile(optimizer=opt)

    h = model.fit_generator(gen_train,
                            validation_data=gen_val,
                            steps_per_epoch=sg.steps_per_epoch,
                            validation_steps=sg_val.steps_per_epoch,
                            initial_epoch=sg.cur_epoch,
                            epochs=epochs,
                            callbacks=cb,
                            max_queue_size=10)

    logging.info('Train elapsed time: %.2f' % (time.time() - t1))

    model.save(output_path + '/model')
Exemplo n.º 2
0
def train_embed(data_path, train_list, val_list,
                train_list_adapt, val_list_adapt,
                prepool_net_path, postpool_net_path,
                init_path,
                epochs, 
                preproc_file, output_path,
                freeze_prepool, freeze_postpool_layers,
                **kwargs):

    set_float_cpu(float_keras())
        
    if init_path is None:
        model, init_epoch = KML.load_checkpoint(output_path, epochs)
        if model is None:
            emb_args = SeqEmbed.filter_args(**kwargs)
            prepool_net = load_model_arch(prepool_net_path)
            postpool_net = load_model_arch(postpool_net_path)

            model = SeqEmbed(prepool_net, postpool_net,
                             loss='categorical_crossentropy',
                             **emb_args)
        else:
            kwargs['init_epoch'] = init_epoch
    else:
        logging.info('loading init model: %s' % init_path)
        model = KML.load(init_path)

    
    sg_args = G.filter_args(**kwargs)
    opt_args = KOF.filter_args(**kwargs)
    cb_args = KCF.filter_args(**kwargs)
    logging.debug(sg_args)
    logging.debug(opt_args)
    logging.debug(cb_args)
    
    if preproc_file is not None:
        preproc = TransformList.load(preproc_file)
    else:
        preproc = None
        
    sg = G(data_path, train_list, train_list_adapt,
           shuffle_seqs=True, reset_rng=False,
           transform=preproc, **sg_args)
    max_length = sg.max_seq_length
    gen_val = None
    if val_list is not None:
        sg_val = G(data_path, val_list, val_list_adapt,
                    transform=preproc,
                    shuffle_seqs=False, reset_rng=True,
                    **sg_args)
        max_length = max(max_length, sg_val.max_seq_length)
        gen_val = data_generator(sg_val, max_length)

    gen_train = data_generator(sg, max_length)
    
    logging.info('max length: %d' % max_length)
    
    t1 = time.time()
    if freeze_prepool:
        model.freeze_prepool_net()

    if freeze_postpool_layers is not None:
        model.freeze_postpool_net_layers(freeze_postpool_layers)
    
    model.build(max_length)
    
    cb = KCF.create_callbacks(model, output_path, **cb_args)
    opt = KOF.create_optimizer(**opt_args)
    model.compile(optimizer=opt)
    
    h = model.fit_generator(gen_train, validation_data=gen_val,
                            steps_per_epoch=sg.steps_per_epoch,
                            validation_steps=sg_val.steps_per_epoch,
                            initial_epoch=sg.cur_epoch,
                            epochs=epochs, callbacks=cb, max_queue_size=10)
                          
    logging.info('Train elapsed time: %.2f' % (time.time() - t1))
    
    model.save(output_path + '/model')