Пример #1
0
class TextRecognitionDemo:
    def __init__(self, config):
        self.config = config
        self.model_path = config.get('model_path')
        self.vocab = read_vocab(config.get('vocab_path'))
        self.transform = create_list_of_transforms(config.get('transforms_list'))
        self.use_ctc = self.config.get('use_ctc')
        self.model = TextRecognitionModel(config.get('backbone_config'), len(
            self.vocab), config.get('head', {}), config.get('transformation', {}))
        if self.model_path is not None:
            self.model.load_weights(self.model_path, map_location=config.get('map_location', 'cpu'))
        self.model.eval()
        self.device = config.get('device', 'cpu')
        self.model = self.model.to(self.device)

    def __call__(self, img):
        img = self.transform(img)
        img = img[0].unsqueeze(0)
        img = img.to(self.device)
        logits, pred = self.model(img)
        if self.use_ctc:
            pred = torch.nn.functional.log_softmax(logits.detach(), dim=2)
            pred = ctc_greedy_search(pred, 0)

        return self.vocab.construct_phrase(pred[0], ignore_end_token=self.use_ctc)
class PyTorchRunner(BaseRunner):
    def load_model(self):
        self.vocab_len = len(read_vocab(self.config.get('vocab_path')))
        self.use_ctc = self.config.get('use_ctc')
        out_size = self.vocab_len + 1 if self.use_ctc else self.vocab_len
        self.model = TextRecognitionModel(
            self.config.get('backbone_config'), out_size,
            self.config.get('head', {}), self.config.get('transformation', {}))
        self.device = self.config.get('device', 'cpu')
        self.model.load_weights(self.config.get('model_path'),
                                map_location=self.device)
        self.model = self.model.to(self.device)
        self.model.eval()

    def run_model(self, img):
        img = img.to(self.device)
        logits, pred = self.model(img)
        if self.use_ctc:
            pred = torch.nn.functional.log_softmax(logits.detach(), dim=2)
            pred = ctc_greedy_search(pred, 0)
        return pred

    def openvino_transform(self):
        return False

    def reload_model(self, new_model_path):
        self.model.load_weights(new_model_path, map_location=self.device)
        self.model = self.model.to(self.device)
        self.model.eval()
Пример #3
0
 def __init__(self, work_dir, config, rank=0):
     self.rank = rank
     self.config = config
     if self.rank == 0:
         seed_worker(self.config.get('seed'))
     self.model_path = config.get('model_path')
     self.train_paths = config.get('train_paths')
     self.val_path = config.get('val_path')
     self.vocab = read_vocab(config.get('vocab_path'))
     self.train_transforms_list = config.get('train_transforms_list')
     self.val_transforms_list = config.get('val_transforms_list')
     self.loss_type = config.get('loss_type', 'NLL')
     self.total_epochs = config.get('epochs', 30)
     self.learing_rate = config.get('learning_rate', 1e-3)
     self.clip = config.get('clip_grad', 5.0)
     self.work_dir = os.path.abspath(work_dir)
     self.save_dir = os.path.join(self.work_dir, config.get('save_dir', 'model_checkpoints'))
     self.val_results_path = os.path.join(self.work_dir, 'val_results')
     self.step = 0
     self.global_step = 0
     self._test_steps = config.get('_test_steps', 1e18)
     self.epoch = 1
     self.best_val_loss = 1e18
     self.best_val_accuracy = 0.0
     self.best_val_loss_test = 1e18
     self.best_val_accuracy_test = 0.0
     self.print_freq = config.get('print_freq', 16)
     self.save_freq = config.get('save_freq', 2000)
     self.val_freq = config.get('val_freq', 5000)
     self.logs_path = os.path.join(self.work_dir, config.get('log_path', 'logs'))
     if self.rank == 0:
         self.writer = SummaryWriter(self.logs_path)
         self.writer.add_text('General info', pformat(config))
     self.device = config.get('device', 'cpu')
     self.multi_gpu = config.get('multi_gpu')
     if self.multi_gpu:
         torch.distributed.init_process_group("nccl", init_method="env://")
         self.device = torch.device(f'cuda:{self.rank}')
     self.create_dirs()
     self.load_dataset()
     self.loss = torch.nn.CTCLoss(blank=0, zero_infinity=self.config.get(
         'CTCLossZeroInf', False)) if self.loss_type == 'CTC' else None
     self.out_size = len(self.vocab) + 1 if self.loss_type == 'CTC' else len(self.vocab)
     self.model = TextRecognitionModel(config.get('backbone_config'), self.out_size,
                                       config.get('head', {}), config.get('transformation', {}))
     print(self.model)
     if self.model_path is not None:
         self.model.load_weights(self.model_path, map_location=self.device)
     self.model = self.model.to(self.device)
     if self.multi_gpu:
         if torch.cuda.device_count() > 1:
             self.model = torch.nn.parallel.DistributedDataParallel(
                 self.model, device_ids=[self.rank], output_device=self.rank)
     self.optimizer = getattr(optim, config.get('optimizer', 'Adam'))(self.model.parameters(), self.learing_rate)
     self.lr_scheduler = getattr(optim.lr_scheduler, self.config.get('scheduler', 'ReduceLROnPlateau'))(
         self.optimizer, **self.config.get('scheduler_params', {}))
     self.time = get_timestamp()
     self.use_lang_model = self.config.get("head").get("use_semantics")
     if self.use_lang_model:
         self.fasttext_model = fasttext.load_model(self.config.get("language_model_path"))
