def __init__(self, args):
        self.batch_size = 1
        self.show_txt = args.show_txt
        self.show_img = args.show_img
        self.prep_model_name = args.prep_model_name
        self.prep_model_path = args.prep_path
        self.ocr_name = args.ocr
        self.dataset_name = args.dataset

        if self.dataset_name == 'vgg':
            self.test_set = properties.vgg_text_dataset_test
            self.input_size = properties.input_size
        elif self.dataset_name == 'pos':
            self.test_set = properties.patch_dataset_test
            self.input_size = properties.input_size

        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        self.prep_model = torch.load(os.path.join(
            self.prep_model_path, self.prep_model_name)).to(self.device)

        self.ocr = get_ocr_helper(self.ocr_name, is_eval=True)

        if self.dataset_name == 'pos':
            self.dataset = PatchDataset(self.test_set, pad=True)
        else:
            transform = transforms.Compose([
                PadWhite(self.input_size),
                transforms.ToTensor(),
            ])
            self.dataset = ImgDataset(
                self.test_set, transform=transform, include_name=True)
            self.loader_eval = torch.utils.data.DataLoader(
                self.dataset, batch_size=self.batch_size, num_workers=properties.num_workers)
Beispiel #2
0
    def __init__(self, args):
        self.batch_size = 1
        self.lr_crnn = args.lr_crnn
        self.lr_prep = args.lr_prep
        self.max_epochs = args.epoch
        self.inner_limit = args.inner_limit
        self.crnn_model_path = args.crnn_model
        self.sec_loss_scalar = args.scalar
        self.ocr_name = args.ocr
        self.std = args.std
        self.is_random_std = args.random_std
        torch.manual_seed(42)

        self.train_set = properties.pos_text_dataset_train
        self.validation_set = properties.pos_text_dataset_dev
        self.input_size = properties.input_size

        self.ocr = get_ocr_helper(self.ocr_name)

        self.char_to_index, self.index_to_char, self.vocab_size = get_char_maps(
            properties.char_set)
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")

        if self.crnn_model_path == '':
            self.crnn_model = CRNN(self.vocab_size, False).to(self.device)
        else:
            self.crnn_model = torch.load(
                properties.crnn_model_path).to(self.device)
        self.crnn_model.register_backward_hook(self.crnn_model.backward_hook)
        self.prep_model = UNet().to(self.device)

        self.dataset = PatchDataset(
            properties.patch_dataset_train, pad=True, include_name=True)
        self.validation_set = PatchDataset(
            properties.patch_dataset_dev, pad=True)
        self.loader_train = torch.utils.data.DataLoader(
            self.dataset, batch_size=self.batch_size, shuffle=True, drop_last=True, collate_fn=PatchDataset.collate)

        self.train_set_size = len(self.dataset)
        self.val_set_size = len(self.validation_set)

        self.primary_loss_fn = CTCLoss().to(self.device)
        self.secondary_loss_fn = MSELoss().to(self.device)
        self.optimizer_crnn = optim.Adam(
            self.crnn_model.parameters(), lr=self.lr_crnn, weight_decay=0)
        self.optimizer_prep = optim.Adam(
            self.prep_model.parameters(), lr=self.lr_prep, weight_decay=0)
    def __init__(self, args):
        self.ocr_name = args.ocr
        self.batch_size = args.batch_size
        self.lr = args.lr
        self.epochs = args.epoch
        self.std = args.std
        self.ocr = args.ocr
        self.p_samples = args.p
        self.sec_loss_scalar = args.scalar

        self.train_set = properties.vgg_text_dataset_train
        self.validation_set = properties.vgg_text_dataset_dev
        self.input_size = properties.input_size

        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        self.prep_model = UNet().to(self.device)
        self.ocr = get_ocr_helper(self.ocr)

        self.char_to_index, self.index_to_char, self.vocab_size = get_char_maps(
            properties.char_set)

        self.loss_fn = CTCLoss(reduction='none').to(self.device)

        transform = transforms.Compose([
            PadWhite(self.input_size),
            transforms.ToTensor(),
        ])
        self.dataset = ImgDataset(
            self.train_set, transform=transform, include_name=True)
        self.validation_set = ImgDataset(
            self.validation_set, transform=transform, include_name=True)
        self.loader_train = torch.utils.data.DataLoader(
            self.dataset, batch_size=self.batch_size, shuffle=True, drop_last=True)
        self.loader_validation = torch.utils.data.DataLoader(
            self.validation_set, batch_size=self.batch_size, drop_last=True)

        self.val_set_size = len(self.validation_set)
        self.train_set_size = len(self.dataset)

        self.optimizer = optim.Adam(
            self.prep_model.parameters(), lr=self.lr, weight_decay=0)
        self.secondary_loss_fn = MSELoss().to(self.device)
    def __init__(self, args):
        self.batch_size = args.batch_size
        self.random_seed = args.random_seed
        self.lr = args.lr
        self.max_epochs = args.epoch
        self.ocr = args.ocr
        self.std = args.std
        self.is_random_std = args.random_std
        self.dataset_name = args.dataset

        self.decay = 0.8
        self.decay_step = 10
        torch.manual_seed(self.random_seed)
        np.random.seed(torch.initial_seed())

        if self.dataset_name == 'pos':
            self.train_set = properties.pos_text_dataset_train
            self.validation_set = properties.pos_text_dataset_dev
        elif self.dataset_name == 'vgg':
            self.train_set = properties.vgg_text_dataset_train
            self.validation_set = properties.vgg_text_dataset_dev

        self.input_size = properties.input_size

        self.char_to_index, self.index_to_char, self.vocab_size = get_char_maps(
            properties.char_set)
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")

        self.model = CRNN(self.vocab_size, False).to(self.device)
        self.model.register_backward_hook(self.model.backward_hook)

        self.ocr = get_ocr_helper(self.ocr)

        transform = transforms.Compose([
            PadWhite(self.input_size),
            transforms.ToTensor(),
        ])
        if self.ocr is not None:
            noisy_transform = transforms.Compose([
                PadWhite(self.input_size),
                transforms.ToTensor(),
                AddGaussianNoice(
                    std=self.std, is_stochastic=self.is_random_std)
            ])
            dataset = OCRDataset(
                self.train_set, transform=noisy_transform, ocr_helper=self.ocr)
            self.loader_train = torch.utils.data.DataLoader(
                dataset, batch_size=self.batch_size, drop_last=True, shuffle=True)

            validation_set = OCRDataset(
                self.validation_set, transform=transform, ocr_helper=self.ocr)
            self.loader_validation = torch.utils.data.DataLoader(
                validation_set, batch_size=self.batch_size, drop_last=True)

        self.train_set_size = len(dataset)
        self.val_set_size = len(validation_set)

        self.loss_function = CTCLoss().to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        self.scheduler = optim.lr_scheduler.StepLR(
            self.optimizer, step_size=self.decay_step, gamma=self.decay)