コード例 #1
0
ファイル: main.py プロジェクト: yhg0112/TVQAplus
def validate(opt,
             dset,
             model,
             criterion,
             mode="valid",
             use_hard_negatives=False):
    dset.set_mode(mode)
    torch.set_grad_enabled(False)
    model.eval()
    valid_loader = DataLoader(dset,
                              batch_size=opt.test_bsz,
                              shuffle=False,
                              collate_fn=pad_collate,
                              num_workers=opt.num_workers,
                              pin_memory=True)

    valid_qids = []
    valid_loss = []
    valid_corrects = []
    max_len_dict = dict(
        max_sub_l=opt.max_sub_l,
        max_vid_l=opt.max_vid_l,
        max_vcpt_l=opt.max_vcpt_l,
        max_qa_l=opt.max_qa_l,
    )
    for val_idx, batch in enumerate(valid_loader):
        model_inputs, targets, qids = prepare_inputs(batch,
                                                     max_len_dict=max_len_dict,
                                                     device=opt.device)
        model_inputs.use_hard_negatives = use_hard_negatives
        outputs, att_loss, _, temporal_loss, _ = model(model_inputs)
        loss = criterion(
            outputs, targets
        ) + opt.att_weight * att_loss + opt.ts_weight * temporal_loss
        # measure accuracy and record loss
        valid_qids += [int(x) for x in qids]
        valid_loss.append(loss.data.item())
        pred_ids = outputs.data.max(1)[1]
        valid_corrects += pred_ids.eq(targets.data).tolist()
        if opt.debug and val_idx == 20:
            break

    valid_acc = sum(valid_corrects) / float(len(valid_corrects))
    valid_loss = sum(valid_loss) / float(len(valid_corrects))
    qid_corrects = [
        "%d\t%d" % (a, b) for a, b in zip(valid_qids, valid_corrects)
    ]
    return valid_acc, valid_loss, qid_corrects
コード例 #2
0
def inference(opt, dset, model):
    dset.set_mode(opt.mode)
    data_loader = DataLoader(dset,
                             batch_size=opt.test_bsz,
                             shuffle=False,
                             collate_fn=pad_collate)

    train_corrects = []
    predictions = dict(ts_answer={}, raw_bbox=[])
    max_len_dict = dict(max_sub_l=opt.max_sub_l,
                        max_vid_l=opt.max_vid_l,
                        max_vcpt_l=opt.max_vcpt_l,
                        max_qa_l=opt.max_qa_l,
                        max_dc_l=opt.max_dc_l)
    for valid_idx, batch in tqdm(enumerate(data_loader)):
        model_inputs, targets, qids = prepare_inputs(batch,
                                                     max_len_dict=max_len_dict,
                                                     device=opt.device)

        inference_outputs, _ = model(model_inputs)

        # predicted answers
        pred_ids = inference_outputs.data.max(1)[1]

        train_corrects += pred_ids.eq(targets.data).tolist()
        train_acc = sum(train_corrects) / float(len(train_corrects))
        print(train_corrects)

    print("Idx {:02d} [Train] acc {:.4f}".format(valid_idx, train_acc))

    # predicted regions
    # if inference_outputs["att_predictions"]:
    #     predictions["raw_bbox"] += inference_outputs["att_predictions"]
    #
    # temporal_predictions = inference_outputs["t_scores"]
    # for qid, pred_a_idx, temporal_score_st, temporal_score_ed, img_indices in \
    #         zip(qids, pred_ids.tolist(),
    #             temporal_predictions[:, :, :, 0],
    #             temporal_predictions[:, :, :, 1],
    #             model_inputs["image_indices"]):
    #     offset = (img_indices[0] % 6) / 3
    #     (st, ed), _ = find_max_pair(temporal_score_st[pred_a_idx].cpu().numpy().tolist(),
    #                                 temporal_score_ed[pred_a_idx].cpu().numpy().tolist())
    #     # [[st, ed], pred_ans_idx], note that [st, ed] is associated with the predicted answer.
    #     predictions["ts_answer"][str(qid)] = [[st * 2 + offset, (ed + 1) * 2 + offset], int(pred_a_idx)]

    return predictions
