def __init__(self, device, model, optimizer, loss, val_loader, \ train_loader, flags, global_state): self.model = model self.optimizer = optimizer self.loss_func = loss self.train_loader = train_loader self.eval_loader = val_loader self.to_use_device = device self.flags = flags.Global self.global_state = global_state if flags.Global.loss_type == 'ctc': self.converter = CTCLabelConverter(flags) else: self.converter = AttnLabelConverter(flags)
def __init__(self, flags): super(SATRN, self).__init__() self.inplanes = 1 if flags.Global.image_shape[0] == 1 else 3 self.converter = AttnLabelConverter(flags) self.num_classes = self.converter.char_num self.d_model = flags.Transformer.model_dims self.d_ff = flags.Transformer.feedforward_dims self.num_encoder = flags.Transformer.num_encoder self.num_decoder = flags.Transformer.num_decoder self.h = flags.Transformer.num_head self.dropout = flags.Transformer.dropout_rate c = copy.deepcopy self.attn = MultiHeadedAttention(self.h, self.d_model) self.laff = LocalityAwareFeedForward(self.d_model, self.d_ff, self.d_model) self.ff = PointwiseFeedForward(self.d_model, self.d_ff) self.position1d = PositionalEncoding(self.d_model, self.dropout) self.position2d = A2DPE(self.d_model, self.dropout) self.encoder = Encoder( EncoderLayer(self.d_model, c(self.attn), c(self.laff), self.dropout), self.num_encoder) self.decoder = Decoder( DecoderLayer(self.d_model, c(self.attn), c(self.attn), c(self.ff), self.dropout), self.num_decoder) self.src_embed = nn.Sequential(ShallowCNN(self.inplanes, self.d_model), self.position2d) self.tgt_embed = nn.Sequential( Embeddings(self.num_classes, self.d_model), self.position1d) self.generator = Generator(self.d_model, self.num_classes)
def __init__(self, flags): super(DAN, self).__init__() self.input_shape = flags.Global.image_shape self.inplanes = self.input_shape[0] self.strides = [(2, 2), (1, 1), (2, 2), (1, 1), (1, 1)] self.compress_layer = flags.Architecture.compress_layer self.block = BasicBlock self.layers = flags.Architecture.layers self.maxT = flags.Global.batch_max_length self.depth = flags.CAM.depth self.num_channel = flags.CAM.num_channel self.converter = AttnLabelConverter(flags) self.num_class = self.converter.char_num self.num_steps = flags.Global.batch_max_length self.is_train = flags.Global.is_train self.feature_extractor = ResNet(self.inplanes, self.block, self.strides, self.layers, self.compress_layer) self.scales = Feature_Extractor(self.input_shape, self.block, self.strides, self.layers, self.compress_layer).Iwantshapes() self.cam_module = CAM(self.scales, self.maxT, self.depth, self.num_channel) self.decoder = DTD(self.num_class, self.num_channel)
def __init__(self, device, model, optimizer, loss, val_loader, \ train_loader, flags, global_state): self.model = model self.optimizer = optimizer self.loss_func = loss self.train_loader = train_loader self.eval_loader = val_loader self.to_use_device = device self.flags = flags.Global self.global_state = global_state if flags.Global.loss_type == 'ctc': self.converter = CTCLabelConverter(flags) elif flags.Global.loss_type == 'attn': self.converter = AttnLabelConverter(flags) else: raise Exception('Not implemented error!') logging.info(self.flags)
def __init__(self): self.config = build_config() model = build_model(self.config) device, gpu_count = build_device(self.config) optimizer = build_optimizer(self.config, model) model, optimizer, global_state = build_pretrained_weights(self.config, model, optimizer) self.device = device if self.config.Global.loss_type == 'ctc': self.converter = CTCLabelConverter(self.config) else: self.converter = AttnLabelConverter(self.config) self.model = model.to(self.device) self.keep_ratio_with_pad = self.config.TrainReader.padding self.channel = self.config.Global.image_shape[0] self.imgH = self.config.Global.image_shape[1] self.imgW = self.config.Global.image_shape[2] self.num_steps = self.config.Global.batch_max_length + 1
def __init__(self): config = build_config() model = build_model(config) device, gpu_count = build_device(config) optimizer = build_optimizer(config, model) if gpu_count > 1: model = nn.DataParallel(model) model, optimizer, global_state = build_pretrained_weights( config, model, optimizer) self.device = device if config.Global.loss_type == 'ctc': self.converter = CTCLabelConverter(config) else: self.converter = AttnLabelConverter(config) self.model = model.to(self.device) self.keep_ratio_with_pad = config.TrainReader.padding self.channel = config.Global.image_shape[0] self.imgH = config.Global.image_shape[1] self.imgW = config.Global.image_shape[2]
def __init__(self, flags): super(FAN, self).__init__() self.inplanes = 1 if flags.Global.image_shape[0] == 1 else 3 self.num_inputs = flags.SeqRNN.input_size self.num_hiddens = flags.SeqRNN.hidden_size self.converter = AttnLabelConverter(flags) self.num_classes = self.converter.char_num self.block = BasicBlock self.layers = flags.Architecture.layers self.feature_extractor = ResNet(self.inplanes, self.num_inputs, self.block, self.layers) self.reshape_layer = ReshapeLayer() self.sequence_layer = Attention(self.num_inputs, self.num_hiddens, self.num_classes)
def __init__(self, flags): super(SAR, self).__init__() self.inplanes = 1 if flags.Global.image_shape[0] == 1 else 3 self.input_size = flags.SeqRNN.input_size self.en_hidden_size = flags.SeqRNN.en_hidden_size self.de_hidden_size = flags.SeqRNN.de_hidden_size self.converter = AttnLabelConverter(flags) self.num_classes = self.converter.char_num self.block = BasicBlock self.layers = flags.Architecture.layers self.feature_extractor = ResNet(self.inplanes, self.input_size, self.block, self.layers) self.lstm_encoder = LSTMEncoder(self.input_size, self.en_hidden_size) self.lstm_decoder = LSTMDecoder(self.input_size, self.en_hidden_size, self.de_hidden_size, self.num_classes)
class Recoginizer(object): def __init__(self): config = build_config() model = build_model(config) device, gpu_count = build_device(config) optimizer = build_optimizer(config, model) if gpu_count > 1: model = nn.DataParallel(model) model, optimizer, global_state = build_pretrained_weights( config, model, optimizer) self.device = device if config.Global.loss_type == 'ctc': self.converter = CTCLabelConverter(config) else: self.converter = AttnLabelConverter(config) self.model = model.to(self.device) self.keep_ratio_with_pad = config.TrainReader.padding self.channel = config.Global.image_shape[0] self.imgH = config.Global.image_shape[1] self.imgW = config.Global.image_shape[2] def preprocess(self, image): self.transform = transforms.ToTensor() if self.keep_ratio_with_pad: w, h = image.size ratio = w / float(h) if math.ceil(ratio * self.imgH) > self.imgW: resized_image = image.resize((self.imgW, self.imgH), Image.BICUBIC) resized_image = self.transform(resized_image) imgP = resized_image.sub(0.5).div(0.5) else: resized_W = math.ceil(ratio * self.imgH) resized_image = image.resize((resized_W, self.imgH), Image.BICUBIC) resized_image = self.transform(resized_image) resized_image = resized_image.sub(0.5).div(0.5) c, h, w = resized_image.size() imgP = torch.FloatTensor(*(self.channel, self.imgH, self.imgW)).fill_(0) imgP[:, :, :w] = resized_image imgP[:, :, w:] = resized_image[:, :, w - 1].unsqueeze(2).expand( c, h, self.imgW - w) else: resized_image = image.resize((self.imgW, self.imgH), Image.BICUBIC) resized_image = self.transform(resized_image) imgP = resized_image.sub(0.5).div(0.5) imgP = imgP.unsqueeze(0) return imgP def predict(self, image_tensor): self.model.eval() with torch.no_grad(): image_tensor = image_tensor.to(self.device) outputs = self.model(image_tensor) outputs = outputs.softmax(dim=2).detach().cpu().numpy() preds_str = self.converter.decode(outputs) return preds_str def __call__(self, image): image_tensor = self.preprocess(image) preds_str = self.predict(image_tensor) return preds_str
class TrainerRec(object): def __init__(self, device, model, optimizer, loss, val_loader, \ train_loader, flags, global_state): self.model = model self.optimizer = optimizer self.loss_func = loss self.train_loader = train_loader self.eval_loader = val_loader self.to_use_device = device self.flags = flags.Global self.global_state = global_state if flags.Global.loss_type == 'ctc': self.converter = CTCLabelConverter(flags) else: self.converter = AttnLabelConverter(flags) def train(self): self.metric = RecMetric(self.converter) self.model = self.model.to(self.to_use_device) logging.info(self.to_use_device) logging.info('Training...') all_step = self.flags.num_iters if len(self.global_state) > 0: best_model = self.global_state['best_model'] global_step = self.global_state['global_step'] else: best_model = { 'best_acc': 0, 'eval_loss': 0, 'eval_acc': 0, 'norm_edit_dis': 0 } global_step = 0 try: while True: self.model.train() start_time = time.time() batch_data = self.train_loader.get_batch() cur_batch_size = batch_data['img'].shape[0] targets, targets_lengths = self.converter.encode( batch_data['label']) batch_data['targets'] = targets batch_data['targets_lengths'] = targets_lengths batch_data['img'] = batch_data['img'].to(self.to_use_device) batch_data['targets'] = batch_data['targets'].to( self.to_use_device) self.optimizer.zero_grad() if self.flags.loss_type == 'ctc': predicts = self.model.forward(batch_data['img']) else: predicts = self.model.forward( batch_data['img'], batch_data['targets'][:, :-1]) loss = self.loss_func(predicts, batch_data) loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5) self.optimizer.step() acc_dict = self.metric(predicts, batch_data['label']) acc = acc_dict['n_correct'] / cur_batch_size norm_edit_dis = 1 - acc_dict['norm_edit_dis'] / cur_batch_size if (global_step + 1) % self.flags.print_batch_step == 0: interval_batch_time = time.time() - start_time logging.info( f"[{global_step + 1} / {all_step}] - " f"loss:{loss:.4f} - " f"acc:{acc:.4f} - " f"norm_edit_dis:{norm_edit_dis:.4f} - " f"interval_batch_time:{interval_batch_time:.4f} - ") if (global_step + 1) >= self.flags.eval_batch_step and ( global_step + 1) % self.flags.eval_batch_step == 0: self.global_state['global_step'] = global_step eval_dict = self.evaluate() if eval_dict['eval_acc'] > best_model['best_acc']: best_model.update(eval_dict) self.global_state['best_model'] = best_model model_save_path = f"{self.flags.save_model_dir}/best_acc.pth" save_checkpoint(model_save_path, self.model, self.optimizer, global_state=self.global_state) if not self.flags.highest_acc_save_type: model_save_path = f"{self.flags.save_model_dir}/iter_{global_step + 1}.pth" save_checkpoint(model_save_path, self.model, self.optimizer, global_state=self.global_state) if global_step == self.flags.num_iters: print('end the training') raise StopIteration global_step += 1 except KeyboardInterrupt: save_checkpoint(os.path.join(self.flags.save_model_dir, 'final.pth'), self.model, self.optimizer, global_state=self.global_state) except: error_msg = traceback.format_exc() logging.error(error_msg) finally: for k, v in best_model.items(): logging.info(f'{k}: {v}') def evaluate(self): logging.info('start evaluate') self.model.eval() nums = 0 result_dict = {'eval_loss': 0., 'eval_acc': 0., 'norm_edit_dis': 0.} show_str = [] with torch.no_grad(): for (img, label) in tqdm(self.eval_loader): batch_data = {} batch_data['img'], batch_data['label'] = img, label targets, targets_lengths = self.converter.encode( batch_data['label']) batch_data['targets'] = targets batch_data['targets_lengths'] = targets_lengths batch_data['img'] = batch_data['img'].to(self.to_use_device) batch_data['targets'] = batch_data['targets'].to( self.to_use_device) if self.flags.loss_type == 'ctc': output = self.model.forward(batch_data['img']) else: output = self.model.forward(batch_data['img'], batch_data['targets'][:, :-1]) loss = self.loss_func(output, batch_data) nums += batch_data['img'].shape[0] acc_dict = self.metric(output, batch_data['label']) result_dict['eval_loss'] += loss.item() result_dict['eval_acc'] += acc_dict['n_correct'] result_dict['norm_edit_dis'] += acc_dict['norm_edit_dis'] show_str.extend(acc_dict['show_str']) result_dict['eval_loss'] /= len(self.eval_loader) result_dict['eval_acc'] /= nums result_dict['norm_edit_dis'] = 1 - result_dict['norm_edit_dis'] / nums logging.info(f"eval_loss:{result_dict['eval_loss']}") logging.info(f"eval_acc:{result_dict['eval_acc']}") logging.info(f"norm_edit_dis:{result_dict['norm_edit_dis']}") for s in show_str[:10]: logging.info(s) self.model.train() return result_dict