Example #1
0
    def predict(self, sample=None):
        pred_sents = []
        actual_sents = []
        img_files = []

        for batch in self.valid_gen:
            batch = self.batch_to_device(batch)

            if self.beamsearch:
                translated_sentence = batch_translate_beam_search(
                    batch['img'], self.model)
            else:
                translated_sentence = translate(batch['img'], self.model)

            pred_sent = self.vocab.batch_decode(translated_sentence.tolist())
            actual_sent = self.vocab.batch_decode(batch['tgt_output'].tolist())

            img_files.extend(batch['filenames'])

            pred_sents.extend(pred_sent)
            actual_sents.extend(actual_sent)

            if sample != None and len(pred_sents) > sample:
                break

        return pred_sents, actual_sents, img_files
Example #2
0
    def predict(self, img):
        img = process_input(img, self.config['dataset']['image_height'],
                            self.config['dataset']['image_min_width'],
                            self.config['dataset']['image_max_width'])
        img = img.to(self.config['device'])

        if self.config['predictor']['beamsearch']:
            sent = translate_beam_search(img, self.model)
            s = sent
        else:
            sents = translate(img, self.model)
            s = translate(img, self.model)[0].tolist()

        s = self.vocab.decode(s)

        return s
Example #3
0
    def gen_pseudo_labels(self, outfile=None):
        pred_sents = []
        img_files = []
        probs_sents = []

        for idx, batch in enumerate(tqdm.tqdm(self.valid_gen)):
            batch = self.batch_to_device(batch)

            if self.model.seq_modeling != 'crnn':
                if self.beamsearch:
                    translated_sentence = batch_translate_beam_search(
                        batch['img'], self.model)
                    prob = None
                else:
                    translated_sentence, prob = translate(
                        batch['img'], self.model)
                pred_sent = self.vocab.batch_decode(
                    translated_sentence.tolist())
            else:
                translated_sentence, prob = translate_crnn(
                    batch['img'], self.model)
                pred_sent = self.vocab.batch_decode(
                    translated_sentence.tolist(), crnn=True)

            pred_sents.extend(pred_sent)
            img_files.extend(batch['filenames'])
            probs_sents.extend(prob)
        assert len(pred_sents) == len(img_files) and len(img_files) == len(
            probs_sents)
        with open(outfile, 'w', encoding='utf-8') as f:
            for anno in zip(img_files, pred_sents, probs_sents):
                f.write('||||'.join([anno[0], anno[1],
                                     str(float(anno[2]))]) + '\n')
Example #4
0
    def batch_predict(self, images):
        """
        param: images : list of ndarray
        """
        batch_dict, indices = self.batch_process(images)
        list_keys = [
            i for i in batch_dict
            if batch_dict[i] != batch_dict.default_factory()
        ]
        result = list([])

        for width in list_keys:
            batch = batch_dict[width]
            batch = np.asarray(batch)
            batch = torch.FloatTensor(batch)
            batch = batch.to(self.config['device'])

            if self.config['predictor']['beamsearch']:
                sent = batch_translate_beam_search(batch, model=self.model)
            else:
                sent = translate(batch, self.model).tolist()

            batch_text = self.vocab.batch_decode(sent)
            result.extend(batch_text)

        # sort text result to original coordinate
        def get_index(element):
            return element[1]

        z = zip(result, indices)
        sorted_result = sorted(z, key=get_index)
        result, _ = zip(*sorted_result)

        return result
Example #5
0
    def predict(self, img):
        img = process_input(img)
        img = img.to(self.config['device'])

        s = translate(img, self.model)[0].tolist()
        s = self.vocab.decode(s)

        return s
Example #6
0
    def predict_batch(self, batch_img, standard_size):
        batch_img = process_batch_input(batch_img,
                                        self.config['dataset']['image_height'],
                                        standard_size)
        batch_img = batch_img.to(self.config['device'])
        s = translate(batch_img, self.model).tolist()

        s_decoded = [self.vocab.decode(_) for _ in s]
        return s_decoded
    def predict(self, img):
        img = process_input(img, self.config['dataset']['image_height'],
                            self.config['dataset']['image_min_width'],
                            self.config['dataset']['image_max_width'])

        img = img.to(self.config['device'])
        #
        s, _ = translate(img, self.model)
        s = s[0].tolist()
        s = self.vocab.decode(s)
        return s
Example #8
0
    def batch_predict(self, images):
        """
        param: images : list of ndarray

        """
        batch = self.batch_process(images)
        batch = batch.to(self.config['device'])
        if self.config['predictor']['beamsearch']:
            sent = translate_beam_search(batch, self.model)
            s = sent
        else:
            sents = translate(batch, self.model).tolist()

        sequences = self.vocab.batch_decode(sents)
        return sequences