コード例 #3
0
def validate(opt, dset, model, criterion, mode="valid"):
    dset.set_mode(mode)
    torch.set_grad_enabled(False)
    model.eval()
    valid_loader = DataLoader(dset, batch_size=opt.test_bsz, shuffle=False,
                              collate_fn=pad_collate, num_workers=opt.num_workers, pin_memory=True)

    submit_json_val = {}
    valid_qids = []
    valid_loss = []
    valid_corrects = []
    max_len_dict = dict(
        max_sub_l=opt.max_sub_l,
        max_vid_l=opt.max_vid_l,
        max_vcpt_l=opt.max_vcpt_l,
        max_qa_l=opt.max_qa_l,
        max_dc_l=opt.max_dc_l,
    )
    for val_idx, batch in enumerate(valid_loader):
        model_inputs, targets, qids = prepare_inputs(batch, max_len_dict=max_len_dict, device=opt.device)
        outputs, _= model(model_inputs)

        loss = criterion(outputs, targets)

        valid_qids += [int(x) for x in qids]
        valid_loss.append(loss.data.item())
        pred_ids = outputs.data.max(1)[1]

        for qdix, q_id in enumerate(model_inputs['qid']):
            q_id_str = str(q_id)
            submit_json_val[q_id_str] = int(pred_ids[qdix].item())

        valid_corrects += pred_ids.eq(targets.data).tolist()

    acc_1st, acc_2nd = 0., 0. 
    valid_acc = sum(valid_corrects) / float(len(valid_corrects))
    valid_loss = sum(valid_loss) / float(len(valid_corrects))
    qid_corrects = ["%d\t%d" % (a, b) for a, b in zip(valid_qids, valid_corrects)]
    return valid_acc, valid_loss, qid_corrects, acc_1st, acc_2nd, submit_json_val
コード例 #4
0
ファイル: inference.py プロジェクト: yhg0112/TVQAplus
def inference(opt, dset, model):
    dset.set_mode(opt.mode)
    data_loader = DataLoader(dset, batch_size=opt.test_bsz, shuffle=False, collate_fn=pad_collate)

    predictions = dict(ts_answer={}, raw_bbox=[])
    max_len_dict = dict(
        max_sub_l=opt.max_sub_l,
        max_vid_l=opt.max_vid_l,
        max_vcpt_l=opt.max_vcpt_l,
        max_qa_l=opt.max_qa_l,
    )
    for valid_idx, batch in tqdm(enumerate(data_loader)):
        model_inputs, targets, qids = prepare_inputs(batch, max_len_dict=max_len_dict, device=opt.device)
        model_inputs.use_hard_negatives = False
        model_inputs.eval_object_word_ids = dset.eval_object_word_ids  # so we know which words need boxes.

        inference_outputs = model(model_inputs)
        # predicted answers
        pred_ids = inference_outputs["answer"].data.max(1)[1]

        # predicted regions
        if inference_outputs["att_predictions"]:
            predictions["raw_bbox"] += inference_outputs["att_predictions"]

        temporal_predictions = inference_outputs["t_scores"]
        for qid, pred_a_idx, temporal_score_st, temporal_score_ed, img_indices in \
                zip(qids, pred_ids.tolist(),
                    temporal_predictions[:, :, :, 0],
                    temporal_predictions[:, :, :, 1],
                    model_inputs["image_indices"]):
            offset = (img_indices[0] % 6) / 3
            (st, ed), _ = find_max_pair(temporal_score_st[pred_a_idx].cpu().numpy().tolist(),
                                        temporal_score_ed[pred_a_idx].cpu().numpy().tolist())
            # [[st, ed], pred_ans_idx], note that [st, ed] is associated with the predicted answer.
            predictions["ts_answer"][str(qid)] = [[st * 2 + offset, (ed + 1) * 2 + offset], int(pred_a_idx)]
        if opt.debug:
            break
    return predictions
