def __init__(self, device, model, optimizer, loss, val_loader, \
               train_loader, flags, global_state):
     self.model = model
     self.optimizer = optimizer
     self.loss_func = loss
     self.train_loader = train_loader
     self.eval_loader = val_loader
     self.to_use_device = device
     self.flags = flags.Global
     self.global_state = global_state
     if flags.Global.loss_type == 'ctc':
         self.converter = CTCLabelConverter(flags)
     else:
         self.converter = AttnLabelConverter(flags)
예제 #2
0
    def __init__(self, flags):
        super(SATRN, self).__init__()
        self.inplanes = 1 if flags.Global.image_shape[0] == 1 else 3
        self.converter = AttnLabelConverter(flags)
        self.num_classes = self.converter.char_num
        self.d_model = flags.Transformer.model_dims
        self.d_ff = flags.Transformer.feedforward_dims
        self.num_encoder = flags.Transformer.num_encoder
        self.num_decoder = flags.Transformer.num_decoder
        self.h = flags.Transformer.num_head
        self.dropout = flags.Transformer.dropout_rate

        c = copy.deepcopy
        self.attn = MultiHeadedAttention(self.h, self.d_model)
        self.laff = LocalityAwareFeedForward(self.d_model, self.d_ff,
                                             self.d_model)
        self.ff = PointwiseFeedForward(self.d_model, self.d_ff)
        self.position1d = PositionalEncoding(self.d_model, self.dropout)
        self.position2d = A2DPE(self.d_model, self.dropout)

        self.encoder = Encoder(
            EncoderLayer(self.d_model, c(self.attn), c(self.laff),
                         self.dropout), self.num_encoder)
        self.decoder = Decoder(
            DecoderLayer(self.d_model, c(self.attn), c(self.attn), c(self.ff),
                         self.dropout), self.num_decoder)
        self.src_embed = nn.Sequential(ShallowCNN(self.inplanes, self.d_model),
                                       self.position2d)
        self.tgt_embed = nn.Sequential(
            Embeddings(self.num_classes, self.d_model), self.position1d)
        self.generator = Generator(self.d_model, self.num_classes)
예제 #3
0
    def __init__(self, flags):
        super(DAN, self).__init__()
        self.input_shape = flags.Global.image_shape
        self.inplanes = self.input_shape[0]
        self.strides = [(2, 2), (1, 1), (2, 2), (1, 1), (1, 1)]
        self.compress_layer = flags.Architecture.compress_layer
        self.block = BasicBlock
        self.layers = flags.Architecture.layers
        self.maxT = flags.Global.batch_max_length
        self.depth = flags.CAM.depth
        self.num_channel = flags.CAM.num_channel
        self.converter = AttnLabelConverter(flags)
        self.num_class = self.converter.char_num
        self.num_steps = flags.Global.batch_max_length
        self.is_train = flags.Global.is_train

        self.feature_extractor = ResNet(self.inplanes, self.block,
                                        self.strides, self.layers,
                                        self.compress_layer)
        self.scales = Feature_Extractor(self.input_shape, self.block,
                                        self.strides, self.layers,
                                        self.compress_layer).Iwantshapes()
        self.cam_module = CAM(self.scales, self.maxT, self.depth,
                              self.num_channel)
        self.decoder = DTD(self.num_class, self.num_channel)
예제 #4
0
 def __init__(self, device, model, optimizer, loss, val_loader, \
               train_loader, flags, global_state):
     self.model = model
     self.optimizer = optimizer
     self.loss_func = loss
     self.train_loader = train_loader
     self.eval_loader = val_loader
     self.to_use_device = device
     self.flags = flags.Global
     self.global_state = global_state
     if flags.Global.loss_type == 'ctc':
         self.converter = CTCLabelConverter(flags)
     elif flags.Global.loss_type == 'attn':
         self.converter = AttnLabelConverter(flags)
     else:
         raise Exception('Not implemented error!')
     logging.info(self.flags)
예제 #5
0
    def __init__(self):
        self.config = build_config()
        model = build_model(self.config)
        device, gpu_count = build_device(self.config)
        optimizer = build_optimizer(self.config, model)
        model, optimizer, global_state = build_pretrained_weights(self.config, model, optimizer)
        self.device = device
        if self.config.Global.loss_type == 'ctc':
            self.converter = CTCLabelConverter(self.config)
        else:
            self.converter = AttnLabelConverter(self.config)
        self.model = model.to(self.device)

        self.keep_ratio_with_pad = self.config.TrainReader.padding
        self.channel = self.config.Global.image_shape[0]
        self.imgH = self.config.Global.image_shape[1]
        self.imgW = self.config.Global.image_shape[2]
        self.num_steps = self.config.Global.batch_max_length + 1
예제 #6
0
    def __init__(self):
        config = build_config()
        model = build_model(config)
        device, gpu_count = build_device(config)
        optimizer = build_optimizer(config, model)
        if gpu_count > 1:
            model = nn.DataParallel(model)
        model, optimizer, global_state = build_pretrained_weights(
            config, model, optimizer)
        self.device = device
        if config.Global.loss_type == 'ctc':
            self.converter = CTCLabelConverter(config)
        else:
            self.converter = AttnLabelConverter(config)
        self.model = model.to(self.device)

        self.keep_ratio_with_pad = config.TrainReader.padding
        self.channel = config.Global.image_shape[0]
        self.imgH = config.Global.image_shape[1]
        self.imgW = config.Global.image_shape[2]
