Beispiel #1
0
def main(folder_path,
         base_path_character=None,
         base_path_affinity=None,
         base_path_bbox=None,
         model_path=None,
         model=None):
    """
	Entry function for synthesising character and affinity heatmap on images given in a folder using a pre-trained model
	:param folder_path: Path of folder where the images are
	:param base_path_character: Path where to store the character heatmap
	:param base_path_affinity: Path where to store the affinity heatmap
	:param base_path_bbox: Path where to store the generated word_bbox overlapped on the image
	:param model_path: Path where the pre-trained model is stored
	:param model: If model is provided directly use it instead of loading it
	:return:
	"""

    os.makedirs(base_path_affinity, exist_ok=True)
    os.makedirs(base_path_character, exist_ok=True)
    os.makedirs(base_path_bbox, exist_ok=True)

    if base_path_character is None:
        base_path_character = '/'.join(
            folder_path.split('/')[:-1]) + '/target_character'
    if base_path_affinity is None:
        base_path_affinity = '/'.join(
            folder_path.split('/')[:-1]) + '/target_affinity'
    if base_path_bbox is None:
        base_path_bbox = '/'.join(folder_path.split('/')[:-1]) + '/word_bbox'

    # Dataloader to pre-process images given in the folder

    infer_dataloader = DataLoaderEval(folder_path)

    infer_dataloader = DataLoader(infer_dataloader,
                                  batch_size=2,
                                  shuffle=True,
                                  num_workers=2)

    if model is None:

        # If model has not been provided, loading it from the path provided

        if config.model_architecture == 'UNET_ResNet':
            from src.UNET_ResNet import UNetWithResnet50Encoder
            model = UNetWithResnet50Encoder()
        else:
            from src.craft_model import CRAFT
            model = CRAFT()
        model = DataParallelModel(model)

        if config.use_cuda:
            model = model.cuda()
            saved_model = torch.load(model_path)
        else:
            saved_model = torch.load(model_path, map_location='cpu')
        model.load_state_dict(saved_model['state_dict'])

    synthesize(infer_dataloader, model, base_path_affinity,
               base_path_character, base_path_bbox)
def generator_(base_target_path, model_path=None, model=None):

    from train_weak_supervision.dataloader import DataLoaderEvalOther_Datapile
    """
	Generator function to generate weighted heat-maps for weak-supervision training
	:param base_target_path: Path where to store the generated annotations
	:param model_path: If model is not provided then load from model_path
	:param model: Pytorch Model can be directly provided ofr inference
	:return: None
	"""

    os.makedirs(base_target_path, exist_ok=True)

    # Storing Predicted

    os.makedirs(base_target_path + '_predicted/affinity', exist_ok=True)
    os.makedirs(base_target_path + '_predicted/character', exist_ok=True)
    os.makedirs(base_target_path + '_predicted/word_bbox', exist_ok=True)

    # Storing Targets for next iteration

    os.makedirs(base_target_path + '_next_target/affinity', exist_ok=True)
    os.makedirs(base_target_path + '_next_target/character', exist_ok=True)
    os.makedirs(base_target_path + '_next_target/affinity_weight',
                exist_ok=True)
    os.makedirs(base_target_path + '_next_target/character_weight',
                exist_ok=True)
    os.makedirs(base_target_path + '_next_target/word_bbox', exist_ok=True)

    # Dataloader to pre-process images given in the dataset and provide annotations to generate weight

    infer_dataloader = DataLoaderEvalOther_Datapile('train')

    infer_dataloader = DataLoader(infer_dataloader,
                                  batch_size=config.batch_size['test'],
                                  shuffle=False,
                                  num_workers=config.num_workers['test'],
                                  worker_init_fn=_init_fn)

    if model is None:

        # If model has not been provided, loading it from the path provided

        if config.model_architecture == 'UNET_ResNet':
            from src.UNET_ResNet import UNetWithResnet50Encoder
            model = UNetWithResnet50Encoder()
        else:
            from src.craft_model import CRAFT
            model = CRAFT()

        model = DataParallelModel(model)

        if config.use_cuda:
            model = model.cuda()

        saved_model = torch.load(model_path)
        model.load_state_dict(saved_model['state_dict'])

    synthesize_with_score(infer_dataloader, model, base_target_path)
Beispiel #3
0
def main():

	seed()

	copyfile('train_synth/config.py', config.save_path + '/config.py')

	if config.model_architecture == 'UNET_ResNet':
		from src.UNET_ResNet import UNetWithResnet50Encoder
		model = UNetWithResnet50Encoder()
	else:
		from src.craft_model import CRAFT
		model = CRAFT()

	model_parameters = filter(lambda p: p.requires_grad, model.parameters())
	params = sum([np.prod(p.size()) for p in model_parameters])

	print('Total number of trainable parameters: ', params)

	model = DataParallelModel(model)
	loss_criterian = DataParallelCriterion(Criterian())

	if config.use_cuda:
		model = model.cuda()

	optimizer = torch.optim.Adam(model.parameters(), lr=config.lr[1])

	if config.pretrained:
		saved_model = torch.load(config.pretrained_path)
		model.load_state_dict(saved_model['state_dict'])
		optimizer.load_state_dict(saved_model['optimizer'])
		starting_no = int(config.pretrained_path.split('/')[-1].split('_')[0])
		all_loss = np.load(config.pretrained_loss_plot_training).tolist()
		print('Loaded the model')

	else:
		starting_no = 0
		all_loss = []

	all_accuracy = []

	print('Loading the dataloader')

	train_dataloader = DataLoaderSYNTH('train')
	train_dataloader = DataLoader(
		train_dataloader, batch_size=config.batch_size['train'],
		shuffle=True, num_workers=config.num_workers['train'])

	print('Loaded the dataloader')

	all_loss = train(
		train_dataloader, loss_criterian, model, optimizer, starting_no=starting_no,
		all_loss=all_loss, all_accuracy=all_accuracy)

	torch.save(
		{
			'state_dict': model.state_dict(),
			'optimizer': optimizer.state_dict()
		}, config.save_path + '/final_model.pkl')

	np.save(config.save_path + '/loss_plot_training.npy', all_loss)
	plt.plot(all_loss)
	plt.savefig(config.save_path + '/loss_plot_training.png')
	plt.clf()

	print("Saved Final Model")