Пример #4
0
 def __init__(self, work_dir, config):
     self.config = config
     self.model_path = config.get('model_path')
     self.train_paths = config.get('train_paths')
     self.val_path = config.get('val_path')
     self.vocab = read_vocab(config.get('vocab_path'))
     self.train_transforms_list = config.get('train_transforms_list')
     self.val_transforms_list = config.get('val_transforms_list')
     self.loss_type = config.get('loss_type', 'NLL')
     self.total_epochs = config.get('epochs', 30)
     self.learing_rate = config.get('learning_rate', 1e-3)
     self.clip = config.get('clip_grad', 5.0)
     self.work_dir = os.path.abspath(work_dir)
     self.save_dir = os.path.join(
         self.work_dir, config.get('save_dir', 'model_checkpoints'))
     self.val_results_path = os.path.join(self.work_dir, 'val_results')
     self.step = 0
     self.global_step = 0
     self._test_steps = config.get('_test_steps', 1e18)
     self.epoch = 1
     self.best_val_loss = 1e18
     self.best_val_accuracy = 0.0
     self.best_val_loss_test = 1e18
     self.best_val_accuracy_test = 0.0
     self.print_freq = config.get('print_freq', 16)
     self.save_freq = config.get('save_freq', 2000)
     self.val_freq = config.get('val_freq', 5000)
     self.logs_path = os.path.join(self.work_dir,
                                   config.get('log_path', 'logs'))
     self.writer = SummaryWriter(self.logs_path)
     self.device = config.get('device', 'cpu')
     self.writer.add_text('General info', pformat(config))
     self.create_dirs()
     self.load_dataset()
     self.loss = torch.nn.CTCLoss(
         blank=0, zero_infinity=self.config.get(
             'CTCLossZeroInf', False)) if self.loss_type == 'CTC' else None
     self.out_size = len(
         self.vocab) + 1 if self.loss_type == 'CTC' else len(self.vocab)
     self.model = TextRecognitionModel(config.get('backbone_config'),
                                       self.out_size,
                                       config.get('head', {}))
     print(self.model)
     if self.model_path is not None:
         self.model.load_weights(self.model_path, map_location=self.device)
     self.model = self.model.to(self.device)
     self.optimizer = getattr(optim,
                              config.get('optimizer',
                                         'Adam'))(self.model.parameters(),
                                                  self.learing_rate)
     self.lr_scheduler = getattr(
         optim.lr_scheduler,
         self.config.get('scheduler', 'ReduceLROnPlateau'))(
             self.optimizer, **self.config.get('scheduler_params', {}))
     self.time = get_timestamp()
 def load_model(self):
     self.vocab_len = len(read_vocab(self.config.get('vocab_path')))
     self.use_ctc = self.config.get('use_ctc')
     out_size = self.vocab_len + 1 if self.use_ctc else self.vocab_len
     self.model = TextRecognitionModel(
         self.config.get('backbone_config'), out_size,
         self.config.get('head', {}), self.config.get('transformation', {}))
     self.device = self.config.get('device', 'cpu')
     self.model.load_weights(self.config.get('model_path'),
                             map_location=self.device)
     self.model = self.model.to(self.device)
     self.model.eval()
Пример #6
0
 def __init__(self, config):
     self.config = config
     self.model_path = config.get('model_path')
     self.vocab = read_vocab(config.get('vocab_path'))
     self.transform = create_list_of_transforms(config.get('transforms_list'))
     self.use_ctc = self.config.get('use_ctc')
     self.model = TextRecognitionModel(config.get('backbone_config'), len(
         self.vocab), config.get('head', {}), config.get('transformation', {}))
     if self.model_path is not None:
         self.model.load_weights(self.model_path, map_location=config.get('map_location', 'cpu'))
     self.model.eval()
     self.device = config.get('device', 'cpu')
     self.model = self.model.to(self.device)
Пример #7
0
 def __init__(self, config):
     self.config = config
     self.model_path = config.get('model_path')
     self.vocab = read_vocab(config.get('vocab_path'))
     self.use_ctc = self.config.get('use_ctc')
     self.out_size = len(self.vocab) + 1 if self.use_ctc else len(self.vocab)
     self.model = TextRecognitionModel(config.get('backbone_config'), self.out_size, config.get('head', {}))
     self.model.eval()
     if self.model_path is not None:
         self.model.load_weights(self.model_path)
     self.img_for_export = torch.rand(self.config.get('input_shape_encoder', self.config.get('input_shape')))
     if not self.use_ctc:
         self.encoder = self.model.get_encoder_wrapper(self.model)
         self.encoder.eval()
         self.decoder = self.model.get_decoder_wrapper(self.model)
         self.decoder.eval()
