Beispiel #1
0
def main(data, vocab, training, model, output):
    # Load configs
    dir_output = output
    config = Config([data, vocab, training, model])
    config.save(dir_output)
    vocab = Vocab(config)

    # Load datasets
    train_set = DataGenerator(path_formulas=config.path_formulas_train,
            dir_images=config.dir_images_train, img_prepro=greyscale,
            max_iter=config.max_iter, bucket=config.bucket_train,
            path_matching=config.path_matching_train,
            max_len=config.max_length_formula,
            form_prepro=vocab.form_prepro)
    val_set = DataGenerator(path_formulas=config.path_formulas_val,
            dir_images=config.dir_images_val, img_prepro=greyscale,
            max_iter=config.max_iter, bucket=config.bucket_val,
            path_matching=config.path_matching_val,
            max_len=config.max_length_formula,
            form_prepro=vocab.form_prepro)

    # Define learning rate schedule
    n_batches_epoch = ((len(train_set) + config.batch_size - 1) //
                        config.batch_size)
    lr_schedule = LRSchedule(lr_init=config.lr_init,
            start_decay=config.start_decay*n_batches_epoch,
            end_decay=config.end_decay*n_batches_epoch,
            end_warm=config.end_warm*n_batches_epoch,
            lr_warm=config.lr_warm,
            lr_min=config.lr_min)

    # Build model and train
    model = Img2SeqModel(config, dir_output, vocab)
    model.build_train(config)
    model.train(config, train_set, val_set, lr_schedule)
Beispiel #2
0
def main(data, vocab, training, model, output):
    # Load configs
    dir_output = output
    config = Config([data, vocab, training, model])
    config.save(dir_output)
    vocab = Vocab(config)

    # Load datasets
    train_set = DataGenerator(path_formulas=config.path_formulas_train,
            dir_images=config.dir_images_train, img_prepro=greyscale,
            max_iter=config.max_iter, bucket=config.bucket_train,
            path_matching=config.path_matching_train,
            max_len=config.max_length_formula,
            form_prepro=vocab.form_prepro)

    
    all_img = []
    all_formula = []
    for i, (_img, _formula) in enumerate(minibatches(train_set, batch_size)):
        all_img.append(_img)
        if _formula is not None:
            _formula, _formula_length = pad_batch_formulas(
            _formula,
            vocab.id_pad,
            vocab.id_end
        )
        all_formula.append(_formula)
    
    np.save('np_formula', np.array(all_formula))
    np.save('np_img', np.array(all_img))

    print("DONE EXPORTING NUMPY FILES")
    return None
    val_set = DataGenerator(path_formulas=config.path_formulas_val,
            dir_images=config.dir_images_val, img_prepro=greyscale,
            max_iter=config.max_iter, bucket=config.bucket_val,
            path_matching=config.path_matching_val,
            max_len=config.max_length_formula,
            form_prepro=vocab.form_prepro)

    # Define learning rate schedule
    n_batches_epoch = ((len(train_set) + config.batch_size - 1) //
                        config.batch_size)
    lr_schedule = LRSchedule(lr_init=config.lr_init,
            start_decay=config.start_decay*n_batches_epoch,
            end_decay=config.end_decay*n_batches_epoch,
            end_warm=config.end_warm*n_batches_epoch,
            lr_warm=config.lr_warm,
            lr_min=config.lr_min)

    # Build model and train
    model = Img2SeqModel(config, dir_output, vocab)
    model.build_train(config)
    model.train(config, train_set, val_set, lr_schedule)
Beispiel #3
0
def main(data, vocab, training, model, output):
    # Load configs
    dir_output = output
    config = Config([data, vocab, training, model])
    config.save(dir_output)
    vocab = Vocab(config)

    # Load datasets
    train_set = DataGenerator(
        index_file=config.index_train,
        path_formulas=config.path_formulas_train,
        dir_images=config.dir_images_train,
        max_iter=config.max_iter,
        path_matching=config.path_matching_train,
        max_len=config.max_length_formula,
        form_prepro=vocab.form_prepro)
    val_set = DataGenerator(
        index_file=config.index_val,
        path_formulas=config.path_formulas_val,
        dir_images=config.dir_images_val,
        max_iter=config.max_iter,
        path_matching=config.path_matching_val,
        max_len=config.max_length_formula,
        form_prepro=vocab.form_prepro)

    # Define learning rate schedule
    n_batches_epoch = ((len(train_set) + config.batch_size - 1) //
                       config.batch_size)

    print len(train_set)
    print config.batch_size
    print n_batches_epoch

    lr_schedule = LRSchedule(lr_init=config.lr_init,
                             start_decay=config.start_decay * n_batches_epoch,
                             end_decay=config.end_decay * n_batches_epoch,
                             end_warm=config.end_warm * n_batches_epoch,
                             lr_warm=config.lr_warm,
                             lr_min=config.lr_min)

    transfer_model = config.transfer_model

    # Build model and train
    model = Img2SeqModel(config, dir_output, vocab)
    model.build_train(config)
    if transfer_model and os.path.isdir(transfer_model):
        model.restore_session(transfer_model)

    model.train(config, train_set, val_set, lr_schedule)
Beispiel #4
0
# _attn_cell_config = {
#     'cell_type': 'lstm',
#     'num_units': 12,
#     'dim_e'    : 14,
#     'dim_o'    : 16,
#     'dim_embeddings': 32,
# }

dir_output = "results/small/"
config = Config([
    "configs/data_small.json",
    "configs/vocab_small.json",
    "configs/training_small.json",
    "configs/model.json",
])
config.save(dir_output)
vocab = Vocab(config)

train_set = DataGenerator(
    path_formulas=config.path_formulas_train,
    dir_images=config.dir_images_train,
    img_prepro=greyscale,
    max_iter=config.max_iter,
    bucket=config.bucket_train,
    path_matching=config.path_matching_train,
    max_len=config.max_length_formula,
    form_prepro=vocab.form_prepro
)
val_set = DataGenerator(
    path_formulas=config.path_formulas_val,
    dir_images=config.dir_images_val,