コード例 #5
0
ファイル: main.py プロジェクト: yhg0112/TVQAplus
def train(opt,
          dset,
          model,
          criterion,
          optimizer,
          epoch,
          previous_best_acc,
          use_hard_negatives=False):
    dset.set_mode("train")
    model.train()
    train_loader = DataLoader(dset,
                              batch_size=opt.bsz,
                              shuffle=True,
                              collate_fn=pad_collate,
                              num_workers=opt.num_workers,
                              pin_memory=True)

    train_loss = []
    train_loss_att = []
    train_loss_ts = []
    train_loss_cls = []
    valid_acc_log = ["batch_idx\tacc"]
    train_corrects = []
    torch.set_grad_enabled(True)
    max_len_dict = dict(
        max_sub_l=opt.max_sub_l,
        max_vid_l=opt.max_vid_l,
        max_vcpt_l=opt.max_vcpt_l,
        max_qa_l=opt.max_qa_l,
    )

    # init meters
    dataloading_time = AverageMeter()
    prepare_inputs_time = AverageMeter()
    model_forward_time = AverageMeter()
    model_backward_time = AverageMeter()

    timer_dataloading = time.time()
    for batch_idx, batch in tqdm(enumerate(train_loader)):
        dataloading_time.update(time.time() - timer_dataloading)
        timer_start = time.time()
        model_inputs, _, qids = prepare_inputs(batch,
                                               max_len_dict=max_len_dict,
                                               device=opt.device)
        prepare_inputs_time.update(time.time() - timer_start)
        model_inputs.use_hard_negatives = use_hard_negatives
        try:
            timer_start = time.time()
            outputs, att_loss, _, temporal_loss, _ = model(model_inputs)
            outputs, targets = outputs
            att_loss = opt.att_weight * att_loss
            temporal_loss = opt.ts_weight * temporal_loss
            cls_loss = criterion(outputs, targets)
            # keep the cls_loss at the same magnitude as only classifying batch_size objects
            cls_loss = cls_loss * (1.0 * len(qids) / len(targets))
            loss = cls_loss + att_loss + temporal_loss
            model_forward_time.update(time.time() - timer_start)
            timer_start = time.time()
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), opt.clip)
            optimizer.step()
            model_backward_time.update(time.time() - timer_start)
            # scheduler.step()
            train_loss.append(loss.data.item())
            train_loss_att.append(float(att_loss))
            train_loss_ts.append(float(temporal_loss))
            train_loss_cls.append(cls_loss.item())
            pred_ids = outputs.data.max(1)[1]
            train_corrects += pred_ids.eq(targets.data).tolist()
        except RuntimeError as e:
            if "out of memory" in str(e):
                print("WARNING: ran out of memory, skipping batch")
            else:
                print("RuntimeError {}".format(e))
                sys.exit(1)
        if batch_idx % opt.log_freq == 0:
            niter = epoch * len(train_loader) + batch_idx
            if batch_idx == 0:  # do not add to the loss curve, since it only contains a very small
                train_acc = 0
                train_loss = 0
                train_loss_att = 0
                train_loss_ts = 0
                train_loss_cls = 0
            else:
                train_acc = sum(train_corrects) / float(len(train_corrects))
                train_loss = sum(train_loss) / float(len(train_corrects))
                train_loss_att = sum(train_loss_att) / float(
                    len(train_corrects))
                train_loss_cls = sum(train_loss_cls) / float(
                    len(train_corrects))
                train_loss_ts = sum(train_loss_ts) / float(len(train_corrects))
                opt.writer.add_scalar("Train/Acc", train_acc, niter)
                opt.writer.add_scalar("Train/Loss", train_loss, niter)
                opt.writer.add_scalar("Train/Loss_att", train_loss_att, niter)
                opt.writer.add_scalar("Train/Loss_cls", train_loss_cls, niter)
                opt.writer.add_scalar("Train/Loss_ts", train_loss_ts, niter)
            # Test
            valid_acc, valid_loss, qid_corrects = \
                validate(opt, dset, model, criterion, mode="valid", use_hard_negatives=use_hard_negatives)
            opt.writer.add_scalar("Valid/Acc", valid_acc, niter)
            opt.writer.add_scalar("Valid/Loss", valid_loss, niter)

            valid_log_str = "%02d\t%.4f" % (batch_idx, valid_acc)
            valid_acc_log.append(valid_log_str)

            # remember the best acc.
            if valid_acc > previous_best_acc:
                previous_best_acc = valid_acc
                torch.save(model.state_dict(),
                           os.path.join(opt.results_dir, "best_valid.pth"))

            print(
                "Epoch {:02d} [Train] acc {:.4f} loss {:.4f} loss_att {:.4f} loss_ts {:.4f} loss_cls {:.4f}"
                .format(epoch, train_acc, train_loss, train_loss_att,
                        train_loss_ts, train_loss_cls))

            print("Epoch {:02d} [Val] acc {:.4f} loss {:.4f}".format(
                epoch, valid_acc, valid_loss))

            # reset to train
            torch.set_grad_enabled(True)
            model.train()
            dset.set_mode("train")
            train_corrects = []
            train_loss = []
            train_loss_att = []
            train_loss_ts = []
            train_loss_cls = []

        timer_dataloading = time.time()
        if opt.debug and batch_idx == 5:
            print(
                "dataloading_time: max {dataloading_time.max} "
                "min {dataloading_time.min} avg {dataloading_time.avg}\n"
                "prepare_inputs_time: max {prepare_inputs_time.max} "
                "min {prepare_inputs_time.min} avg {prepare_inputs_time.avg}\n"
                "model_forward_time: max {model_forward_time.max} "
                "min {model_forward_time.min} avg {model_forward_time.avg}\n"
                "model_backward_time: max {model_backward_time.max} "
                "min {model_backward_time.min} avg {model_backward_time.avg}\n"
                "".format(dataloading_time=dataloading_time,
                          prepare_inputs_time=prepare_inputs_time,
                          model_forward_time=model_forward_time,
                          model_backward_time=model_backward_time))
            break

    # additional log
    with open(os.path.join(opt.results_dir, "valid_acc.log"), "a") as f:
        f.write("\n".join(valid_acc_log) + "\n")

    return previous_best_acc
