def __init__(self, args): self.batch_size = 1 self.show_txt = args.show_txt self.show_img = args.show_img self.prep_model_name = args.prep_model_name self.prep_model_path = args.prep_path self.ocr_name = args.ocr self.dataset_name = args.dataset if self.dataset_name == 'vgg': self.test_set = properties.vgg_text_dataset_test self.input_size = properties.input_size elif self.dataset_name == 'pos': self.test_set = properties.patch_dataset_test self.input_size = properties.input_size self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") self.prep_model = torch.load(os.path.join( self.prep_model_path, self.prep_model_name)).to(self.device) self.ocr = get_ocr_helper(self.ocr_name, is_eval=True) if self.dataset_name == 'pos': self.dataset = PatchDataset(self.test_set, pad=True) else: transform = transforms.Compose([ PadWhite(self.input_size), transforms.ToTensor(), ]) self.dataset = ImgDataset( self.test_set, transform=transform, include_name=True) self.loader_eval = torch.utils.data.DataLoader( self.dataset, batch_size=self.batch_size, num_workers=properties.num_workers)
def __init__(self, args): self.batch_size = 1 self.lr_crnn = args.lr_crnn self.lr_prep = args.lr_prep self.max_epochs = args.epoch self.inner_limit = args.inner_limit self.crnn_model_path = args.crnn_model self.sec_loss_scalar = args.scalar self.ocr_name = args.ocr self.std = args.std self.is_random_std = args.random_std torch.manual_seed(42) self.train_set = properties.pos_text_dataset_train self.validation_set = properties.pos_text_dataset_dev self.input_size = properties.input_size self.ocr = get_ocr_helper(self.ocr_name) self.char_to_index, self.index_to_char, self.vocab_size = get_char_maps( properties.char_set) self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") if self.crnn_model_path == '': self.crnn_model = CRNN(self.vocab_size, False).to(self.device) else: self.crnn_model = torch.load( properties.crnn_model_path).to(self.device) self.crnn_model.register_backward_hook(self.crnn_model.backward_hook) self.prep_model = UNet().to(self.device) self.dataset = PatchDataset( properties.patch_dataset_train, pad=True, include_name=True) self.validation_set = PatchDataset( properties.patch_dataset_dev, pad=True) self.loader_train = torch.utils.data.DataLoader( self.dataset, batch_size=self.batch_size, shuffle=True, drop_last=True, collate_fn=PatchDataset.collate) self.train_set_size = len(self.dataset) self.val_set_size = len(self.validation_set) self.primary_loss_fn = CTCLoss().to(self.device) self.secondary_loss_fn = MSELoss().to(self.device) self.optimizer_crnn = optim.Adam( self.crnn_model.parameters(), lr=self.lr_crnn, weight_decay=0) self.optimizer_prep = optim.Adam( self.prep_model.parameters(), lr=self.lr_prep, weight_decay=0)
def __init__(self, args): self.ocr_name = args.ocr self.batch_size = args.batch_size self.lr = args.lr self.epochs = args.epoch self.std = args.std self.ocr = args.ocr self.p_samples = args.p self.sec_loss_scalar = args.scalar self.train_set = properties.vgg_text_dataset_train self.validation_set = properties.vgg_text_dataset_dev self.input_size = properties.input_size self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") self.prep_model = UNet().to(self.device) self.ocr = get_ocr_helper(self.ocr) self.char_to_index, self.index_to_char, self.vocab_size = get_char_maps( properties.char_set) self.loss_fn = CTCLoss(reduction='none').to(self.device) transform = transforms.Compose([ PadWhite(self.input_size), transforms.ToTensor(), ]) self.dataset = ImgDataset( self.train_set, transform=transform, include_name=True) self.validation_set = ImgDataset( self.validation_set, transform=transform, include_name=True) self.loader_train = torch.utils.data.DataLoader( self.dataset, batch_size=self.batch_size, shuffle=True, drop_last=True) self.loader_validation = torch.utils.data.DataLoader( self.validation_set, batch_size=self.batch_size, drop_last=True) self.val_set_size = len(self.validation_set) self.train_set_size = len(self.dataset) self.optimizer = optim.Adam( self.prep_model.parameters(), lr=self.lr, weight_decay=0) self.secondary_loss_fn = MSELoss().to(self.device)
def __init__(self, args): self.batch_size = args.batch_size self.random_seed = args.random_seed self.lr = args.lr self.max_epochs = args.epoch self.ocr = args.ocr self.std = args.std self.is_random_std = args.random_std self.dataset_name = args.dataset self.decay = 0.8 self.decay_step = 10 torch.manual_seed(self.random_seed) np.random.seed(torch.initial_seed()) if self.dataset_name == 'pos': self.train_set = properties.pos_text_dataset_train self.validation_set = properties.pos_text_dataset_dev elif self.dataset_name == 'vgg': self.train_set = properties.vgg_text_dataset_train self.validation_set = properties.vgg_text_dataset_dev self.input_size = properties.input_size self.char_to_index, self.index_to_char, self.vocab_size = get_char_maps( properties.char_set) self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") self.model = CRNN(self.vocab_size, False).to(self.device) self.model.register_backward_hook(self.model.backward_hook) self.ocr = get_ocr_helper(self.ocr) transform = transforms.Compose([ PadWhite(self.input_size), transforms.ToTensor(), ]) if self.ocr is not None: noisy_transform = transforms.Compose([ PadWhite(self.input_size), transforms.ToTensor(), AddGaussianNoice( std=self.std, is_stochastic=self.is_random_std) ]) dataset = OCRDataset( self.train_set, transform=noisy_transform, ocr_helper=self.ocr) self.loader_train = torch.utils.data.DataLoader( dataset, batch_size=self.batch_size, drop_last=True, shuffle=True) validation_set = OCRDataset( self.validation_set, transform=transform, ocr_helper=self.ocr) self.loader_validation = torch.utils.data.DataLoader( validation_set, batch_size=self.batch_size, drop_last=True) self.train_set_size = len(dataset) self.val_set_size = len(validation_set) self.loss_function = CTCLoss().to(self.device) self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) self.scheduler = optim.lr_scheduler.StepLR( self.optimizer, step_size=self.decay_step, gamma=self.decay)