def run_model(self, img):
     img = img.to(self.device)
     logits, pred = self.model(img)
     if self.use_ctc:
         pred = torch.nn.functional.log_softmax(logits.detach(), dim=2)
         pred = ctc_greedy_search(pred, 0)
     return pred
 def _run_ctc_head(self, img):
     logits = self.exec_net.infer(
         inputs={self.config.get('model_input_names'): img})[
                     self.config.get('model_output_names').split(',')[0]]
     pred = log_softmax(logits, axis=2)
     pred = ctc_greedy_search(pred, 0)
     return pred
Ejemplo n.º 3
0
def calculate_loss(logits,
                   targets,
                   target_lengths,
                   should_cut_by_min=False,
                   ctc_loss=None):
    """args:
        logits: probability distribution return by model
                [B, MAX_LEN, voc_size]
        targets: target formulas
                [B, MAX_LEN]
    """

    if ctc_loss is None:
        if should_cut_by_min:
            required_len = min(logits.size(1), targets.size(1))
            logits = logits.narrow(1, 0, required_len)
            targets = targets.narrow(1, 0, required_len)
            if required_len < targets.size(1):
                warn('Cutting tensor leads to tensor sized less than target')
        else:
            # narrows on 1st dim from 'start_pos' 'length' symbols
            logits = logits.narrow(1, 0, targets.size(1))

        padding = torch.ones_like(targets) * PAD_TOKEN
        mask_for_tgt = (targets != padding)
        b_size, max_len, vocab_size = logits.shape  # batch size, length of the formula, vocab size
        targets = targets.masked_select(mask_for_tgt)
        mask_for_logits = mask_for_tgt.unsqueeze(2).expand(-1, -1, vocab_size)
        logits = logits.masked_select(mask_for_logits).contiguous().view(
            -1, vocab_size)
        logits = torch.log(logits)
        assert logits.size(0) == targets.size(0)
        pred = torch.max(logits.data, 1)[1]

        accuracy = (pred == targets)
        accuracy = accuracy.cpu().numpy().astype(np.uint32)
        accuracy = np.sum(accuracy) / len(accuracy)
        accuracy = accuracy.item()
        loss = torch.nn.functional.nll_loss(logits, targets)
    else:
        logits = torch.nn.functional.log_softmax(logits, dim=2)
        max_len, b_size, vocab_size = logits.shape  # batch size, length of the formula, vocab size
        input_lengths = torch.full(size=(b_size, ),
                                   fill_value=max_len,
                                   dtype=torch.long)
        loss = ctc_loss(logits,
                        targets,
                        input_lengths=input_lengths,
                        target_lengths=target_lengths)

        predictions = ctc_greedy_search(logits.detach(), ctc_loss.blank)
        accuracy = 0
        for i in range(b_size):
            gt = targets[i][:target_lengths[i]].cpu()
            if len(predictions[i]) == len(gt) and torch.all(
                    predictions[i].eq(gt)):
                accuracy += 1
        accuracy /= b_size
    return loss, accuracy
Ejemplo n.º 4
0
    def validate(self, use_gt_token=True):
        self.model.eval()
        val_avg_loss = 0.0
        val_avg_accuracy = 0.0
        print('Validation started')
        with torch.no_grad():
            filename = VAL_FILE_NAME_TEMPLATE.format(self.val_results_path, self.epoch, self.step, self.time)
            with open(filename, 'w') as output_file:
                for loader in self.val_loaders:
                    val_loss, val_acc = 0, 0
                    for img_name, target_lengths, imgs, training_gt, loss_computation_gt in tqdm(loader):

                        imgs = imgs.to(self.device)
                        training_gt = training_gt.to(self.device)
                        loss_computation_gt = loss_computation_gt.to(self.device)
                        logits, pred = self.model(imgs, training_gt if use_gt_token else None)
                        if self.loss_type == 'CTC':
                            pred = torch.nn.functional.log_softmax(logits.detach(), dim=2)
                            pred = ctc_greedy_search(pred, blank_token=self.loss.blank)
                        for j, phrase in enumerate(pred):
                            gold_phrase_str = self.vocab.construct_phrase(
                                loss_computation_gt[j], ignore_end_token=self.config.get('use_ctc'))
                            pred_phrase_str = self.vocab.construct_phrase(phrase,
                                                                          max_len=1 + len(gold_phrase_str.split()),
                                                                          ignore_end_token=self.config.get('use_ctc')
                                                                          )
                            gold_phrase_str = gold_phrase_str.lower()
                            pred_phrase_str = pred_phrase_str.lower()
                            output_file.write(img_name[j] + '\t' + pred_phrase_str + '\t' + gold_phrase_str + '\n')
                            val_acc += int(pred_phrase_str == gold_phrase_str)
                        cut = self.loss_type != 'CTC'
                        loss, _ = calculate_loss(logits, loss_computation_gt, target_lengths,
                                                 should_cut_by_min=cut, ctc_loss=self.loss)
                        loss = loss.detach()
                        val_loss += loss
                    val_loss = val_loss / len(loader.dataset)
                    val_acc = val_acc / len(loader.dataset)
                    dataset_name = os.path.split(loader.dataset.data_path)[-1]
                    print('Epoch {}, dataset {} loss: {:.4f}'.format(
                        self.epoch, dataset_name, val_loss
                    ))
                    self.writer.add_scalar(f'Loss {dataset_name}', val_loss, self.global_step)
                    print('Epoch {}, dataset {} accuracy: {:.4f}'.format(
                        self.epoch, dataset_name, val_acc
                    ))
                    self.writer.add_scalar(f'Accuracy {dataset_name}', val_acc, self.global_step)
                    weight = len(loader.dataset) / sum(map(lambda ld: len(ld.dataset), self.val_loaders))
                    val_avg_loss += val_loss * weight
                    val_avg_accuracy += val_acc * weight
        print('Epoch {}, validation average loss: {:.4f}'.format(
            self.epoch, val_avg_loss
        ))
        print('Epoch {}, validation average accuracy: {:.4f}'.format(
            self.epoch, val_avg_accuracy
        ))
        self.save_model('validation_epoch_{}_step_{}_{}.pth'.format(self.epoch, self.step, self.time))
        self.model.train()
        return val_avg_loss, val_avg_accuracy
 def run_complete_model(self, img):
     model_output_names = get_onnx_outputs(self.model)
     model_input_names = get_onnx_inputs(self.model)[0]
     logits, _ = self.model.run(
         model_output_names,
         {model_input_names: np.array(img, dtype=np.float32)})
     pred = log_softmax(logits, axis=2)
     pred = ctc_greedy_search(pred, 0)
     return pred