예제 #7
0
    def __init__(self, flags):
        super(FAN, self).__init__()
        self.inplanes = 1 if flags.Global.image_shape[0] == 1 else 3
        self.num_inputs = flags.SeqRNN.input_size
        self.num_hiddens = flags.SeqRNN.hidden_size
        self.converter = AttnLabelConverter(flags)
        self.num_classes = self.converter.char_num

        self.block = BasicBlock
        self.layers = flags.Architecture.layers
        self.feature_extractor = ResNet(self.inplanes, self.num_inputs,
                                        self.block, self.layers)
        self.reshape_layer = ReshapeLayer()
        self.sequence_layer = Attention(self.num_inputs, self.num_hiddens,
                                        self.num_classes)
예제 #8
0
    def __init__(self, flags):
        super(SAR, self).__init__()
        self.inplanes = 1 if flags.Global.image_shape[0] == 1 else 3
        self.input_size = flags.SeqRNN.input_size
        self.en_hidden_size = flags.SeqRNN.en_hidden_size
        self.de_hidden_size = flags.SeqRNN.de_hidden_size
        self.converter = AttnLabelConverter(flags)
        self.num_classes = self.converter.char_num

        self.block = BasicBlock
        self.layers = flags.Architecture.layers
        self.feature_extractor = ResNet(self.inplanes, self.input_size,
                                        self.block, self.layers)
        self.lstm_encoder = LSTMEncoder(self.input_size, self.en_hidden_size)
        self.lstm_decoder = LSTMDecoder(self.input_size, self.en_hidden_size,
                                        self.de_hidden_size, self.num_classes)
예제 #9
0
class Recoginizer(object):
    def __init__(self):
        config = build_config()
        model = build_model(config)
        device, gpu_count = build_device(config)
        optimizer = build_optimizer(config, model)
        if gpu_count > 1:
            model = nn.DataParallel(model)
        model, optimizer, global_state = build_pretrained_weights(
            config, model, optimizer)
        self.device = device
        if config.Global.loss_type == 'ctc':
            self.converter = CTCLabelConverter(config)
        else:
            self.converter = AttnLabelConverter(config)
        self.model = model.to(self.device)

        self.keep_ratio_with_pad = config.TrainReader.padding
        self.channel = config.Global.image_shape[0]
        self.imgH = config.Global.image_shape[1]
        self.imgW = config.Global.image_shape[2]

    def preprocess(self, image):
        self.transform = transforms.ToTensor()

        if self.keep_ratio_with_pad:
            w, h = image.size
            ratio = w / float(h)
            if math.ceil(ratio * self.imgH) > self.imgW:
                resized_image = image.resize((self.imgW, self.imgH),
                                             Image.BICUBIC)
                resized_image = self.transform(resized_image)
                imgP = resized_image.sub(0.5).div(0.5)
            else:
                resized_W = math.ceil(ratio * self.imgH)
                resized_image = image.resize((resized_W, self.imgH),
                                             Image.BICUBIC)
                resized_image = self.transform(resized_image)
                resized_image = resized_image.sub(0.5).div(0.5)

                c, h, w = resized_image.size()
                imgP = torch.FloatTensor(*(self.channel, self.imgH,
                                           self.imgW)).fill_(0)
                imgP[:, :, :w] = resized_image
                imgP[:, :, w:] = resized_image[:, :,
                                               w - 1].unsqueeze(2).expand(
                                                   c, h, self.imgW - w)
        else:
            resized_image = image.resize((self.imgW, self.imgH), Image.BICUBIC)
            resized_image = self.transform(resized_image)
            imgP = resized_image.sub(0.5).div(0.5)

        imgP = imgP.unsqueeze(0)
        return imgP

    def predict(self, image_tensor):
        self.model.eval()
        with torch.no_grad():
            image_tensor = image_tensor.to(self.device)
            outputs = self.model(image_tensor)
            outputs = outputs.softmax(dim=2).detach().cpu().numpy()
            preds_str = self.converter.decode(outputs)

        return preds_str

    def __call__(self, image):
        image_tensor = self.preprocess(image)
        preds_str = self.predict(image_tensor)

        return preds_str
