예제 #1
0
    def test_epoch_end(self, outputs):
        if "preds" in outputs[0]:
            output_test_predictions_file = os.path.join(self.hparams.output_dir, "test_predictions_" +
                                                        str(self.count_valid_epoch) + ".txt")
            output_test_targets_file = os.path.join(self.hparams.output_dir, "test_targets_" +
                                                        str(self.count_valid_epoch) + ".txt")
            # write predictions and targets for later rouge evaluation.
            with open(output_test_predictions_file, "w") as p_writer, open(output_test_targets_file, "w") as t_writer:
                for output_batch in outputs:
                    p_writer.writelines(convert_text(s) + "\n" for s in output_batch["preds"])
                    t_writer.writelines(convert_text(s) + "\n" for s in output_batch["target"])
                p_writer.close()
                t_writer.close()

            #bleu_info = eval_bleu_sents(output_test_targets_file, output_test_predictions_file)
            bleu_info = eval_sacre_bleu(output_test_targets_file, output_test_predictions_file)
            #bleu_info = eval_bleu(output_test_targets_file, output_test_predictions_file)
            moverScore = eval_mover_score(output_test_targets_file, output_test_predictions_file)


            logger.info("valid epoch: %s", self.count_valid_epoch)
            logger.info("%s bleu_info: %s", self.count_valid_epoch, bleu_info)
            logger.info("%s mover score: %s", self.count_valid_epoch, moverScore)

            self.count_valid_epoch += 1

        else:
            logger.info('not in')

        return self.check_validation_end(outputs)
예제 #2
0
    def validation_epoch_end(self, outputs, prefix="val"):
        self.step_count += 1

        if "preds" in outputs[0]:
            output_test_predictions_file = os.path.join(
                self.hparams.output_dir, "validation_predictions_" +
                str(self.count_valid_epoch) + ".txt")
            output_test_targets_file = os.path.join(
                self.hparams.output_dir,
                "validation_targets_" + str(self.count_valid_epoch) + ".txt")
            # write predictions and targets for later rouge evaluation.
            with open(output_test_predictions_file,
                      "w") as p_writer, open(output_test_targets_file,
                                             "w") as t_writer:
                for output_batch in outputs:
                    p_writer.writelines(
                        convert_text(s) + "\n" for s in output_batch["preds"])
                    t_writer.writelines(
                        convert_text(s) + "\n" for s in output_batch["target"])
                p_writer.close()
                t_writer.close()

            #bleu_info = eval_bleu_sents(output_test_targets_file, output_test_predictions_file)
            if self.count_valid_epoch >= 0:
                bleu_info = eval_sacre_bleu(output_test_targets_file,
                                            output_test_predictions_file)
                moverScore = eval_mover_score(output_test_targets_file,
                                              output_test_predictions_file)
            else:
                bleu_info = 0
                moverScore = [0, 0]

            metrics = {}
            metrics["{}_avg_bleu".format(prefix)] = bleu_info
            metrics["{}_mover_mean1".format(prefix)] = moverScore[0]
            metrics["{}_mover_median1".format(prefix)] = moverScore[1]
            metrics["step_count"] = self.step_count

            logger.info("valid epoch: %s", self.count_valid_epoch)
            logger.info("%s bleu_info: %s", self.count_valid_epoch, bleu_info)
            logger.info("%s mover score: %s", self.count_valid_epoch,
                        moverScore)

            avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()

            mover_tensor: torch.FloatTensor = torch.tensor(
                moverScore[0]).type_as(avg_loss)

            self.count_valid_epoch += 1

        else:
            logger.info('not in')
            avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        #tensorboard_logs = {"val_loss": avg_loss}

        return {
            "avg_val_loss": avg_loss,
            "log": metrics,
            "{}_mover".format(prefix): mover_tensor
        }