Ejemplo n.º 6
0
def calculate_loss(logits,
                   targets,
                   target_lengths,
                   should_cut_by_min=False,
                   ctc_loss=None):
    """args:
        logits: probability distribution return by model
                [B, MAX_LEN, voc_size]
        targets: target formulas
                [B, MAX_LEN]
    """

    if ctc_loss is None:
        if should_cut_by_min:
            required_len = min(logits.size(1), targets.size(1))
            logits = logits.narrow(1, 0, required_len)
            targets = targets.narrow(1, 0, required_len)
            if required_len < targets.size(1):
                warn('Cutting tensor leads to tensor sized less than target')
        else:
            # narrows on 1st dim from 'start_pos' 'length' symbols
            logits = logits.narrow(1, 0, targets.size(1))
        logits = logits.permute(0, 2, 1)
        loss = torch.nn.functional.nll_loss(logits,
                                            targets,
                                            ignore_index=PAD_TOKEN)

        assert logits.size(0) == targets.size(0)
        pred = torch.max(logits.data, 1)[1]

        accuracy = (pred == targets)
        accuracy = accuracy.cpu().numpy().astype(np.uint32)
        accuracy = np.hstack(
            [accuracy[i][:l] for i, l in enumerate(target_lengths)])
        accuracy = np.sum(accuracy) / np.prod(accuracy.shape)
        accuracy = accuracy.item()

    else:
        logits = torch.nn.functional.log_softmax(logits, dim=2)
        max_len, b_size, _ = logits.shape  # batch size, length of the formula, vocab size
        input_lengths = torch.full(size=(b_size, ),
                                   fill_value=max_len,
                                   dtype=torch.long)
        loss = ctc_loss(logits,
                        targets,
                        input_lengths=input_lengths,
                        target_lengths=target_lengths)

        predictions = ctc_greedy_search(logits.detach(), ctc_loss.blank)
        accuracy = 0
        for i in range(b_size):
            gt = targets[i][:target_lengths[i]].cpu()
            if len(predictions[i]) == len(gt) and torch.all(
                    predictions[i].eq(gt)):
                accuracy += 1
        accuracy /= b_size
    return loss, accuracy
Ejemplo n.º 7
0
    def __call__(self, img):
        img = self.transform(img)
        img = img[0].unsqueeze(0)
        img = img.to(self.device)
        logits, pred = self.model(img)
        if self.use_ctc:
            pred = torch.nn.functional.log_softmax(logits.detach(), dim=2)
            pred = ctc_greedy_search(pred, 0)

        return self.vocab.construct_phrase(pred[0], ignore_end_token=self.use_ctc)
Ejemplo n.º 8
0
    def run_model(self, img):
        if torch.is_tensor(img):
            img = img.clone().detach().numpy()
        if self.use_ctc:
            logits = self.exec_net.infer(
                inputs={self.config.get('model_input_names'): img
                        })[self.config.get('model_output_names').split(',')[0]]
            pred = log_softmax(logits, axis=2)
            pred = ctc_greedy_search(pred, 0)
            return pred[0]

        enc_res = self.exec_net_encoder.infer(
            inputs={
                self.config.get('encoder_input_names', ENCODER_INPUTS).split(',')[0]:
                img
            })
        enc_out_names = self.config.get('encoder_output_names',
                                        ENCODER_OUTPUTS).split(',')
        ir_row_enc_out = enc_res[enc_out_names[0]]
        dec_states_h = enc_res[enc_out_names[1]]
        dec_states_c = enc_res[enc_out_names[2]]
        output = enc_res[enc_out_names[3]]
        dec_in_names = self.config.get('decoder_input_names',
                                       DECODER_INPUTS).split(',')
        dec_out_names = self.config.get('decoder_output_names',
                                        DECODER_OUTPUTS).split(',')
        tgt = np.array([[START_TOKEN]] * 1)
        logits = []
        for _ in range(MAX_SEQ_LEN):
            dec_res = self.exec_net_decoder.infer(
                inputs={
                    dec_in_names[0]: dec_states_h,
                    dec_in_names[1]: dec_states_c,
                    dec_in_names[2]: output,
                    dec_in_names[3]: ir_row_enc_out,
                    dec_in_names[4]: tgt
                })

            dec_states_h = dec_res[dec_out_names[0]]
            dec_states_c = dec_res[dec_out_names[1]]
            output = dec_res[dec_out_names[2]]
            logit = dec_res[dec_out_names[3]]
            logits.append(logit)

            tgt = np.reshape(np.argmax(logit, axis=1), (1, 1)).astype(np.long)
            if tgt[0][0] == END_TOKEN:
                break
        return np.argmax(np.array(logits).squeeze(1), axis=1)