def main():
    opts = parse_args()
    init_logging(
        os.path.join(opts.log_dir,
                     '{:s}_win0_win4_log_test.txt'.format(opts.task)))

    if torch.cuda.is_available():
        torch.cuda.set_device(opts.gpu)
        logging.info("Using GPU!")
        device = "cuda"
    else:
        logging.info("Using CPU!")
        device = "cpu"

    logging.info(opts)

    test_datasets = PhoenixVideo(opts.vocab_file,
                                 opts.corpus_dir,
                                 opts.video_path,
                                 phase=opts.task,
                                 DEBUG=opts.DEBUG)
    vocab_size = test_datasets.vocab.num_words
    blank_id = test_datasets.vocab.word2index['<BLANK>']
    vocabulary = Vocabulary(opts.vocab_file)
    #     model = DilatedSLRNet(opts, device, vocab_size, vocabulary,
    #                           dilated_channels=512, num_blocks=5, dilations=[1, 2, 4], dropout=0.0)
    model = MainStream(vocab_size)
    criterion = CtcLoss(opts, blank_id, device, reduction="none")
    trainer = Trainer(opts, model, criterion, vocabulary, vocab_size, blank_id)

    # ctcdeocde
    ctc_decoder_vocab = [chr(x) for x in range(20000, 20000 + vocab_size)]
    ctc_decoder = ctcdecode.CTCBeamDecoder(ctc_decoder_vocab,
                                           beam_width=opts.beam_width,
                                           blank_id=blank_id,
                                           num_processes=10)

    if os.path.exists(opts.check_point):
        logging.info("Loading checkpoint file from {}".format(
            opts.check_point))
        epoch, num_updates, loss = trainer.load_checkpoint(opts.check_point)
    else:
        logging.info("No checkpoint file in found in {}".format(
            opts.check_point))
        epoch, num_updates, loss = 0, 0, 0.0

    test_iter = trainer.get_batch_iterator(test_datasets,
                                           batch_size=opts.batch_size,
                                           shuffle=False)
    decoded_dict = {}
    val_err, val_correct, val_count = np.zeros([4]), 0, 0

    with open("Data/output/hypo_ctc.txt",
              "w") as f, open("Data/output/ref_ctc.txt", "w") as f2:
        with torch.no_grad():
            model.eval()
            criterion.eval()
            for samples in tqdm(test_iter):
                samples = trainer._prepare_sample(samples)
                video = samples["data"]
                len_video = samples["len_data"]
                label = samples["label"]
                len_label = samples["len_label"]
                video_id = samples['id']

                logits, _ = model(video, len_video)
                len_video /= 4
                logits = F.softmax(logits, dim=-1)
                pred_seq, _, _, out_seq_len = ctc_decoder.decode(
                    logits, len_video)
                start = 0
                for i, length in enumerate(len_label):
                    end = start + length
                    ref = label[start:end].tolist()
                    hyp = [
                        x[0] for x in groupby(pred_seq[i][0]
                                              [:out_seq_len[i][0]].tolist())
                    ]
                    ref_sent = " ".join(
                        [vocabulary.index2word[r] for r in ref])
                    hyp_sent = " ".join(
                        [vocabulary.index2word[r] for r in hyp])
                    f.write(hyp_sent + "\n")
                    f2.write(ref_sent + "\n")

                    decoded_dict[video_id[i]] = hyp
                    val_correct += int(ref == hyp)
                    err = get_wer_delsubins(ref, hyp)
                    val_err += np.array(err)
                    val_count += 1
                    start = end
                assert end == label.size(0)
            logging.info('-' * 50)
            logging.info('Epoch: {:d}, DEV ACC: {:.5f}, {:d}/{:d}'.format(
                epoch, val_correct / val_count, val_correct, val_count))
            logging.info(
                'Epoch: {:d}, DEV WER: {:.5f}, SUB: {:.5f}, INS: {:.5f}, DEL: {:.5f}'
                .format(epoch, val_err[0] / val_count, val_err[1] / val_count,
                        val_err[2] / val_count, val_err[3] / val_count))

            list_str_for_test = []
            for k, v in decoded_dict.items():
                start_time = 0
                for wi in v:
                    tl = np.random.random() * 0.1
                    list_str_for_test.append('{} 1 {:.3f} {:.3f} {}\n'.format(
                        k, start_time, start_time + tl,
                        test_datasets.vocab.index2word[wi]))
                    start_time += tl
            tmp_prefix = str(uuid.uuid1())
            txt_file = '{:s}.txt'.format(tmp_prefix)
            result_file = os.path.join('evaluation_relaxation', txt_file)
            with open(result_file, 'w') as fid:
                fid.writelines(list_str_for_test)
            phoenix_eval_err = get_phoenix_wer(txt_file, opts.task, tmp_prefix)
            logging.info(
                '[Relaxation Evaluation] Epoch: {:d}, DEV WER: {:.5f}, SUB: {:.5f}, INS: {:.5f}, DEL: {:.5f}'
                .format(epoch, phoenix_eval_err[0], phoenix_eval_err[1],
                        phoenix_eval_err[2], phoenix_eval_err[3]))
            return phoenix_eval_err
