示例#1
0
 def setUp(self):
     self.img_height = 48
     self.epochs = 1
     self.train_batch_generator = BatchGeneratorIAMHandwriting('../../fixtures/iam_handwriting/', self.img_height)
     self.test_batch_generator = BatchGeneratorIAMHandwriting('../../fixtures/iam_handwriting/', self.img_height)
     self.alphabet = self.test_batch_generator.alphabet
     self.model = ModelOcropy(self.alphabet, self.img_height)
     self.trainer = Trainer(self.model, self.train_batch_generator, self.test_batch_generator, epochs=self.epochs)
示例#2
0
    def start_train(self):
        self.progress()

        # data generators
        data_path = '/home/arsleust/projects/simple-ocr/data/bodmer'
        img_height = 48
        train_data_generator = BatchGeneratorManuscript(data_path,
                                                        img_height=img_height)
        test_data_generator = BatchGeneratorManuscript(
            data_path,
            img_height=img_height,
            sample_size=10,
            alphabet=train_data_generator.alphabet)

        # model
        self.model = ModelOcropy(train_data_generator.alphabet, img_height)
        print(self.model.summary())

        # callbacks
        str_date_time = strftime("%Y-%m-%d %H:%M:%S", gmtime())
        callbacks = []
        if True:
            if not os.path.exists("checkpoints"):
                os.mkdir("checkpoints")
            checkpoints_path = os.path.join("checkpoints",
                                            str_date_time + '.hdf5')
            callback_checkpoint = keras.callbacks.ModelCheckpoint(
                checkpoints_path,
                monitor='val_loss',
                verbose=1,
                save_best_only=True,
                save_weights_only=True)
            callbacks.append(callback_checkpoint)
        if True:
            callback_gui = GUICallback(test_data_generator, self)
            callbacks.append(callback_gui)

        # trainer
        trainer = Trainer(self.model,
                          train_data_generator,
                          test_data_generator,
                          lr=self.lrate,
                          epochs=self.nb_epoch,
                          steps_per_epochs=20,
                          callbacks=callbacks)

        trainer.train()
        print("Training done")

        self.end_train()
示例#3
0
def main():
    args = parse_args()
    model_script = load_module(args.model_path)
    cfg = init_experiment(args)
    model, model_cfg = model_script.init_model(cfg)

    # prepare dataset
    dataset = YKDataset(window_size=cfg.window_size, dic_path=cfg.DIC_PATH)
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    kf = KFold(n_splits=cfg.n_folds, shuffle=True)
    for train_indices, val_indices in kf.split(indices):
        print('t', train_indices)
        print('v', val_indices)
        train_sampler = SubsetRandomSampler(train_indices)
        val_sampler = SubsetRandomSampler(val_indices)

        trainer = Trainer(model, cfg, model_cfg, dataset, train_sampler,
                          val_sampler, cfg.optimizer)

        logger.info(f'Total Epochs: {cfg.n_epochs}')
        for epoch in range(cfg.n_epochs):
            trainer.training(epoch)
            trainer.validation(epoch)

        cfg = init_experiment(args)
示例#4
0
def main():
    args = parse_args()
    model_script = load_module(args.model_path)
    cfg = init_experiment(args)
    model, model_cfg = model_script.init_model(cfg)

    # prepare dataset
    dataset = YKDataset(window_size=cfg.window_size, dic_path=cfg.DIC_PATH)
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    split = int(np.floor(VAL_SPLIT * dataset_size))
    if SHUFFLE_DATASET:
        np.random.seed(RANDOM_SEED)
        np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_indices)
    val_sampler = SubsetRandomSampler(val_indices)

    trainer = Trainer(model, cfg, model_cfg, dataset, train_sampler,
                      val_sampler, cfg.optimizer)

    logger.info(f'Total Epochs: {cfg.n_epochs}')
    for epoch in range(cfg.n_epochs):
        trainer.training(epoch)
        trainer.validation(epoch)
示例#5
0
文件: GRCNN.py 项目: happog/FudanOCR
 def __init__(self, modelObject, opt, train_loader, val_loader):
     Trainer.__init__(self, modelObject, opt, train_loader, val_loader)
示例#6
0
 def __init__(self, modelObject, opt, train_loader, val_loader):
     Trainer.__init__(self, modelObject, opt, train_loader, val_loader)
     from model.detection_model.TextSnake_pytorch.util import global_data
     global_data._init()