Пример #8
0
class Exporter:
    def __init__(self, config):
        self.config = config
        self.model_path = config.get('model_path')
        self.vocab = read_vocab(config.get('vocab_path'))
        self.use_ctc = self.config.get('use_ctc')
        self.out_size = len(self.vocab) + 1 if self.use_ctc else len(self.vocab)
        self.model = TextRecognitionModel(config.get('backbone_config'), self.out_size, config.get('head', {}))
        self.model.eval()
        if self.model_path is not None:
            self.model.load_weights(self.model_path)
        self.img_for_export = torch.rand(self.config.get('input_shape_encoder', self.config.get('input_shape')))
        if not self.use_ctc:
            self.encoder = self.model.get_encoder_wrapper(self.model)
            self.encoder.eval()
            self.decoder = self.model.get_decoder_wrapper(self.model)
            self.decoder.eval()

    def export_complete_model(self):
        model_inputs = [self.config.get('model_input_names')]
        model_outputs = self.config.get('model_output_names').split(',')
        print(f"Saving model to {self.config.get('res_model_name')}")
        res_path = os.path.join(os.path.split(self.model_path)[0], self.config.get('res_model_name'))
        torch.onnx.export(self.model, self.img_for_export, res_path,
                          opset_version=11, input_names=model_inputs, output_names=model_outputs,
                          dynamic_axes={model_inputs[0]: {0: 'batch', 1: 'channels', 2: 'height', 3: 'width'},
                                        model_outputs[0]: {0: 'batch', 1: 'max_len', 2: 'vocab_len'},
                                        })

    def export_encoder(self):
        encoder_inputs = self.config.get('encoder_input_names', ENCODER_INPUTS).split(',')
        encoder_outputs = self.config.get('encoder_output_names', ENCODER_OUTPUTS).split(',')
        res_encoder_path = os.path.join(os.path.split(self.model_path)[0], self.config.get('res_encoder_name'))
        torch.onnx.export(self.encoder, self.img_for_export, res_encoder_path,
                          opset_version=OPSET_VERSION,
                          input_names=encoder_inputs,
                          output_names=encoder_outputs,
                          )

    def export_decoder(self):
        decoder_inputs = self.config.get('decoder_input_names', DECODER_INPUTS).split(',')
        decoder_outputs = self.config.get('decoder_output_names', DECODER_OUTPUTS).split(',')
        input_shapes = self.config.get('decoder_input_shapes', [
                                       HIDDEN_SHAPE, CONTEXT_SHAPE, OUTPUT_SHAPE, FEATURES_SHAPE, TGT_SHAPE])
        inputs = [torch.rand(shape) for shape in input_shapes]
        res_decoder_path = os.path.join(os.path.split(self.model_path)[0], self.config.get('res_decoder_name'))
        torch.onnx.export(self.decoder,
                          inputs,
                          res_decoder_path,
                          opset_version=OPSET_VERSION,
                          input_names=decoder_inputs,
                          output_names=decoder_outputs
                          )

    def export_complete_model_ir(self):
        input_model = os.path.join(os.path.split(self.model_path)[0], self.config.get('res_model_name'))
        input_shape = self.config.get('input_shape')
        output_names = self.config.get('model_output_names')
        output_dir = os.path.split(self.model_path)[0]
        export_command = f"""{OPENVINO_DIR}/bin/setupvars.sh && \
        python {OPENVINO_DIR}/deployment_tools/model_optimizer/mo.py \
        --framework onnx \
        --input_model {input_model} \
        --input_shape "{input_shape}" \
        --output "{output_names}" \
        --log_level={LOG_LEVEL} \
        --output_dir {output_dir} \
        --scale_values 'imgs[255]'"""
        if self.config.get('verbose_export'):
            print(export_command)
        subprocess.run(export_command, shell=True, check=True)

    def export_encoder_ir(self):
        input_model = os.path.join(os.path.split(self.model_path)[0], self.config.get('res_encoder_name'))
        input_shape = self.config.get('input_shape_encoder')
        num_channels = input_shape[1]
        scale_values = '[255]' if num_channels == 1 else '[255,255,255]'
        reverse_channels = '' if num_channels == 1 else '--reverse_input_channels'
        input_names = self.config.get("encoder_input_names", ENCODER_INPUTS)
        output_names = self.config.get('encoder_output_names', ENCODER_OUTPUTS)
        output_dir = os.path.split(self.model_path)[0]
        export_command = f"""{OPENVINO_DIR}/bin/setupvars.sh && \
        python {OPENVINO_DIR}/deployment_tools/model_optimizer/mo.py \
        --framework onnx \
        --input_model {input_model} \
        --input_shape "{input_shape}" \
        --output "{output_names}" \
        {reverse_channels} \
        --log_level={LOG_LEVEL} \
        --output_dir {output_dir} \
        --scale_values '{input_names}{scale_values}'"""
        if self.config.get('verbose_export'):
            print(export_command)
        subprocess.run(export_command, shell=True, check=True)

    def export_decoder_ir(self):
        input_shape_decoder = self.config.get('decoder_input_shapes', [
            HIDDEN_SHAPE, CONTEXT_SHAPE, OUTPUT_SHAPE, FEATURES_SHAPE, TGT_SHAPE])
        input_shape_decoder = ', '.join(str(shape) for shape in input_shape_decoder)
        input_model = os.path.join(os.path.split(self.model_path)[0], self.config.get('res_decoder_name'))
        input_names = self.config.get('decoder_input_names', DECODER_INPUTS)
        output_names = self.config.get('decoder_output_names', DECODER_OUTPUTS)
        output_dir = os.path.split(self.model_path)[0]
        export_command = f"""{OPENVINO_DIR}/bin/setupvars.sh &&
        python {OPENVINO_DIR}/deployment_tools/model_optimizer/mo.py \
        --framework onnx \
        --input_model {input_model} \
        --input '{input_names}' \
        --input_shape '{input_shape_decoder}' \
        --log_level={LOG_LEVEL} \
        --output_dir {output_dir} \
        --output '{output_names}'"""
        if self.config.get('verbose_export'):
            print(export_command)
        subprocess.run(export_command,
                       shell=True, check=True
                       )

    def export_to_onnx_model_if_not_yet(self, model, model_type):
        """Wrapper for _export_model_if_not_yet. Exports model to ONNX

        Args:
            model (str): Path to the model file
            model_type (str): Encoder or decoder
        """
        self._export_model_if_not_yet(model, model_type, ir=False)

    def export_to_ir_model_if_not_yet(self, model, model_type):
        """Wrapper for _export_model_if_not_yet. Exports model to ONNX and to OpenVINO IR

        Args:
            model (str): Path to the model file
            model_type (str): Encoder or decoder
        """
        self._export_model_if_not_yet(model, model_type, ir=False)
        self._export_model_if_not_yet(model, model_type, ir=True)

    def _export_model_if_not_yet(self, model, model_type, ir=False):
        """Checks if given model file exists and if not runs model export

        Args:
            model (str): Path to the model file
            model_type (str): encoder or decoder
            ir (bool, optional): Export to OpenVINO IR. Defaults to False.
        """

        export_function_template = 'export_{}{}'
        if not self.use_ctc:
            assert model_type in ('encoder', 'decoder')

        result_model_exists = os.path.exists(model)
        if ir:
            model_xml = model.replace('.onnx', '.xml')
            model_bin = model.replace('.onnx', '.bin')
            result_model_exists = os.path.exists(model_xml) and os.path.exists(model_bin)
        if not result_model_exists:
            print(f'Model {model} does not exists, exporting it...')
            ir_suffix = '_ir' if ir else ''
            if not self.use_ctc:
                export_function_name = export_function_template.format(model_type, ir_suffix)
            else:
                export_function_name = export_function_template.format('complete_model', ir_suffix)
            getattr(self, export_function_name)()