Beispiel #2
0
def main_2():
    opts = parse_args()
    init_logging(os.path.join(opts.log_dir, '{:s}_log.txt'.format(opts.task)))

    if torch.cuda.is_available():
        torch.cuda.set_device(opts.gpu)
        logging.info("Using GPU!")
        device = "cuda"
    else:
        logging.info("Using CPU!")
        device = "cpu"

    logging.info(opts)

    test_datasets = PhoenixVideo(opts.vocab_file,
                                 opts.corpus_dir,
                                 opts.video_path,
                                 phase=opts.task,
                                 DEBUG=opts.DEBUG)
    vocab_size = test_datasets.vocab.num_words
    blank_id = test_datasets.vocab.word2index['<BLANK>']
    vocabulary = Vocabulary(opts.vocab_file)
    model = DilatedSLRNet(opts,
                          device,
                          vocab_size,
                          vocabulary,
                          dilated_channels=512,
                          num_blocks=5,
                          dilations=[1, 2, 4],
                          dropout=0.0)
    criterion = CtcLoss(opts, blank_id, device, reduction="none")
    trainer = Trainer(opts, model, criterion, vocabulary, vocab_size, blank_id)

    # iterative decoder
    dec_generator = IterativeGenerate(vocabulary, model)

    if os.path.exists(opts.check_point):
        logging.info("Loading checkpoint file from {}".format(
            opts.check_point))
        epoch, num_updates, loss = trainer.load_checkpoint(opts.check_point)
    else:
        logging.info("No checkpoint file in found in {}".format(
            opts.check_point))
        epoch, num_updates, loss = 0, 0, 0.0

    test_iter = trainer.get_batch_iterator(test_datasets,
                                           batch_size=opts.batch_size,
                                           shuffle=False)
    decoded_dict = {}
    with torch.no_grad():
        model.eval()
        criterion.eval()
        val_err, val_correct, val_count = np.zeros([4]), 0, 0
        for samples in tqdm(test_iter):
            samples = trainer._prepare_sample(samples)
            video = samples["data"]
            len_video = samples["len_data"]
            label = samples["label"]
            len_label = samples["len_label"]
            video_id = samples['id']

            hypos = dec_generator.generate_ctcdecode(video, len_video)

            start = 0
            for i, length in enumerate(len_label):
                end = start + length
                ref = label[start:end].tolist()
                # hyp = [x for x in pred_seq[i] if x != 0]
                # hyp = [x[0] for x in groupby(pred_seq[i][0][:out_seq_len[i][0]].tolist())]
                hyp = trainer.post_process_prediction(hypos[i][0]["tokens"])
                # if i == 0:
                #     if len(hyp) == 0:
                #         logging.info("Here hyp is None!!!!")
                #     logging.info("video id: {}".format(video_id[i]))
                #     logging.info("ref: {}".format(" ".join(str(i) for i in ref)))
                #     logging.info("hyp: {}".format(" ".join(str(i) for i in hyp)))
                #
                #     logging.info("\n")
                decoded_dict[video_id[i]] = hyp
                val_correct += int(ref == hyp)
                err = get_wer_delsubins(ref, hyp)
                val_err += np.array(err)
                val_count += 1
                start = end
            assert end == label.size(0)
        logging.info('-' * 50)
        logging.info('Epoch: {:d}, DEV ACC: {:.5f}, {:d}/{:d}'.format(
            epoch, val_correct / val_count, val_correct, val_count))
        logging.info(
            'Epoch: {:d}, DEV WER: {:.5f}, SUB: {:.5f}, INS: {:.5f}, DEL: {:.5f}'
            .format(epoch, val_err[0] / val_count, val_err[1] / val_count,
                    val_err[2] / val_count, val_err[3] / val_count))

        list_str_for_test = []
        for k, v in decoded_dict.items():
            start_time = 0
            for wi in v:
                tl = np.random.random() * 0.1
                list_str_for_test.append('{} 1 {:.3f} {:.3f} {}\n'.format(
                    k, start_time, start_time + tl,
                    test_datasets.vocab.index2word[wi]))
                start_time += tl
        tmp_prefix = str(uuid.uuid1())
        txt_file = '{:s}.txt'.format(tmp_prefix)
        result_file = os.path.join('evaluation_relaxation', txt_file)
        with open(result_file, 'w') as fid:
            fid.writelines(list_str_for_test)
        phoenix_eval_err = get_phoenix_wer(txt_file, opts.task, tmp_prefix)
        logging.info(
            '[Relaxation Evaluation] Epoch: {:d}, DEV WER: {:.5f}, SUB: {:.5f}, INS: {:.5f}, DEL: {:.5f}'
            .format(epoch, phoenix_eval_err[0], phoenix_eval_err[1],
                    phoenix_eval_err[2], phoenix_eval_err[3]))
        return phoenix_eval_err
