示例#1
0
def eval_epoch(model, validation_data, device, opt):
    """The same setting as training, where ground-truth word x_{t-1}
    is used to predict next word x_{t}, not realistic for real inference"""
    model.eval()

    total_loss = 0
    n_word_total = 0
    n_word_correct = 0

    with torch.no_grad():
        for batch in tqdm(validation_data,
                          mininterval=2,
                          desc="  Validation =>"):
            if opt.recurrent:
                # prepare data
                batched_data = [
                    prepare_batch_inputs(step_data,
                                         device=device,
                                         non_blocking=opt.pin_memory)
                    for step_data in batch[0]
                ]
                input_ids_list = [e["input_ids"] for e in batched_data]
                video_features_list = [
                    e["video_feature"] for e in batched_data
                ]
                input_masks_list = [e["input_mask"] for e in batched_data]
                token_type_ids_list = [
                    e["token_type_ids"] for e in batched_data
                ]
                input_labels_list = [e["input_labels"] for e in batched_data]

                loss, pred_scores_list = model(input_ids_list,
                                               video_features_list,
                                               input_masks_list,
                                               token_type_ids_list,
                                               input_labels_list)
            else:  # single sentence
                if opt.untied or opt.mtrans:
                    # prepare data
                    batched_data = prepare_batch_inputs(
                        batch[0], device=device, non_blocking=opt.pin_memory)
                    video_feature = batched_data["video_feature"]
                    video_mask = batched_data["video_mask"]
                    text_ids = batched_data["text_ids"]
                    text_mask = batched_data["text_mask"]
                    text_labels = batched_data["text_labels"]

                    loss, pred_scores = model(video_feature, video_mask,
                                              text_ids, text_mask, text_labels)
                    pred_scores_list = [pred_scores]
                    input_labels_list = [text_labels]
                else:
                    # prepare data
                    batched_data = prepare_batch_inputs(
                        batch[0], device=device, non_blocking=opt.pin_memory)
                    input_ids = batched_data["input_ids"]
                    video_features = batched_data["video_feature"]
                    input_masks = batched_data["input_mask"]
                    token_type_ids = batched_data["token_type_ids"]
                    input_labels = batched_data["input_labels"]

                    loss, pred_scores = model(input_ids, video_features,
                                              input_masks, token_type_ids,
                                              input_labels)
                    pred_scores_list = [pred_scores]
                    input_labels_list = [input_labels]

            # keep logs
            n_correct = 0
            n_word = 0
            for pred, gold in zip(pred_scores_list, input_labels_list):
                n_correct += cal_performance(pred, gold)
                valid_label_mask = gold.ne(RCDataset.IGNORE)
                n_word += valid_label_mask.sum().item()

            n_word_total += n_word
            n_word_correct += n_correct
            total_loss += loss.item()

            if opt.debug:
                break

    loss_per_word = 1.0 * total_loss / n_word_total
    accuracy = 1.0 * n_word_correct / n_word_total
    return loss_per_word, accuracy
示例#2
0
def train_epoch(model, training_data, optimizer, ema, device, opt, writer,
                epoch):
    model.train()

    total_loss = 0
    n_word_total = 0
    n_word_correct = 0

    torch.autograd.set_detect_anomaly(True)
    for batch_idx, batch in tqdm(enumerate(training_data),
                                 mininterval=2,
                                 desc="  Training =>",
                                 total=len(training_data)):
        niter = epoch * len(training_data) + batch_idx
        writer.add_scalar("Train/LearningRate",
                          float(optimizer.param_groups[0]["lr"]), niter)
        if opt.recurrent:
            # prepare data
            batched_data = [
                prepare_batch_inputs(step_data,
                                     device=device,
                                     non_blocking=opt.pin_memory)
                for step_data in batch[0]
            ]
            input_ids_list = [e["input_ids"] for e in batched_data]
            video_features_list = [e["video_feature"] for e in batched_data]
            input_masks_list = [e["input_mask"] for e in batched_data]
            token_type_ids_list = [e["token_type_ids"] for e in batched_data]
            input_labels_list = [e["input_labels"] for e in batched_data]

            if opt.debug:

                def print_info(batched_data, step_idx, batch_idx):
                    cur_data = batched_data[step_idx]
                    logger.info("input_ids \n{}".format(
                        cur_data["input_ids"][batch_idx]))
                    logger.info("input_mask \n{}".format(
                        cur_data["input_mask"][batch_idx]))
                    logger.info("input_labels \n{}".format(
                        cur_data["input_labels"][batch_idx]))
                    logger.info("token_type_ids \n{}".format(
                        cur_data["token_type_ids"][batch_idx]))

                print_info(batched_data, 0, 0)

            # forward & backward
            optimizer.zero_grad()
            loss, pred_scores_list = model(input_ids_list, video_features_list,
                                           input_masks_list,
                                           token_type_ids_list,
                                           input_labels_list)
        else:  # single sentence
            if opt.untied or opt.mtrans:
                # prepare data
                batched_data = prepare_batch_inputs(
                    batch[0], device=device, non_blocking=opt.pin_memory)
                video_feature = batched_data["video_feature"]
                video_mask = batched_data["video_mask"]
                text_ids = batched_data["text_ids"]
                text_mask = batched_data["text_mask"]
                text_labels = batched_data["text_labels"]

                if opt.debug:

                    def print_info(cur_data, batch_idx):
                        logger.info("text_ids \n{}".format(
                            cur_data["text_ids"][batch_idx]))
                        logger.info("text_mask \n{}".format(
                            cur_data["text_mask"][batch_idx]))
                        logger.info("text_labels \n{}".format(
                            cur_data["text_labels"][batch_idx]))

                    print_info(batched_data, 0)

                # forward & backward
                optimizer.zero_grad()
                loss, pred_scores = model(video_feature, video_mask, text_ids,
                                          text_mask, text_labels)

                # make it consistent with other configs
                pred_scores_list = [pred_scores]
                input_labels_list = [text_labels]
            else:
                # prepare data
                batched_data = prepare_batch_inputs(
                    batch[0], device=device, non_blocking=opt.pin_memory)
                input_ids = batched_data["input_ids"]
                video_features = batched_data["video_feature"]
                input_masks = batched_data["input_mask"]
                token_type_ids = batched_data["token_type_ids"]
                input_labels = batched_data["input_labels"]

                if opt.debug:

                    def print_info(cur_data, batch_idx):
                        logger.info("input_ids \n{}".format(
                            cur_data["input_ids"][batch_idx]))
                        logger.info("input_mask \n{}".format(
                            cur_data["input_mask"][batch_idx]))
                        logger.info("input_labels \n{}".format(
                            cur_data["input_labels"][batch_idx]))
                        logger.info("token_type_ids \n{}".format(
                            cur_data["token_type_ids"][batch_idx]))

                    print_info(batched_data, 0)

                # forward & backward
                optimizer.zero_grad()
                loss, pred_scores = model(input_ids, video_features,
                                          input_masks, token_type_ids,
                                          input_labels)

                # make it consistent with other configs
                pred_scores_list = [pred_scores]
                input_labels_list = [input_labels]

        loss.backward()
        if opt.grad_clip != -1:  # enable, -1 == disable
            nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
        optimizer.step()

        # update model parameters with ema
        if ema is not None:
            ema(model, niter)

        # keep logs
        n_correct = 0
        n_word = 0
        for pred, gold in zip(pred_scores_list, input_labels_list):
            n_correct += cal_performance(pred, gold)
            valid_label_mask = gold.ne(RCDataset.IGNORE)
            n_word += valid_label_mask.sum().item()

        n_word_total += n_word
        n_word_correct += n_correct
        total_loss += loss.item()

        if opt.debug:
            break
    torch.autograd.set_detect_anomaly(False)

    loss_per_word = 1.0 * total_loss / n_word_total
    accuracy = 1.0 * n_word_correct / n_word_total
    return loss_per_word, accuracy