Пример #9
0
class Trainer:
    def __init__(self, work_dir, config, rank=0):
        self.rank = rank
        self.config = config
        if self.rank == 0:
            seed_worker(self.config.get('seed'))
        self.model_path = config.get('model_path')
        self.train_paths = config.get('train_paths')
        self.val_path = config.get('val_path')
        self.vocab = read_vocab(config.get('vocab_path'))
        self.train_transforms_list = config.get('train_transforms_list')
        self.val_transforms_list = config.get('val_transforms_list')
        self.loss_type = config.get('loss_type', 'NLL')
        self.total_epochs = config.get('epochs', 30)
        self.learing_rate = config.get('learning_rate', 1e-3)
        self.clip = config.get('clip_grad', 5.0)
        self.work_dir = os.path.abspath(work_dir)
        self.save_dir = os.path.join(self.work_dir, config.get('save_dir', 'model_checkpoints'))
        self.val_results_path = os.path.join(self.work_dir, 'val_results')
        self.step = 0
        self.global_step = 0
        self._test_steps = config.get('_test_steps', 1e18)
        self.epoch = 1
        self.best_val_loss = 1e18
        self.best_val_accuracy = 0.0
        self.best_val_loss_test = 1e18
        self.best_val_accuracy_test = 0.0
        self.print_freq = config.get('print_freq', 16)
        self.save_freq = config.get('save_freq', 2000)
        self.val_freq = config.get('val_freq', 5000)
        self.logs_path = os.path.join(self.work_dir, config.get('log_path', 'logs'))
        if self.rank == 0:
            self.writer = SummaryWriter(self.logs_path)
            self.writer.add_text('General info', pformat(config))
        self.device = config.get('device', 'cpu')
        self.multi_gpu = config.get('multi_gpu')
        if self.multi_gpu:
            torch.distributed.init_process_group("nccl", init_method="env://")
            self.device = torch.device(f'cuda:{self.rank}')
        self.create_dirs()
        self.load_dataset()
        self.loss = torch.nn.CTCLoss(blank=0, zero_infinity=self.config.get(
            'CTCLossZeroInf', False)) if self.loss_type == 'CTC' else None
        self.out_size = len(self.vocab) + 1 if self.loss_type == 'CTC' else len(self.vocab)
        self.model = TextRecognitionModel(config.get('backbone_config'), self.out_size,
                                          config.get('head', {}), config.get('transformation', {}))
        print(self.model)
        if self.model_path is not None:
            self.model.load_weights(self.model_path, map_location=self.device)
        self.model = self.model.to(self.device)
        if self.multi_gpu:
            if torch.cuda.device_count() > 1:
                self.model = torch.nn.parallel.DistributedDataParallel(
                    self.model, device_ids=[self.rank], output_device=self.rank)
        self.optimizer = getattr(optim, config.get('optimizer', 'Adam'))(self.model.parameters(), self.learing_rate)
        self.lr_scheduler = getattr(optim.lr_scheduler, self.config.get('scheduler', 'ReduceLROnPlateau'))(
            self.optimizer, **self.config.get('scheduler_params', {}))
        self.time = get_timestamp()
        self.use_lang_model = self.config.get("head").get("use_semantics")
        if self.use_lang_model:
            self.fasttext_model = fasttext.load_model(self.config.get("language_model_path"))

    def create_dirs(self):
        os.makedirs(self.logs_path, exist_ok=True)
        print('Created logs folder: {}'.format(self.logs_path))
        os.makedirs(self.save_dir, exist_ok=True)
        print('Created save folder: {}'.format(self.save_dir))
        os.makedirs(self.val_results_path, exist_ok=True)
        print('Created validation results folder: {}'.format(self.val_results_path))

    def load_dataset(self):
        for section in self.config.get('datasets').keys():
            if section == 'train':
                train_datasets = self._load_section(section)
            elif section == 'validate':
                val_datasets = self._load_section(section)
            else:
                raise ValueError(f'Wrong section name {section}')


        pprint('Creating training transforms list: {}'.format(self.train_transforms_list), indent=4, width=120)
        batch_transform_train = create_list_of_transforms(self.train_transforms_list)

        train_dataset = ConcatDataset(train_datasets)
        if self.multi_gpu:
            train_sampler = DistributedSampler(dataset=train_dataset,
                                               shuffle=True,
                                               rank=self.rank,
                                               num_replicas=torch.cuda.device_count())
            self.train_loader = DataLoader(
                train_dataset,
                sampler=train_sampler,
                collate_fn=partial(collate_fn, self.vocab.sign2id,
                                   batch_transform=batch_transform_train,
                                   use_ctc=(self.loss_type == 'CTC')),
                num_workers=self.config.get('num_workers', 4),
                batch_size=self.config.get('batch_size', 4),
                pin_memory=True)
        else:
            train_sampler = BatchRandomSampler(dataset=train_dataset, batch_size=self.config.get('batch_size', 4))

            self.train_loader = DataLoader(
                train_dataset,
                batch_sampler=train_sampler,
                collate_fn=partial(collate_fn, self.vocab.sign2id,
                                   batch_transform=batch_transform_train,
                                   use_ctc=(self.loss_type == 'CTC')),
                num_workers=self.config.get('num_workers', 4),
                pin_memory=True)
        pprint('Creating val transforms list: {}'.format(self.val_transforms_list), indent=4, width=120)
        batch_transform_val = create_list_of_transforms(self.val_transforms_list)
        self.val_loaders = [
            DataLoader(
                ds,
                collate_fn=partial(collate_fn, self.vocab.sign2id,
                                   batch_transform=batch_transform_val, use_ctc=(self.loss_type == 'CTC')),
                batch_size=self.config.get('val_batch_size', 1),
                num_workers=self.config.get('num_workers', 4)
            )
            for ds in val_datasets
        ]
        print('num workers: ', self.config.get('num_workers'))

    def _load_section(self, section):
        datasets = []
        for param in self.config.get('datasets')[section]:
            dataset_type = param.pop('type')
            dataset = str_to_class[dataset_type](**param)
            datasets.append(dataset)
        return datasets

    def train(self):
        losses = 0.0
        accuracies = 0.0
        while self.epoch <= self.total_epochs:
            for _, target_lengths, imgs, training_gt, loss_computation_gt in self.train_loader:
                step_loss, step_accuracy = self.train_step(imgs, target_lengths, training_gt, loss_computation_gt)
                losses += step_loss
                accuracies += step_accuracy
                if self.rank == 0:
                    self.writer.add_scalar('Train loss', step_loss, self.global_step)
                    self.writer.add_scalar('Train accuracy', step_accuracy, self.global_step)

                # log message
                if self.global_step % self.print_freq == 0 and self.rank == 0:
                    total_step = len(self.train_loader)
                    print('Epoch {}, step:{}/{} {:.2f}%, Loss:{:.4f}, accuracy: {:.4f}'.format(
                        self.epoch, self.step, total_step,
                        100 * self.step / total_step,
                        losses / self.print_freq, accuracies / self.print_freq
                    ))
                    losses = 0.0
                    accuracies = 0.0
                if self.global_step % self.save_freq == 0 and self.rank == 0:
                    self.save_model('model_epoch_{}_step_{}_{}.pth'.format(
                        self.epoch,
                        self.step,
                        self.time,
                    ))
                    self.writer.add_scalar('Learning rate', self.learing_rate, self.global_step)
                if self.global_step % self.val_freq == 0 and self.rank == 0:

                    step_loss, step_accuracy = self.validate(use_gt_token=False)
                    self.writer.add_scalar('Loss/test_mode_validation', step_loss, self.global_step)
                    self.writer.add_scalar('Accuracy/test_mode_validation', step_accuracy, self.global_step)
                    if step_loss < self.best_val_loss_test:
                        self.best_val_loss_test = step_loss
                        self.save_model('loss_test_best_model_{}.pth'.format(self.time))
                    if step_accuracy > self.best_val_accuracy_test:
                        self.best_val_accuracy_test = step_accuracy
                        self.save_model('accuracy_test_best_model_{}.pth'.format(self.time))

                    self.lr_scheduler.step(step_loss)
                self.current_loss = losses
                if self.global_step >= self._test_steps:
                    return

            self.epoch += 1
            self.step = 0

    def train_step(self, imgs, target_lengths, training_gt, loss_computation_gt):
        self.optimizer.zero_grad()
        imgs = imgs.to(self.device)
        training_gt = training_gt.to(self.device)
        loss_computation_gt = loss_computation_gt.to(self.device)
        semantic_loss = None
        if self.use_lang_model:
            logits, _, semantic_info = self.model(imgs, training_gt)
            gt_strs = [self.vocab.construct_phrase(gt).replace(' ', '') for gt in loss_computation_gt]
            device = imgs.device
            lm_embs = torch.Tensor([self.fasttext_model[s] for s in gt_strs]).to(device)
            # since semantic info should be as close to the language model embedding
            # as possible, target should be 1
            semantic_loss = torch.nn.CosineEmbeddingLoss()(
                semantic_info, lm_embs, target=torch.ones(lm_embs.shape[0], device=device))
        else:
            logits, _ = self.model(imgs, training_gt)
        cut = self.loss_type != 'CTC'
        loss, accuracy = calculate_loss(logits, loss_computation_gt, target_lengths, should_cut_by_min=cut,
                                        ctc_loss=self.loss)
        self.step += 1
        self.global_step += 1
        if semantic_loss:
            loss += semantic_loss
        loss.backward()
        clip_grad_norm_(self.model.parameters(), self.clip)
        self.optimizer.step()
        return loss.item(), accuracy

    def validate(self, use_gt_token=True):
        self.model.eval()
        val_avg_loss = 0.0
        val_avg_accuracy = 0.0
        print('Validation started')
        with torch.no_grad():
            filename = VAL_FILE_NAME_TEMPLATE.format(self.val_results_path, self.epoch, self.step, self.time)
            with open(filename, 'w') as output_file:
                for loader in self.val_loaders:
                    val_loss, val_acc = 0, 0
                    for img_name, target_lengths, imgs, training_gt, loss_computation_gt in tqdm(loader):

                        imgs = imgs.to(self.device)
                        training_gt = training_gt.to(self.device)
                        loss_computation_gt = loss_computation_gt.to(self.device)
                        logits, pred = self.model(imgs, training_gt if use_gt_token else None)
                        if self.loss_type == 'CTC':
                            pred = torch.nn.functional.log_softmax(logits.detach(), dim=2)
                            pred = ctc_greedy_search(pred, blank_token=self.loss.blank)
                        for j, phrase in enumerate(pred):
                            gold_phrase_str = self.vocab.construct_phrase(
                                loss_computation_gt[j], ignore_end_token=self.config.get('use_ctc'))
                            pred_phrase_str = self.vocab.construct_phrase(phrase,
                                                                          max_len=1 + len(gold_phrase_str.split()),
                                                                          ignore_end_token=self.config.get('use_ctc')
                                                                          )
                            gold_phrase_str = gold_phrase_str.lower()
                            pred_phrase_str = pred_phrase_str.lower()
                            output_file.write(img_name[j] + '\t' + pred_phrase_str + '\t' + gold_phrase_str + '\n')
                            val_acc += int(pred_phrase_str == gold_phrase_str)
                        cut = self.loss_type != 'CTC'
                        loss, _ = calculate_loss(logits, loss_computation_gt, target_lengths,
                                                 should_cut_by_min=cut, ctc_loss=self.loss)
                        loss = loss.detach()
                        val_loss += loss
                    val_loss = val_loss / len(loader.dataset)
                    val_acc = val_acc / len(loader.dataset)
                    dataset_name = os.path.split(loader.dataset.data_path)[-1]
                    print('Epoch {}, dataset {} loss: {:.4f}'.format(
                        self.epoch, dataset_name, val_loss
                    ))
                    self.writer.add_scalar(f'Loss {dataset_name}', val_loss, self.global_step)
                    print('Epoch {}, dataset {} accuracy: {:.4f}'.format(
                        self.epoch, dataset_name, val_acc
                    ))
                    self.writer.add_scalar(f'Accuracy {dataset_name}', val_acc, self.global_step)
                    weight = len(loader.dataset) / sum(map(lambda ld: len(ld.dataset), self.val_loaders))
                    val_avg_loss += val_loss * weight
                    val_avg_accuracy += val_acc * weight
        print('Epoch {}, validation average loss: {:.4f}'.format(
            self.epoch, val_avg_loss
        ))
        print('Epoch {}, validation average accuracy: {:.4f}'.format(
            self.epoch, val_avg_accuracy
        ))
        self.save_model('validation_epoch_{}_step_{}_{}.pth'.format(self.epoch, self.step, self.time))
        self.model.train()
        return val_avg_loss, val_avg_accuracy

    def save_model(self, name):
        print('Saving model as name ', name)
        torch.save(self.model.state_dict(), os.path.join(self.save_dir, name))