예제 #3
0
def upload_file():
    f = request.files['file']
    mimetype = f.content_type
    print(mimetype)
    if mimetype == "image/jpeg0" or mimetype == "image/png":
        image = open_image(f)
        extracted_text = read_image(image)
        document = write_docx(extracted_text)
        print("okay")
        return send_file(document,
                         as_attachment=True,
                         attachment_filename="report.docx")
    elif mimetype == "audio/mpeg" or mimetype == "audio/wav":
        f.filename = "sound.wav"
        f.save(f.filename)
        stream = os.popen('python3 test_ffmpeg.py ' + f.filename)
        output = stream.read()
        print(output)
        text = ''
        with open('res.json') as json_file:
            data = json.load(json_file)
            text = data['text']
            print(data['text'])
        extracted_text = text
        processed_text = convert_text(extracted_text)
        print(extracted_text + ';' + processed_text)
        response = make_response(extracted_text + ';' + processed_text, 200)
        response.mimetype = "text/plain"
        return response

    elif mimetype == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
        document = open_docx(f)
        extracted_text = read_docx(document)
        processed_text = convert_text(extracted_text)
        new_document = write_docx(processed_text)
        stream = save_docx(new_document)
        print("okay")
        response = make_response(
            send_file(stream,
                      as_attachment=True,
                      attachment_filename="report1to3.docx"))
        response.headers['word'] = 'yes'
        return response
    elif mimetype == "application/vnd.oasis.opendocument.text":
        extracted_text = open_odt(f)
        processed_text = convert_text(extracted_text)
        stream = write_odt(processed_text)
        print("okay")
        response = make_response(
            send_file(stream,
                      as_attachment=True,
                      attachment_filename="report1to3.odt"))
        response.headers['word'] = 'no'
        return response
예제 #4
0
    def infer(self, raw_text):
        output = []
        if self.conditional:
            # while True:
            #     raw_text = input("Model prompt >>> ")
            #     while not raw_text:
            #         print('Prompt should not be empty!')
            #         raw_text = input("Model prompt >>> ")
            if self.en:
                context_tokens = self.enc.encode(raw_text)
            else:
                context_tokens = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(raw_text))
            for sample_id in range(self.nsamples // self.batch_size):
                out = self.sess.run(self.output, feed_dict={
                    self.context: [context_tokens for _ in range(1)]})[:, len(context_tokens):]
                for batch_id in range(self.batch_size):
                    # if self.en:
                    #     text = self.enc.decode(out[batch_id])
                    # else:
                    #     text = self.tokenizer.convert_ids_to_tokens(out[batch_id])
                    #     text = refine_punc.refine_punc(text)
                    # text = text.split('.')[:self.sent_length]
                    # if text[-1] != '':
                    #     text = text + ['']
                    # text = '. '.join([i.strip() for i in text]).strip()
                    # output.append(text)

                    if self.en:
                        text = self.enc.decode(out[batch_id])
                        text = text.split('.')[:self.sent_length]
                        if text[-1] != '':
                            text = text + ['']
                        text = '. '.join([i.strip() for i in text]).strip()
                        output.append(text)

                    else:
                        text = self.tokenizer.convert_ids_to_tokens(out[batch_id])
                        text = refine_punc.refine_punc(text)
                        text = ' '.join(utils.rm_sp(utils.convert_text(text).split('.. '))[:self.sent_length])
                        output.append(text)

        else:
            generated = 0
            while self.nsamples == 0 or generated < self.nsamples:
                out = self.sess.run(self.output)
                for batch_id in range(self.batch_size):
                    generated += self.batch_size
                    if self.en:
                        text = self.enc.decode(out[batch_id])
                    else:
                        text = self.tokenizer.convert_ids_to_tokens(out[batch_id])
                        text = refine_punc.refine_punc(text)
                    text = text.split('.')[:self.sent_length]
                    if text[-1] != '':
                        text = text + ['']
                    text = '. '.join([i.strip() for i in text]).strip()
                    output.append(text)
        return output
예제 #5
0
def hello():
    start_time = time.time()
    text_input = request.json
    texts = convert_text(text_input).tolist()

    if len(text_input) == 0:
        result = []
    else:
        result = rest_infer(texts).tolist()

    return json.dumps({
        "code": 200,
        "time": time.time() - start_time,
        "response": result
    })
예제 #6
0
def submit():
    extracted_text = request.form["text"]
    processed_text = highighter(extracted_text, convert_text(extracted_text))
    response = make_response(processed_text, 200)
    response.mimetype = "text/plain"
    return response