예제 #10
0
class TrainerRec(object):
    def __init__(self, device, model, optimizer, loss, val_loader, \
                  train_loader, flags, global_state):
        self.model = model
        self.optimizer = optimizer
        self.loss_func = loss
        self.train_loader = train_loader
        self.eval_loader = val_loader
        self.to_use_device = device
        self.flags = flags.Global
        self.global_state = global_state
        if flags.Global.loss_type == 'ctc':
            self.converter = CTCLabelConverter(flags)
        else:
            self.converter = AttnLabelConverter(flags)

    def train(self):
        self.metric = RecMetric(self.converter)
        self.model = self.model.to(self.to_use_device)
        logging.info(self.to_use_device)
        logging.info('Training...')
        all_step = self.flags.num_iters
        if len(self.global_state) > 0:
            best_model = self.global_state['best_model']
            global_step = self.global_state['global_step']
        else:
            best_model = {
                'best_acc': 0,
                'eval_loss': 0,
                'eval_acc': 0,
                'norm_edit_dis': 0
            }
            global_step = 0
        try:
            while True:
                self.model.train()
                start_time = time.time()
                batch_data = self.train_loader.get_batch()
                cur_batch_size = batch_data['img'].shape[0]
                targets, targets_lengths = self.converter.encode(
                    batch_data['label'])
                batch_data['targets'] = targets
                batch_data['targets_lengths'] = targets_lengths
                batch_data['img'] = batch_data['img'].to(self.to_use_device)
                batch_data['targets'] = batch_data['targets'].to(
                    self.to_use_device)

                self.optimizer.zero_grad()
                if self.flags.loss_type == 'ctc':
                    predicts = self.model.forward(batch_data['img'])
                else:
                    predicts = self.model.forward(
                        batch_data['img'], batch_data['targets'][:, :-1])
                loss = self.loss_func(predicts, batch_data)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5)
                self.optimizer.step()

                acc_dict = self.metric(predicts, batch_data['label'])
                acc = acc_dict['n_correct'] / cur_batch_size
                norm_edit_dis = 1 - acc_dict['norm_edit_dis'] / cur_batch_size
                if (global_step + 1) % self.flags.print_batch_step == 0:
                    interval_batch_time = time.time() - start_time
                    logging.info(
                        f"[{global_step + 1} / {all_step}] - "
                        f"loss:{loss:.4f} - "
                        f"acc:{acc:.4f} - "
                        f"norm_edit_dis:{norm_edit_dis:.4f} - "
                        f"interval_batch_time:{interval_batch_time:.4f} - ")
                if (global_step + 1) >= self.flags.eval_batch_step and (
                        global_step + 1) % self.flags.eval_batch_step == 0:
                    self.global_state['global_step'] = global_step
                    eval_dict = self.evaluate()
                    if eval_dict['eval_acc'] > best_model['best_acc']:
                        best_model.update(eval_dict)

                        self.global_state['best_model'] = best_model
                        model_save_path = f"{self.flags.save_model_dir}/best_acc.pth"
                        save_checkpoint(model_save_path,
                                        self.model,
                                        self.optimizer,
                                        global_state=self.global_state)
                    if not self.flags.highest_acc_save_type:
                        model_save_path = f"{self.flags.save_model_dir}/iter_{global_step + 1}.pth"
                        save_checkpoint(model_save_path,
                                        self.model,
                                        self.optimizer,
                                        global_state=self.global_state)

                if global_step == self.flags.num_iters:
                    print('end the training')
                    raise StopIteration
                global_step += 1
        except KeyboardInterrupt:
            save_checkpoint(os.path.join(self.flags.save_model_dir,
                                         'final.pth'),
                            self.model,
                            self.optimizer,
                            global_state=self.global_state)
        except:
            error_msg = traceback.format_exc()
            logging.error(error_msg)
        finally:
            for k, v in best_model.items():
                logging.info(f'{k}: {v}')

    def evaluate(self):
        logging.info('start evaluate')
        self.model.eval()
        nums = 0
        result_dict = {'eval_loss': 0., 'eval_acc': 0., 'norm_edit_dis': 0.}
        show_str = []
        with torch.no_grad():
            for (img, label) in tqdm(self.eval_loader):
                batch_data = {}
                batch_data['img'], batch_data['label'] = img, label
                targets, targets_lengths = self.converter.encode(
                    batch_data['label'])
                batch_data['targets'] = targets
                batch_data['targets_lengths'] = targets_lengths
                batch_data['img'] = batch_data['img'].to(self.to_use_device)
                batch_data['targets'] = batch_data['targets'].to(
                    self.to_use_device)
                if self.flags.loss_type == 'ctc':
                    output = self.model.forward(batch_data['img'])
                else:
                    output = self.model.forward(batch_data['img'],
                                                batch_data['targets'][:, :-1])
                loss = self.loss_func(output, batch_data)

                nums += batch_data['img'].shape[0]
                acc_dict = self.metric(output, batch_data['label'])
                result_dict['eval_loss'] += loss.item()
                result_dict['eval_acc'] += acc_dict['n_correct']
                result_dict['norm_edit_dis'] += acc_dict['norm_edit_dis']
                show_str.extend(acc_dict['show_str'])

        result_dict['eval_loss'] /= len(self.eval_loader)
        result_dict['eval_acc'] /= nums
        result_dict['norm_edit_dis'] = 1 - result_dict['norm_edit_dis'] / nums
        logging.info(f"eval_loss:{result_dict['eval_loss']}")
        logging.info(f"eval_acc:{result_dict['eval_acc']}")
        logging.info(f"norm_edit_dis:{result_dict['norm_edit_dis']}")

        for s in show_str[:10]:
            logging.info(s)
        self.model.train()
        return result_dict