Пример #10
0
class Trainer:
    def __init__(self, work_dir, config):
        self.config = config
        self.model_path = config.get('model_path')
        self.train_paths = config.get('train_paths')
        self.val_path = config.get('val_path')
        self.vocab = read_vocab(config.get('vocab_path'))
        self.train_transforms_list = config.get('train_transforms_list')
        self.val_transforms_list = config.get('val_transforms_list')
        self.loss_type = config.get('loss_type', 'NLL')
        self.total_epochs = config.get('epochs', 30)
        self.learing_rate = config.get('learning_rate', 1e-3)
        self.clip = config.get('clip_grad', 5.0)
        self.work_dir = os.path.abspath(work_dir)
        self.save_dir = os.path.join(
            self.work_dir, config.get('save_dir', 'model_checkpoints'))
        self.val_results_path = os.path.join(self.work_dir, 'val_results')
        self.step = 0
        self.global_step = 0
        self._test_steps = config.get('_test_steps', 1e18)
        self.epoch = 1
        self.best_val_loss = 1e18
        self.best_val_accuracy = 0.0
        self.best_val_loss_test = 1e18
        self.best_val_accuracy_test = 0.0
        self.print_freq = config.get('print_freq', 16)
        self.save_freq = config.get('save_freq', 2000)
        self.val_freq = config.get('val_freq', 5000)
        self.logs_path = os.path.join(self.work_dir,
                                      config.get('log_path', 'logs'))
        self.writer = SummaryWriter(self.logs_path)
        self.device = config.get('device', 'cpu')
        self.writer.add_text('General info', pformat(config))
        self.create_dirs()
        self.load_dataset()
        self.loss = torch.nn.CTCLoss(
            blank=0, zero_infinity=self.config.get(
                'CTCLossZeroInf', False)) if self.loss_type == 'CTC' else None
        self.out_size = len(
            self.vocab) + 1 if self.loss_type == 'CTC' else len(self.vocab)
        self.model = TextRecognitionModel(config.get('backbone_config'),
                                          self.out_size,
                                          config.get('head', {}))
        print(self.model)
        if self.model_path is not None:
            self.model.load_weights(self.model_path, map_location=self.device)
        self.model = self.model.to(self.device)
        self.optimizer = getattr(optim,
                                 config.get('optimizer',
                                            'Adam'))(self.model.parameters(),
                                                     self.learing_rate)
        self.lr_scheduler = getattr(
            optim.lr_scheduler,
            self.config.get('scheduler', 'ReduceLROnPlateau'))(
                self.optimizer, **self.config.get('scheduler_params', {}))
        self.time = get_timestamp()

    def create_dirs(self):
        if not os.path.exists(self.logs_path):
            os.makedirs(self.logs_path)
        print('Created logs folder: {}'.format(self.logs_path))
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)
        print('Created save folder: {}'.format(self.save_dir))
        if not os.path.exists(self.val_results_path):
            os.makedirs(self.val_results_path)
        print('Created validation results folder: {}'.format(
            self.val_results_path))

    def load_dataset(self):
        train_datasets = []
        val_datasets = []
        for param in self.config.get('datasets'):
            dataset_type = param.pop('type')
            subset = param.pop('subset')
            dataset = str_to_class[dataset_type](**param)
            if subset == 'train':
                train_datasets.append(dataset)
            elif subset == 'validate':
                val_datasets.append(dataset)

        train_dataset = ConcatDataset(train_datasets)
        train_sampler = BatchRandomSampler(dataset=train_dataset,
                                           batch_size=self.config.get(
                                               'batch_size', 4))
        pprint('Creating training transforms list: {}'.format(
            self.train_transforms_list),
               indent=4,
               width=120)
        batch_transform_train = create_list_of_transforms(
            self.train_transforms_list)
        self.train_loader = DataLoader(
            train_dataset,
            batch_sampler=train_sampler,
            collate_fn=partial(collate_fn,
                               self.vocab.sign2id,
                               batch_transform=batch_transform_train,
                               use_ctc=(self.loss_type == 'CTC')),
            num_workers=self.config.get('num_workers', 4))
        val_samplers = [
            BatchRandomSampler(dataset=ds, batch_size=1) for ds in val_datasets
        ]
        pprint('Creating val transforms list: {}'.format(
            self.val_transforms_list),
               indent=4,
               width=120)
        batch_transform_val = create_list_of_transforms(
            self.val_transforms_list)
        self.val_loaders = [
            DataLoader(ds,
                       batch_sampler=sampler,
                       collate_fn=partial(collate_fn,
                                          self.vocab.sign2id,
                                          batch_transform=batch_transform_val,
                                          use_ctc=(self.loss_type == 'CTC')),
                       num_workers=self.config.get('num_workers', 4))
            for ds, sampler in zip(val_datasets, val_samplers)
        ]
        print('num workers: ', self.config.get('num_workers'))

    def train(self):
        losses = 0.0
        accuracies = 0.0
        while self.epoch <= self.total_epochs:
            for _, target_lengths, imgs, training_gt, loss_computation_gt in self.train_loader:
                step_loss, step_accuracy = self.train_step(
                    imgs, target_lengths, training_gt, loss_computation_gt)
                losses += step_loss
                accuracies += step_accuracy
                self.writer.add_scalar('Train loss', step_loss,
                                       self.global_step)
                self.writer.add_scalar('Train accuracy', step_accuracy,
                                       self.global_step)

                # log message
                if self.global_step % self.print_freq == 0:
                    total_step = len(self.train_loader)
                    print(
                        'Epoch {}, step:{}/{} {:.2f}%, Loss:{:.4f}, accuracy: {:.4f}'
                        .format(self.epoch, self.step, total_step,
                                100 * self.step / total_step,
                                losses / self.print_freq,
                                accuracies / self.print_freq))
                    losses = 0.0
                    accuracies = 0.0
                if self.global_step % self.save_freq == 0:
                    self.save_model('model_epoch_{}_step_{}_{}.pth'.format(
                        self.epoch,
                        self.step,
                        self.time,
                    ))
                    self.writer.add_scalar('Learning rate', self.learing_rate,
                                           self.global_step)
                if self.global_step % self.val_freq == 0:

                    step_loss, step_accuracy = self.validate(
                        use_gt_token=False)
                    self.writer.add_scalar('Loss/test_mode_validation',
                                           step_loss, self.global_step)
                    self.writer.add_scalar('Accuracy/test_mode_validation',
                                           step_accuracy, self.global_step)
                    if step_loss < self.best_val_loss_test:
                        self.best_val_loss_test = step_loss
                        self.save_model('loss_test_best_model_{}.pth'.format(
                            self.time))
                    if step_accuracy > self.best_val_accuracy_test:
                        self.best_val_accuracy_test = step_accuracy
                        self.save_model(
                            'accuracy_test_best_model_{}.pth'.format(
                                self.time))

                    self.lr_scheduler.step(step_loss)
                self.current_loss = losses
                if self.global_step >= self._test_steps:
                    return

            self.epoch += 1
            self.step = 0

    def train_step(self, imgs, target_lengths, training_gt,
                   loss_computation_gt):
        self.optimizer.zero_grad()
        imgs = imgs.to(self.device)
        training_gt = training_gt.to(self.device)
        loss_computation_gt = loss_computation_gt.to(self.device)
        logits, _ = self.model(imgs, training_gt)
        cut = self.loss_type != 'CTC'
        loss, accuracy = calculate_loss(logits,
                                        loss_computation_gt,
                                        target_lengths,
                                        should_cut_by_min=cut,
                                        ctc_loss=self.loss)
        self.step += 1
        self.global_step += 1
        loss.backward()
        clip_grad_norm_(self.model.parameters(), self.clip)
        self.optimizer.step()
        return loss.item(), accuracy

    def validate(self, use_gt_token=True):
        self.model.eval()
        val_avg_loss = 0.0
        val_avg_accuracy = 0.0
        print('Validation started')
        with torch.no_grad():
            filename = VAL_FILE_NAME_TEMPLATE.format(self.val_results_path,
                                                     self.epoch, self.step,
                                                     self.time)
            with open(filename, 'w') as output_file:
                for loader in self.val_loaders:
                    val_loss, val_acc = 0, 0
                    for img_name, target_lengths, imgs, training_gt, loss_computation_gt in tqdm(
                            loader):

                        imgs = imgs.to(self.device)
                        training_gt = training_gt.to(self.device)
                        loss_computation_gt = loss_computation_gt.to(
                            self.device)
                        logits, pred = self.model(
                            imgs, training_gt if use_gt_token else None)
                        if self.loss_type == 'CTC':
                            pred = torch.nn.functional.log_softmax(
                                logits.detach(), dim=2)
                            pred = ctc_greedy_search(
                                pred, blank_token=self.loss.blank)
                        for j, phrase in enumerate(pred):
                            gold_phrase_str = self.vocab.construct_phrase(
                                loss_computation_gt[j],
                                ignore_end_token=self.config.get('use_ctc'))
                            pred_phrase_str = self.vocab.construct_phrase(
                                phrase,
                                max_len=1 + len(gold_phrase_str.split()),
                                ignore_end_token=self.config.get('use_ctc'))
                            output_file.write(img_name[j] + '\t' +
                                              pred_phrase_str + '\t' +
                                              gold_phrase_str + '\n')
                            val_acc += int(pred_phrase_str == gold_phrase_str)
                        cut = self.loss_type != 'CTC'
                        loss, _ = calculate_loss(logits,
                                                 loss_computation_gt,
                                                 target_lengths,
                                                 should_cut_by_min=cut,
                                                 ctc_loss=self.loss)
                        loss = loss.detach()
                        val_loss += loss
                    val_loss = val_loss / len(loader)
                    val_acc = val_acc / len(loader)
                    dataset_name = os.path.split(loader.dataset.data_path)[-1]
                    print('Epoch {}, dataset {} loss: {:.4f}'.format(
                        self.epoch, dataset_name, val_loss))
                    print('Epoch {}, dataset {} accuracy: {:.4f}'.format(
                        self.epoch, dataset_name, val_acc))
                    weight = len(loader) / sum(map(len, self.val_loaders))
                    val_avg_loss += val_loss * weight
                    val_avg_accuracy += val_acc * weight
        print('Epoch {}, validation average loss: {:.4f}'.format(
            self.epoch, val_avg_loss))
        print('Epoch {}, validation average accuracy: {:.4f}'.format(
            self.epoch, val_avg_accuracy))
        self.save_model('validation_epoch_{}_step_{}_{}.pth'.format(
            self.epoch, self.step, self.time))
        self.model.train()
        return val_avg_loss, val_avg_accuracy

    def save_model(self, name):
        print('Saving model as name ', name)
        torch.save(self.model.state_dict(), os.path.join(self.save_dir, name))