示例#7
0
    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=args.train_batch,
        shuffle=True,
        num_workers=args.train_num_workers,
        collate_fn=collate_func,
        drop_last=False,
        pin_memory=True)
    setattr(train_loader, 'total_item_len', len(train_set))
    val_set = Featset(sample=1000)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.val_batch,
                                             shuffle=False,
                                             num_workers=args.val_num_workers,
                                             collate_fn=collate_func,
                                             drop_last=False,
                                             pin_memory=True)
    setattr(val_loader, 'total_item_len', len(val_set))
    cudnn.benchmark = True

    time = datetime.datetime.now()
    filename = '%s_%d%d%d_' % (args.discribe, time.month, time.day, time.hour)
    save_dir = os.path.join(Config.save_dir, filename)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    trainer = Trainer(Config, model, train_loader, val_loader)
    trainer.run()

    pdb.set_trace()
示例#8
0
def main():
    # cmd args
    parser = argparse.ArgumentParser(
        "A Python command-line tool for training ocr models")
    parser.add_argument('generator', choices=['iam', 'bodmer'])
    parser.add_argument('data_path', type=str)
    parser.add_argument('--epochs', type=int, default=1)
    parser.add_argument('--steps-epochs', type=int, default=None)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--plateau-reduce-lr', type=bool, default=True)
    parser.add_argument('--image-height', type=int, default=48)
    parser.add_argument('--levenshtein', type=bool, default=True)
    parser.add_argument('--tensorboard', type=bool, default=True)
    args = parser.parse_args()

    # parameters
    generator_type = args.generator
    img_height = args.image_height
    data_path = args.data_path
    epochs = args.epochs
    steps_per_epochs = args.steps_epochs
    lr = args.lr
    reduce_lr_on_plateau = args.plateau_reduce_lr
    levenshtein = args.levenshtein
    tensorboard = args.tensorboard

    # data generators
    if generator_type == 'iam':
        train_data_generator = BatchGeneratorIAMHandwriting(
            data_path, img_height=img_height)
        test_data_generator = BatchGeneratorIAMHandwriting(
            data_path,
            img_height=img_height,
            sample_size=100,
            alphabet=train_data_generator.alphabet)
    elif generator_type == 'bodmer':
        train_data_generator = BatchGeneratorManuscript(data_path,
                                                        img_height=img_height)
        test_data_generator = BatchGeneratorManuscript(
            data_path,
            img_height=img_height,
            sample_size=100,
            alphabet=train_data_generator.alphabet)
    else:
        raise Exception("Data generator is not defined.")

    # model
    model = ModelOcropy(train_data_generator.alphabet, img_height)
    print(model.summary())

    # callbacks
    str_date_time = strftime("%Y-%m-%d %H:%M:%S", gmtime())
    callbacks = []
    if reduce_lr_on_plateau:
        callback_lr_plateau = keras.callbacks.ReduceLROnPlateau(
            monitor='val_ctc_loss', factor=0.1, patience=4, verbose=1)
        callbacks.append(callback_lr_plateau)
    if levenshtein:
        callback_levenshtein = LevenshteinCallback(test_data_generator,
                                                   size=10)
        callbacks.append(callback_levenshtein)
    if tensorboard:
        log_path = os.path.join("logs", str_date_time)
        callback_tensorboard = keras.callbacks.TensorBoard(
            log_dir=log_path,
            batch_size=1,
        )
        callbacks.append(callback_tensorboard)
    if True:
        if not os.path.exists("checkpoints"):
            os.mkdir("checkpoints")
        checkpoints_path = os.path.join("checkpoints", str_date_time + '.hdf5')
        callback_checkpoint = keras.callbacks.ModelCheckpoint(
            checkpoints_path,
            monitor='val_loss',
            verbose=1,
            save_best_only=True,
            save_weights_only=True)
        callbacks.append(callback_checkpoint)

    # trainer
    trainer = Trainer(model,
                      train_data_generator,
                      test_data_generator,
                      lr=lr,
                      epochs=epochs,
                      steps_per_epochs=steps_per_epochs,
                      callbacks=callbacks)

    trainer.train()