Example #9
0
    def predict(self, img):
        img = self.preprocess_input(img)
        img = np.expand_dims(img, axis=0)
        img = torch.FloatTensor(img)
        img = img.to(self.config['device'])

        if self.config['predictor']['beamsearch']:
            sent = translate_beam_search(img, self.model)
            s = sent
        else:
            s = translate(img, self.model)[0].tolist()

        s = self.vocab.decode(s)

        return s
Example #10
0
    def predict(self, img, return_prob=False):
        img = process_input(img, self.config['dataset']['image_height'],
                            self.config['dataset']['image_min_width'],
                            self.config['dataset']['image_max_width'])
        img = img.to(self.config['device'])

        if self.config['predictor']['beamsearch']:
            sent = translate_beam_search(img, self.model)
            s = sent
            prob = None
        else:
            s, prob = translate(img, self.model)
            s = s[0].tolist()
            prob = prob[0]

        s = self.vocab.decode(s)

        if return_prob:
            return s, prob
        else:
            return s
Example #11
0
    def predict(self, sample=None):
        pred_sents = []
        actual_sents = []
        img_files = []
        
        n = 0
        for batch in  self.valid_gen.gen(self.batch_size):
            translated_sentence = translate(batch['img'], self.model)
            pred_sent = self.vocab.batch_decode(translated_sentence.tolist())
            actual_sent = self.vocab.batch_decode(batch['tgt_input'].T.tolist())

            img_files.extend(batch['filenames'])

            pred_sents.extend(pred_sent)
            actual_sents.extend(actual_sent)
            n += len(actual_sents)
            
            if sample != None and n > sample:
                break

        return pred_sents, actual_sents, img_files
Example #12
0
    def predict(self, img, return_prob=False):
        img = self.preprocess_input(img)
        img = np.expand_dims(img, axis=0)
        img = torch.FloatTensor(img)
        img = img.to(self.config['device'])

        if self.config['predictor']['beamsearch']:
            sent = translate_beam_search(img, self.model)
            s = sent
            prob = None
        else:
            s, prob = translate(img, self.model)
            s = s[0].tolist()
            prob = prob[0]

        s = self.vocab.decode(s)

        if return_prob:
            return s, prob
        else:
            return s
Example #13
0
    def predict(self, sample=None):
        pred_sents = []
        actual_sents = []
        img_files = []
        probs_sents = []
        imgs_sents = []

        for idx, batch in enumerate(tqdm.tqdm(self.valid_gen)):
            batch = self.batch_to_device(batch)

            if self.model.seq_modeling != 'crnn':
                if self.beamsearch:
                    translated_sentence = batch_translate_beam_search(
                        batch['img'], self.model)
                    prob = None
                else:
                    translated_sentence, prob = translate(
                        batch['img'], self.model)
                pred_sent = self.vocab.batch_decode(
                    translated_sentence.tolist())
            else:
                translated_sentence, prob = translate_crnn(
                    batch['img'], self.model)
                pred_sent = self.vocab.batch_decode(
                    translated_sentence.tolist(), crnn=True)

            actual_sent = self.vocab.batch_decode(batch['tgt_output'].tolist())
            pred_sents.extend(pred_sent)
            actual_sents.extend(actual_sent)

            imgs_sents.extend(batch['img'])
            img_files.extend(batch['filenames'])
            probs_sents.extend(prob)

            # Visualize in tensorboard
            if idx == 0:
                try:
                    num_samples = self.config['monitor']['num_samples']
                    fig = plt.figure(figsize=(12, 15))
                    imgs_samples = imgs_sents[:num_samples]
                    preds_samples = pred_sents[:num_samples]
                    actuals_samples = actual_sents[:num_samples]
                    probs_samples = probs_sents[:num_samples]
                    for id_img in range(len(imgs_samples)):
                        img = imgs_samples[id_img]
                        img = img.permute(1, 2, 0)
                        img = img.cpu().detach().numpy()
                        ax = fig.add_subplot(num_samples,
                                             1,
                                             id_img + 1,
                                             xticks=[],
                                             yticks=[])
                        plt.imshow(img)
                        ax.set_title(
                            "LB: {} \n Pred: {:.4f}-{}".format(
                                actuals_samples[id_img], probs_samples[id_img],
                                preds_samples[id_img]),
                            color=('green' if actuals_samples[id_img]
                                   == preds_samples[id_img] else 'red'),
                            fontdict={
                                'fontsize': 18,
                                'fontweight': 'medium'
                            })

                    self.writer.add_figure('predictions vs. actuals',
                                           fig,
                                           global_step=self.iter)
                except Exception as error:
                    print(error)
                    continue

            if sample != None and len(pred_sents) > sample:
                break

        return pred_sents, actual_sents, img_files, probs_sents, imgs_sents