def run_translate(eval_data_loader, translator, opt):
    # submission template
    batch_res = {"version": "VERSION 1.0",
                 "results": defaultdict(list),
                 "external_data": {"used": "true", "details": "ay"}}
    for raw_batch in tqdm(eval_data_loader, mininterval=2, desc="  - (Translate)"):
        if opt.recurrent:
            # prepare data
            step_sizes = raw_batch[1]  # list(int), len == bsz
            meta = raw_batch[2]  # list(dict), len == bsz
            batch = [prepare_batch_inputs(step_data, device=translator.device)
                     for step_data in raw_batch[0]]
            model_inputs = [
                [e["input_ids"] for e in batch],
                [e["video_feature"] for e in batch],
                [e["input_mask"] for e in batch],
                [e["token_type_ids"] for e in batch]
            ]

            dec_seq_list = translator.translate_batch(
                model_inputs, use_beam=opt.use_beam, recurrent=True, untied=False, xl=opt.xl)

            # example_idx indicates which example is in the batch
            for example_idx, (step_size, cur_meta) in enumerate(zip(step_sizes, meta)):
                # step_idx or we can also call it sen_idx
                for step_idx, step_batch in enumerate(dec_seq_list[:step_size]):
                    batch_res["results"][cur_meta["name"]].append({
                        "sentence": eval_data_loader.dataset.convert_ids_to_sentence(
                            step_batch[example_idx].cpu().tolist()).encode("ascii", "ignore"),
                        "timestamp": cur_meta["timestamp"][step_idx],
                        "gt_sentence": cur_meta["gt_sentence"][step_idx]
                    })
        else:  # single sentence
            meta = raw_batch[2]  # list(dict), len == bsz
            batched_data = prepare_batch_inputs(raw_batch[0], device=translator.device)
            if opt.untied or opt.mtrans:
                model_inputs = [
                    batched_data["video_feature"],
                    batched_data["video_mask"],
                    batched_data["text_ids"],
                    batched_data["text_mask"],
                    batched_data["text_labels"]
                ]
            else:
                model_inputs = [
                    batched_data["input_ids"],
                    batched_data["video_feature"],
                    batched_data["input_mask"],
                    batched_data["token_type_ids"]
                ]

            dec_seq = translator.translate_batch(
                model_inputs, use_beam=opt.use_beam, recurrent=False, untied=opt.untied or opt.mtrans)

            # example_idx indicates which example is in the batch
            for example_idx, (cur_gen_sen, cur_meta) in enumerate(zip(dec_seq, meta)):
                cur_data = {
                    "sentence": eval_data_loader.dataset.convert_ids_to_sentence(
                        cur_gen_sen.cpu().tolist()).encode("ascii", "ignore"),
                    "timestamp": cur_meta["timestamp"],
                    "gt_sentence": cur_meta["gt_sentence"]
                }
                batch_res["results"][cur_meta["name"]].append(cur_data)

        if opt.debug:
            break

    batch_res["results"] = sort_res(batch_res["results"])
    return batch_res