示例#9
0
def train(cfg):
    Dataset = load_dataset(cfg.dataset)
    train_dataset = Dataset('train', cfg)
    val_dataset = Dataset('val', cfg)
    cfg = Config().update_dataset_info(cfg, train_dataset)
    Config().print(cfg)
    logger = Logger(cfg)

    model = SqueezeDetWithLoss(cfg)
    if cfg.load_model != '':
        if cfg.load_model.endswith('f364aa15.pth') or cfg.load_model.endswith(
                'a815701f.pth'):
            model = load_official_model(model, cfg.load_model)
        else:
            model = load_model(model, cfg.load_model)

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=cfg.lr,
                                momentum=cfg.momentum,
                                weight_decay=cfg.weight_decay)
    lr_scheduler = StepLR(optimizer, 60, gamma=0.5)

    trainer = Trainer(model, optimizer, lr_scheduler, cfg)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=cfg.batch_size,
                                               num_workers=cfg.num_workers,
                                               pin_memory=True,
                                               shuffle=True,
                                               drop_last=True)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=cfg.batch_size,
                                             num_workers=cfg.num_workers,
                                             pin_memory=True)

    metrics = trainer.metrics if cfg.no_eval else trainer.metrics + ['mAP']
    best = 1E9 if cfg.no_eval else 0
    better_than = operator.lt if cfg.no_eval else operator.gt

    for epoch in range(1, cfg.num_epochs + 1):
        train_stats = trainer.train_epoch(epoch, train_loader)
        logger.update(train_stats, phase='train', epoch=epoch)

        save_path = os.path.join(cfg.save_dir, 'model_last.pth')
        save_model(model, save_path, epoch)

        if epoch % cfg.save_intervals == 0:
            save_path = os.path.join(cfg.save_dir,
                                     'model_{}.pth'.format(epoch))
            save_model(model, save_path, epoch)

        if cfg.val_intervals > 0 and epoch % cfg.val_intervals == 0:
            val_stats = trainer.val_epoch(epoch, val_loader)
            logger.update(val_stats, phase='val', epoch=epoch)

            if not cfg.no_eval:
                aps = eval_dataset(val_dataset, save_path, cfg)
                logger.update(aps, phase='val', epoch=epoch)

            value = val_stats['loss'] if cfg.no_eval else aps['mAP']
            if better_than(value, best):
                best = value
                save_path = os.path.join(cfg.save_dir, 'model_best.pth')
                save_model(model, save_path, epoch)

        logger.plot(metrics)
        logger.print_bests(metrics)

    torch.cuda.empty_cache()
示例#10
0
    cfg.MODEL.IS_TRAIN = not args.test
    cfg.TRAIN.TUNE = args.tune
    # cfg.DATASET.NAME = args.dataset
    # cfg.DATASET.ROOT = args.dataset_dir
    # cfg.DATASET.CONT_ROOT = args.cont_dataset_dir
    # cfg.DATASET.IMAGENET = args.imagenet
    cfg.TEST.WEIGHTS = args.weights
    if cfg.MODEL.RAINDROP_TUNE:
        cfg.MODEL.RAINDROP_WEIGHTS = args.weights
    # cfg.TEST.ABLATION = args.ablation
    # cfg.TEST.MODE = args.test_mode
    # cfg.freeze()
    print(cfg)

    if cfg.MODEL.IS_TRAIN:
        trainer = Trainer(
            cfg) if not cfg.MODEL.RAINDROP_TUNE else RaindropTrainer(cfg)
        trainer.run()
    else:
        tester = Tester(cfg)
        if cfg.TEST.ABLATION:
            for i_id in list(range(250, 500)):
                for c_i_id in list(range(185, 375)):
                    for mode in list(range(1, 9)):
                        tester.do_ablation(mode=mode,
                                           img_id=i_id,
                                           c_img_id=c_i_id)
                        log.info("I: {}, C: {}, Mode:{}".format(
                            i_id, c_i_id, mode))
        else:
            # qualitative
            img_path = "datasets/ffhq/images1024x1024/07000/07042.png"
示例#11
0
parser.add_argument("--num_step", default=0, help="current step for training")
parser.add_argument("--batch_size", default=8, help="batch size for training")

parser.add_argument("--test", action="store_true")

args = parser.parse_args()

if __name__ == '__main__':
    cfg = get_cfg_defaults()
    # cfg.merge_from_file(args.base_cfg)
    # cfg.MODEL.IS_TRAIN = not args.test
    # cfg.TRAIN.TUNE = args.tune
    # cfg.DATASET.NAME = args.dataset
    # cfg.DATASET.ROOT = args.dataset_dir
    # cfg.TEST.ABLATION = args.ablation
    # cfg.freeze()

    cfg.DATASET.NAME = args.dataset
    cfg.TRAIN.IS_TRAIN = args.test
    cfg.TRAIN.START_STEP = args.num_step
    cfg.TRAIN.BATCH_SIZE = args.batch_size
    cfg.MODEL.CKPT = args.weights
    print(cfg)

    if not args.test:
        trainer = Trainer(cfg)
        trainer.run()
    else:
        tester = Tester(cfg)
        tester.eval()