def main():
    opts = parse_args()
    init_logging(os.path.join(opts.log_dir, '{:s}_log.txt'.format(opts.task)))

    if torch.cuda.is_available():
        torch.cuda.set_device(opts.gpu)
        logging.info("Using GPU!")
        device = "cuda"
    else:
        logging.info("Using CPU!")
        device = "cpu"

    logging.info(opts)

    test_datasets = PhoenixVideo(opts.vocab_file,
                                 opts.corpus_dir,
                                 opts.video_path,
                                 phase="train",
                                 DEBUG=opts.DEBUG,
                                 sample=False)
    vocab_size = test_datasets.vocab.num_words
    blank_id = test_datasets.vocab.word2index['<BLANK>']
    pad_id = test_datasets.vocab.pad()
    vocabulary = Vocabulary(opts.vocab_file)
    # model = DilatedSLRNet(opts, device, vocab_size, vocabulary,
    #                       dilated_channels=512, num_blocks=5, dilations=[1, 2, 4], dropout=0.0)
    model = MainStream(vocab_size)
    criterion = CtcLoss(opts, blank_id, device, reduction="none")
    trainer = Trainer(opts, model, criterion, vocabulary, vocab_size, blank_id)

    # ctcdeocde
    ctc_decoder_vocab = [chr(x) for x in range(20000, 20000 + vocab_size)]
    ctc_decoder = ctcdecode.CTCBeamDecoder(ctc_decoder_vocab,
                                           beam_width=opts.beam_width,
                                           blank_id=blank_id,
                                           num_processes=10)

    if os.path.exists(opts.check_point):
        logging.info("Loading checkpoint file from {}".format(
            opts.check_point))
        epoch, num_updates, loss = trainer.load_checkpoint(opts.check_point)
    else:
        logging.info("No checkpoint file in found in {}".format(
            opts.check_point))
        epoch, num_updates, loss = 0, 0, 0.0

    test_iter = trainer.get_batch_iterator(test_datasets,
                                           batch_size=opts.batch_size,
                                           shuffle=False)

    with torch.no_grad():
        model.eval()
        criterion.eval()
        prob_results = {}
        for i, samples in enumerate(test_iter):
            if i > 500:
                break
            samples = trainer._prepare_sample(samples)
            video = samples["data"]
            len_video = samples["len_data"]
            label = samples["label"]
            len_label = samples["len_label"]
            video_id = samples['id']
            dec_label = samples["decoder_label"]
            len_dec_label = samples["len_decoder_label"]

            # print("video: ", video.shape)
            logits, _ = model(video, len_video)
            len_video /= 4
            # print("logits: ", logits.shape)
            # print(len_video)

            params = logits[0, :len_video[0], :].transpose(
                1, 0).detach().cpu().numpy()  # [T, vocab_size]
            seq = dec_label[0, :len_dec_label[0]].cpu().numpy()
            alignment = get_alignment(params,
                                      seq,
                                      blank=blank_id,
                                      is_prob=False)  # [length]
            # print("video_id:", video_id[0])
            # print("gt label:", seq)
            # print("alignment:", alignment)

            probs = logits.softmax(-1)[0]  # [length ,vocab_size]
            align_probs = []
            for i in range(alignment.shape[0]):
                align_probs.append(
                    probs[i, alignment[i]].detach().cpu().numpy().tolist())
            # print(align_probs)
            # exit()
            count = 0
            total_cnt = 0
            for i in range(len(align_probs)):
                total_cnt += 1
                if alignment[i] == blank_id:
                    align_probs[i] = 0
                    count += 1
            print(
                "video_id: {}, and blank count / total count: {}/{} = {:.4f}".
                format(video_id[0], count, total_cnt, count / total_cnt))
            prob_results[video_id[0]] = (align_probs, alignment)
            # print(align_probs)
    return prob_results