コード例 #6
0
def train(opt, dset, model, criterion, optimizer, epoch, previous_best_acc):
    dset.set_mode("train")
    model.train()
    train_loader = DataLoader(dset, batch_size=opt.bsz, shuffle=True,
                              collate_fn=pad_collate, num_workers=opt.num_workers, pin_memory=True)

    train_loss = []
    train_loss_iofsm = []
    train_loss_accu = []
    train_loss_ts = []
    train_loss_cls = []
    valid_acc_log = ["batch_idx\tacc\tacc1\tacc2"]
    train_corrects = []
    torch.set_grad_enabled(True)
    max_len_dict = dict(
        max_sub_l=opt.max_sub_l,
        max_vid_l=opt.max_vid_l,
        max_vcpt_l=opt.max_vcpt_l,
        max_qa_l=opt.max_qa_l,
        max_dc_l=opt.max_dc_l,
    )


    timer_dataloading = time.time()
    for batch_idx, batch in tqdm(enumerate(train_loader)):
        timer_start = time.time()
        model_inputs, targets, qids = prepare_inputs(batch, max_len_dict=max_len_dict, device=opt.device)
        try:
            timer_start = time.time()
            outputs, max_statement_sm_sigmoid_ = model(model_inputs)
            
            max_statement_sm_sigmoid, max_statement_sm_sigmoid_selection = max_statement_sm_sigmoid_

            temporal_loss = balanced_binaryCrossEntropy(max_statement_sm_sigmoid, targets, model_inputs["ts_label"], model_inputs["ts_label_mask"])


            cls_loss = criterion(outputs, targets)

            iofsm_loss, _, _ = IOFSM(max_statement_sm_sigmoid_selection, targets, model_inputs["ts_label"], model_inputs["ts_label_mask"])

            att_loss_accu = 0

            loss = cls_loss + temporal_loss + iofsm_loss

            timer_start = time.time()
            loss.backward(retain_graph=False)
            torch.nn.utils.clip_grad_norm_(model.parameters(), opt.clip)
            optimizer.step()
            optimizer.zero_grad()

            train_loss.append(loss.data.item())
            train_loss_iofsm.append(float(iofsm_loss))
            train_loss_ts.append(float(temporal_loss))

            train_loss_cls.append(cls_loss.item())
            pred_ids = outputs.data.max(1)[1]
            train_corrects += pred_ids.eq(targets.data).tolist()
            
        except RuntimeError as e:
            if "out of memory" in str(e):
                print("WARNING: ran out of memory, skipping batch")
            else:
                print("RuntimeError {}".format(e))
                sys.exit(1)

        if batch_idx % opt.log_freq == 0:
            niter = epoch * len(train_loader) + batch_idx
            if batch_idx == 0:  
                train_acc = 0
                train_loss = 0
                train_loss_iofsm = 0
                train_loss_ts = 0
                train_loss_cls = 0
            else:
                train_acc = sum(train_corrects) / float(len(train_corrects))
                train_loss = sum(train_loss) / float(len(train_corrects))
                train_loss_iofsm = sum(train_loss_iofsm) / float(len(train_corrects))
                train_loss_cls = sum(train_loss_cls) / float(len(train_corrects))
                train_loss_ts = sum(train_loss_ts) / float(len(train_corrects))


            valid_acc, valid_loss, qid_corrects, valid_acc1, valid_acc2, submit_json_val = \
                validate(opt, dset, model, criterion, mode="valid")

            valid_log_str = "%02d\t%.4f\t%.4f\t%.4f" % (batch_idx, valid_acc, valid_acc1, valid_acc2)
            valid_acc_log.append(valid_log_str)

            if valid_acc > previous_best_acc:
                with open("best_github.json", 'w') as cqf:
                    json.dump(submit_json_val, cqf)
                previous_best_acc = valid_acc
                if epoch >= 10:
                    torch.save(model.state_dict(), os.path.join("./results/best_valid_to_keep", "best_github_7420.pth"))

            print("Epoch {:02d} [Train] acc {:.4f} loss {:.4f} loss_iofsm {:.4f} loss_ts {:.4f} loss_cls {:.4f}"
                  "[Val] acc {:.4f} loss {:.4f}"
                  .format(epoch, train_acc, train_loss, train_loss_iofsm, train_loss_ts, train_loss_cls,
                          valid_acc, valid_loss))


            torch.set_grad_enabled(True)
            model.train()
            dset.set_mode("train")
            train_corrects = []
            train_loss = []
            train_loss_iofsm = []
            train_loss_ts = []
            train_loss_cls = []

        timer_dataloading = time.time()


    with open(os.path.join(opt.results_dir, "valid_acc.log"), "a") as f:
        f.write("\n".join(valid_acc_log) + "\n")

    return previous_best_acc