Пример #11
0
class Exporter:
    def __init__(self, config):
        self.config = config
        self.model_path = config.get('model_path')
        self.vocab = read_vocab(config.get('vocab_path'))
        self.use_ctc = self.config.get('use_ctc')
        self.out_size = len(self.vocab) + 1 if self.use_ctc else len(
            self.vocab)
        self.model = TextRecognitionModel(config.get('backbone_config'),
                                          self.out_size,
                                          config.get('head', {}))
        self.model.eval()
        if self.model_path is not None:
            self.model.load_weights(self.model_path)
        self.img_for_export = torch.rand(
            self.config.get('input_shape_decoder',
                            self.config.get('input_shape')))
        if not self.use_ctc:
            self.encoder = self.model.get_encoder_wrapper(self.model)
            self.encoder.eval()
            self.decoder = self.model.get_decoder_wrapper(self.model)
            self.decoder.eval()

    def export_complete_model(self):
        model_inputs = [self.config.get('model_input_names')]
        model_outputs = self.config.get('model_output_names').split(',')
        print(f"Saving model to {self.config.get('res_model_name')}")
        res_path = os.path.join(
            os.path.split(self.model_path)[0],
            self.config.get('res_model_name'))
        torch.onnx.export(self.model,
                          self.img_for_export,
                          res_path,
                          opset_version=11,
                          input_names=model_inputs,
                          output_names=model_outputs,
                          dynamic_axes={
                              model_inputs[0]: {
                                  0: 'batch',
                                  1: 'channels',
                                  2: 'height',
                                  3: 'width'
                              },
                              model_outputs[0]: {
                                  0: 'batch',
                                  1: 'max_len',
                                  2: 'vocab_len'
                              },
                          })

    def export_encoder(self):
        encoder_inputs = self.config.get('encoder_input_names',
                                         ENCODER_INPUTS).split(',')
        encoder_outputs = self.config.get('encoder_output_names',
                                          ENCODER_OUTPUTS).split(',')
        res_encoder_path = os.path.join(
            os.path.split(self.model_path)[0],
            self.config.get('res_encoder_name'))
        torch.onnx.export(
            self.encoder,
            self.img_for_export,
            res_encoder_path,
            opset_version=OPSET_VERSION,
            input_names=encoder_inputs,
            output_names=encoder_outputs,
            dynamic_axes={
                encoder_inputs[0]: {
                    0: 'batch',
                    1: 'channels',
                    2: 'height',
                    3: 'width'
                },
                encoder_outputs[0]: {
                    0: 'batch',
                    1: 'H',
                    2: 'W'
                },
            },
        )

    def export_decoder(self):
        tgt = np.array([[START_TOKEN]] * 1)
        decoder_inputs = self.config.get('decoder_input_names',
                                         DECODER_INPUTS).split(',')
        decoder_outputs = self.config.get('decoder_output_names',
                                          DECODER_OUTPUTS).split(',')
        row_enc_out = torch.rand(FEATURES_SHAPE)
        hidden = torch.randn(HIDDEN_SHAPE)
        context = torch.rand(CONTEXT_SHAPE)
        output = torch.rand(OUTPUT_SHAPE)
        res_decoder_path = os.path.join(
            os.path.split(self.model_path)[0],
            self.config.get('res_decoder_name'))
        torch.onnx.export(
            self.decoder,
            (hidden, context, output, row_enc_out,
             torch.tensor(tgt, dtype=torch.long)),
            res_decoder_path,
            opset_version=OPSET_VERSION,
            input_names=decoder_inputs,
            output_names=decoder_outputs,
            dynamic_axes={
                decoder_inputs[3]: {  # row_enc_out name should be here
                    0: 'batch',
                    1: 'H',
                    2: 'W'
                }
            })

    def export_complete_model_ir(self):
        input_model = os.path.join(
            os.path.split(self.model_path)[0],
            self.config.get('res_model_name'))
        input_shape = self.config.get('input_shape')
        output_names = self.config.get('model_output_names')
        output_dir = os.path.split(self.model_path)[0]
        export_command = f"""{OPENVINO_DIR}/bin/setupvars.sh && \
        python {OPENVINO_DIR}/deployment_tools/model_optimizer/mo.py \
        --framework onnx \
        --input_model {input_model} \
        --input_shape "{input_shape}" \
        --output "{output_names}" \
        --log_level={LOG_LEVEL} \
        --output_dir {output_dir} \
        --scale_values 'imgs[255]'"""
        if self.config.get('verbose_export'):
            print(export_command)
        subprocess.run(export_command, shell=True, check=True)

    def export_encoder_ir(self):
        input_model = os.path.join(
            os.path.split(self.model_path)[0],
            self.config.get('res_encoder_name'))
        input_shape = self.config.get('input_shape_decoder')
        output_names = self.config.get('encoder_output_names', ENCODER_OUTPUTS)
        output_dir = os.path.split(self.model_path)[0]
        export_command = f"""{OPENVINO_DIR}/bin/setupvars.sh && \
        python {OPENVINO_DIR}/deployment_tools/model_optimizer/mo.py \
        --framework onnx \
        --input_model {input_model} \
        --input_shape "{input_shape}" \
        --output "{output_names}" \
        --reverse_input_channels \
        --log_level={LOG_LEVEL} \
        --output_dir {output_dir} \
        --scale_values 'imgs[255,255,255]'"""
        if self.config.get('verbose_export'):
            print(export_command)
        subprocess.run(export_command, shell=True, check=True)

    def export_decoder_ir(self):
        input_shape_decoder = self.config.get('input_shape_decoder')
        output_h, output_w = input_shape_decoder[2] / 32, input_shape_decoder[
            3] / 32
        if self.config['backbone_config']['disable_layer_4']:
            output_h, output_w = output_h * 2, output_w * 2
        if self.config['backbone_config']['disable_layer_3']:
            output_h, output_w = output_h * 2, output_w * 2
        output_h, output_w = math.ceil(output_h), math.ceil(output_w)
        input_shape = [
            [1, self.config.get('head', {}).get('decoder_hidden_size', 512)],
            [1, self.config.get('head', {}).get('decoder_hidden_size', 512)],
            [1, self.config.get('head', {}).get('encoder_hidden_size', 256)],
            [
                1, output_h, output_w,
                self.config.get('head', {}).get('decoder_hidden_size', 512)
            ], [1, 1]
        ]
        input_shape = '{}, {}, {}, {}, {}'.format(*input_shape)
        input_model = os.path.join(
            os.path.split(self.model_path)[0],
            self.config.get('res_decoder_name'))
        input_names = self.config.get('decoder_input_names', DECODER_INPUTS)
        output_names = self.config.get('decoder_output_names', DECODER_OUTPUTS)
        output_dir = os.path.split(self.model_path)[0]
        export_command = f"""{OPENVINO_DIR}/bin/setupvars.sh &&
        python {OPENVINO_DIR}/deployment_tools/model_optimizer/mo.py \
        --framework onnx \
        --input_model {input_model} \
        --input {input_names} \
        --input_shape '{input_shape}' \
        --log_level={LOG_LEVEL} \
        --output_dir {output_dir} \
        --output {output_names}"""
        if self.config.get('verbose_export'):
            print(export_command)
        subprocess.run(export_command, shell=True, check=True)

    def export_to_onnx_model_if_not_yet(self, model, model_type):
        """Wrapper for _export_model_if_not_yet. Exports model to ONNX

        Args:
            model (str): Path to the model file
            model_type (str): Encoder or decoder
        """
        self._export_model_if_not_yet(model, model_type, ir=False)

    def export_to_ir_model_if_not_yet(self, model, model_type):
        """Wrapper for _export_model_if_not_yet. Exports model to ONNX and to OpenVINO IR

        Args:
            model (str): Path to the model file
            model_type (str): Encoder or decoder
        """
        self._export_model_if_not_yet(model, model_type, ir=False)
        self._export_model_if_not_yet(model, model_type, ir=True)

    def _export_model_if_not_yet(self, model, model_type, ir=False):
        """Checks if given model file exists and if not runs model export

        Args:
            model (str): Path to the model file
            model_type (str): encoder or decoder
            ir (bool, optional): Export to OpenVINO IR. Defaults to False.
        """

        export_function_template = 'export_{}{}'
        if not self.use_ctc:
            assert model_type in ('encoder', 'decoder')

        result_model_exists = os.path.exists(model)
        if ir:
            model_xml = model.replace('.onnx', '.xml')
            model_bin = model.replace('.onnx', '.bin')
            result_model_exists = os.path.exists(model_xml) and os.path.exists(
                model_bin)
        if not result_model_exists:
            print(f'Model {model} does not exists, exporting it...')
            ir_suffix = '_ir' if ir else ''
            if not self.use_ctc:
                export_function_name = export_function_template.format(
                    model_type, ir_suffix)
            else:
                export_function_name = export_function_template.format(
                    'complete_model', ir_suffix)
            getattr(self, export_function_name)()