def main():
    opts = parse_args()
    init_logging(os.path.join(opts.log_dir, '{:s}_log.txt'.format(opts.task)))

    if torch.cuda.is_available():
        torch.cuda.set_device(opts.gpu)
        logging.info("Using GPU!")
        device = "cuda"
    else:
        logging.info("Using CPU!")
        device = "cpu"

    logging.info(opts)

    test_datasets = PhoenixVideo(opts.vocab_file,
                                 opts.corpus_dir,
                                 opts.video_path,
                                 phase="train",
                                 DEBUG=opts.DEBUG)
    vocab_size = test_datasets.vocab.num_words
    blank_id = test_datasets.vocab.word2index['<BLANK>']
    vocabulary = Vocabulary(opts.vocab_file)
    #     model = DilatedSLRNet(opts, device, vocab_size, vocabulary,
    #                           dilated_channels=512, num_blocks=5, dilations=[1, 2, 4], dropout=0.0)
    model = MainStream(vocab_size)
    criterion = CtcLoss(opts, blank_id, device, reduction="none")
    trainer = Trainer(opts, model, criterion, vocabulary, vocab_size, blank_id)

    # ctcdeocde
    ctc_decoder_vocab = [chr(x) for x in range(20000, 20000 + vocab_size)]
    ctc_decoder = ctcdecode.CTCBeamDecoder(ctc_decoder_vocab,
                                           beam_width=opts.beam_width,
                                           blank_id=blank_id,
                                           num_processes=10)

    if os.path.exists(opts.check_point):
        logging.info("Loading checkpoint file from {}".format(
            opts.check_point))
        epoch, num_updates, loss = trainer.load_checkpoint(opts.check_point)
    else:
        logging.info("No checkpoint file in found in {}".format(
            opts.check_point))
        epoch, num_updates, loss = 0, 0, 0.0

    test_iter = trainer.get_batch_iterator(test_datasets,
                                           batch_size=opts.batch_size,
                                           shuffle=False)

    video_sim = {}

    with torch.no_grad():
        model.eval()
        criterion.eval()
        for i, samples in tqdm(enumerate(test_iter)):
            if i > 50:
                break
            samples = trainer._prepare_sample(samples)
            video = samples["data"]
            len_video = samples["len_data"]
            label = samples["label"]
            len_label = samples["len_label"]
            video_id = samples['id']

            logits, _, scores1, scores2 = model(video, len_video)
            print(scores1)
            ids = scores1.topk(k=16, dim=-1)[1].sort(-1)[0]  # [bs, t, t]
            bs, t, _ = scores1.size()
            for i in range(bs):
                for j in range(t):
                    select_id = ids[i, j, :].cpu().numpy().tolist()
                    for k in range(t):
                        if k not in select_id:
                            scores1[i, j, k] = 1e-9
            print("scores1: ", scores1)
            scores1 = scores1.softmax(-1)

            mask = scores1 > 0.02
            print(scores1, mask)
            scores1 *= mask.float()
            # sim_matrix = scores1.softmax(-1)
            # print(scores1[0, 0, :20])
            # exit()
            for i in range(len(video_id)):
                video_sim[video_id[i]] = scores1[i].cpu().numpy()
    # print(video_sim)
    with open("Data/output/sim_matrix.pkl", "wb") as f:
        pickle.dump(video_sim, f)