Ejemplo n.º 1
0
def valid_model(dataLoader, epoch_number, model, cfg, criterion, logger,
                device, **kwargs):
    model.eval()
    num_classes = dataLoader.dataset.get_num_classes()
    fusion_matrix = FusionMatrix(num_classes)

    with torch.no_grad():
        all_loss = AverageMeter()
        acc = AverageMeter()
        func = torch.nn.Softmax(dim=1)
        for i, (image, label, meta) in enumerate(dataLoader):
            image, label = image.to(device), label.to(device)

            feature = model(image, feature_flag=True)

            output = model(feature, classifier_flag=True)
            loss = criterion(output, label)
            score_result = func(output)

            now_result = torch.argmax(score_result, 1)
            all_loss.update(loss.data.item(), label.shape[0])
            fusion_matrix.update(now_result.cpu().numpy(), label.cpu().numpy())
            now_acc, cnt = accuracy(now_result.cpu().numpy(),
                                    label.cpu().numpy())
            acc.update(now_acc, cnt)

        pbar_str = "------- Valid: Epoch:{:>3d}  Valid_Loss:{:>5.3f}   Valid_Acc:{:>5.2f}%-------".format(
            epoch_number, all_loss.avg, acc.avg * 100)
        logger.info(pbar_str)
    return acc.avg, all_loss.avg
Ejemplo n.º 2
0
def valid_model(dataLoader, epoch_number, model, cfg, criterion, logger,
                device, rank, distributed, **kwargs):
    model.eval()

    with torch.no_grad():
        all_loss = AverageMeter()
        acc_avg = AverageMeter()

        func = torch.nn.Sigmoid() \
            if cfg.LOSS.LOSS_TYPE in ['FocalLoss', 'ClassBalanceFocal'] else \
            torch.nn.Softmax(dim=1)

        for i, (image, label, meta) in enumerate(dataLoader):
            image, label = image.to(device), label.to(device)

            feature = model(image, feature_flag=True)

            output = model(feature, classifier_flag=True, label=label)
            loss = criterion(output, label, feature=feature)
            score_result = func(output)

            now_result = torch.argmax(score_result, 1)
            acc, cnt = accuracy(now_result.cpu().numpy(), label.cpu().numpy())

            if distributed:
                world_size = float(os.environ.get("WORLD_SIZE", 1))
                reduced_loss = reduce_tensor(loss.data, world_size)
                reduced_acc = reduce_tensor(
                    torch.from_numpy(np.array([acc])).cuda(), world_size)
                loss = reduced_loss.cpu().data
                acc = reduced_acc.cpu().data

            all_loss.update(loss.data.item(), label.shape[0])
            if distributed:
                acc_avg.update(acc.data.item(), cnt * world_size)
            else:
                acc_avg.update(acc, cnt)

        pbar_str = "------- Valid: Epoch:{:>3d}  Valid_Loss:{:>5.3f}   Valid_Acc:{:>5.2f}%-------".format(
            epoch_number, all_loss.avg, acc_avg.avg * 100)
        if rank == 0:
            logger.info(pbar_str)
    return acc_avg.avg, all_loss.avg
Ejemplo n.º 3
0
def train_model(trainLoader,
                model,
                epoch,
                epoch_number,
                optimizer,
                combiner,
                criterion,
                cfg,
                logger,
                rank=0,
                **kwargs):
    if cfg.EVAL_MODE:
        model.eval()
    else:
        model.train()

    trainLoader.dataset.update(epoch)
    combiner.update(epoch)
    criterion.update(epoch)

    start_time = time.time()
    number_batch = len(trainLoader)

    all_loss = AverageMeter()
    acc = AverageMeter()
    for i, (image, label, meta) in enumerate(trainLoader):
        cnt = label.shape[0]
        loss, now_acc = combiner.forward(model, criterion, image, label, meta)

        optimizer.zero_grad()

        with amp.scale_loss(loss, optimizer) as scaled_loss:
            scaled_loss.backward()

        optimizer.step()
        all_loss.update(loss.data.item(), cnt)
        acc.update(now_acc, cnt)

        if i % cfg.SHOW_STEP == 0 and rank == 0:
            pbar_str = "Epoch:{:>3d}  Batch:{:>3d}/{}  Batch_Loss:{:>5.3f}  Batch_Accuracy:{:>5.2f}%     ".format(
                epoch, i, number_batch, all_loss.val, acc.val * 100)
            logger.info(pbar_str)
    end_time = time.time()
    pbar_str = "---Epoch:{:>3d}/{}   Avg_Loss:{:>5.3f}   Epoch_Accuracy:{:>5.2f}%   Epoch_Time:{:>5.2f}min---".format(
        epoch, epoch_number, all_loss.avg, acc.avg * 100,
        (end_time - start_time) / 60)
    if rank == 0:
        logger.info(pbar_str)
    return acc.avg, all_loss.avg
def train_model(
    trainLoader,model,epoch,epoch_number,device,optimizer,criterion,cfg,logger,**kwargs
):
    if cfg.EVAL_MODE:
        model.eval()
    else:
        model.train()

    start_time = time.time()
    number_batch = len(trainLoader)
    func = torch.nn.Softmax(dim=1)
    all_loss = AverageMeter()
    acc = AverageMeter()
    for i, (image, label) in enumerate(trainLoader):

        cnt = label.shape[0]
        image, label = image.to(device), label.to(device)
        output = model(image)
        loss = criterion(output, label)
        now_result = torch.argmax(func(output), 1)

        now_acc = accuracy(now_result.cpu().numpy(), label.cpu().numpy())[0]

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        all_loss.update(loss.data.item(), cnt)
        acc.update(now_acc, cnt)

        if i % cfg.SHOW_STEP == 0:
            pbar_str = "Epoch:{:>3d}  Batch:{:>3d}/{}  Batch_Loss:{:>5.3f}  Batch_Accuracy:{:>5.2f}%     ".format(
                epoch, i, number_batch, all_loss.val, acc.val * 100
            )
            logger.info(pbar_str)

    end_time = time.time()
    pbar_str = "---Epoch:{:>3d}/{}   Avg_Loss:{:>5.3f}   Epoch_Accuracy:{:>5.2f}%   Epoch_Time:{:>5.2f}min---".format(
        epoch, epoch_number, all_loss.avg, acc.avg * 100, (end_time - start_time) / 60
    )
    logger.info(pbar_str)
    return acc.avg, all_loss.avg
def valid_model(
    para_dict, dataLoader, epoch_number, model, cfg, criterion, logger):

    model.eval()
    num_classes = dataLoader.dataset.get_num_classes()
    fusion_matrix = FusionMatrix(num_classes)

    num_class_list = para_dict["num_class_list"]
    device = para_dict["device"]
    num_class_list = [i**cfg.LOSS.RATIO for i in num_class_list]
    prior_prob = num_class_list / np.sum(num_class_list)
    prior_prob = torch.FloatTensor(prior_prob).to(device)


    with torch.no_grad():
        all_loss = AverageMeter()
        acc = AverageMeter()
        func = torch.nn.Softmax(dim=1)
        for i, (image, label) in enumerate(dataLoader):
            image, label = image.to(device), label.to(device)

            output = model(image)
            loss = criterion(output, label)
            score_result = func(output)
            score_result = score_result / prior_prob
            now_result = torch.argmax(score_result, 1)

            all_loss.update(loss.data.item(), label.shape[0])

            fusion_matrix.update(now_result.cpu().numpy(), label.cpu().numpy())

            now_acc, cnt = accuracy(now_result.cpu().numpy(), label.cpu().numpy())
            acc.update(now_acc, cnt)

        pbar_str = "------- Valid: Epoch:{:>3d}  Valid_Loss:{:>5.3f}   Valid_Acc:{:>5.2f}%-------".format(
            epoch_number, all_loss.avg, acc.avg * 100
        )
        logger.info(pbar_str)
        # print(fusion_matrix.get_rec_per_class())
        # print(fusion_matrix.get_pre_per_class())
    return acc.avg, all_loss.avg
Ejemplo n.º 6
0
def valid_model(dataLoader, model, device, label_map, level_label_maps):

    model.eval()
    num_levels = label_map.shape[1]
    num_classes = dataLoader.dataset.get_num_classes()

    l1_cls_num = label_map[:, 0].max().item() + 1
    l2_cls_num = label_map[:, 1].max().item() + 1
    virtual_cls_num = l1_cls_num + l2_cls_num - num_classes
    l1_raw_cls_num = l1_cls_num - virtual_cls_num

    fusion_matrix1 = FusionMatrix(num_classes)
    fusion_matrix2 = FusionMatrix(l1_cls_num)
    func = torch.nn.Softmax(dim=1)

    # 20 + 80
    acc1 = AverageMeter()
    p1_acc1 = AverageMeter()
    p2_acc1 = AverageMeter()
    level_accs = [AverageMeter() for _ in range(num_levels)]
    all_labels1 = []
    all_result1 = []

    # 20 + 2
    acc2 = AverageMeter()
    p1_acc2 = AverageMeter()
    p2_acc2 = AverageMeter()
    all_labels2 = []
    all_result2 = []

    with torch.no_grad():
        for i, (image, label, meta) in enumerate(dataLoader):
            image, label = image.to(device), label.to(device)

            batch_size = label.shape[0]
            level_scores = []
            level_probs = []
            for level in range(num_levels):

                level_score = model(image, level)
                level_scores.append(level_score)

                # for each level
                level_label = label_map[label, level]
                level_mask = level_label >= 0
                level_label1 = level_label[level_mask]
                level_score1 = level_score[level_mask]
                level_top1 = torch.argmax(level_score1, 1)
                level_acc1, level_cnt1 = accuracy(level_top1.cpu().numpy(),
                                                  level_label1.cpu().numpy())
                level_accs[level].update(level_acc1, level_cnt1)

                if level == 0:
                    level_prob = func(level_score)
                    level_probs.append(level_prob)
                else:
                    high_lcid_to_curr_lcid = level_label_maps[level - 1]
                    level_prob = torch.zeros(level_score.shape).cuda()
                    for high_lcid in range(high_lcid_to_curr_lcid.shape[0]):
                        curr_lcid_mask = high_lcid_to_curr_lcid[high_lcid]
                        if curr_lcid_mask.sum().item() > 0:
                            level_prob[:, curr_lcid_mask] = func(
                                level_score[:, curr_lcid_mask])
                    level_probs.append(level_prob)

            # =================== 20 + 80 ============================
            all_probs = torch.ones((batch_size, num_classes)).cuda()
            for level in range(num_levels):
                level_prob = level_probs[level]
                related_lcids = label_map[:, level]
                related_lcids = related_lcids[related_lcids >= 0]
                unrelated_class_num1 = (label_map[:, level] < 0).sum().item()
                unrelated_class_num2 = label_map.shape[
                    0] - related_lcids.shape[0]
                assert unrelated_class_num1 == unrelated_class_num2
                all_probs[:,
                          unrelated_class_num1:] *= level_prob[:,
                                                               related_lcids]

            l1_mask1 = label < l1_raw_cls_num
            l1_scores1 = all_probs[l1_mask1]
            l1_labels1 = label[l1_mask1]
            l1_result1 = torch.argmax(l1_scores1, 1)
            l1_now_acc1, l1_cnt1 = accuracy(l1_result1.cpu().numpy(),
                                            l1_labels1.cpu().numpy())
            p1_acc1.update(l1_now_acc1, l1_cnt1)

            l2_mask1 = label >= l1_raw_cls_num
            l2_scores1 = all_probs[l2_mask1]
            l2_labels1 = label[l2_mask1]
            l2_result1 = torch.argmax(l2_scores1, 1)
            l2_now_acc1, l2_cnt1 = accuracy(l2_result1.cpu().numpy(),
                                            l2_labels1.cpu().numpy())
            p2_acc1.update(l2_now_acc1, l2_cnt1)

            now_result = torch.argmax(all_probs, 1)
            now_acc, cnt = accuracy(now_result.cpu().numpy(),
                                    label.cpu().numpy())
            acc1.update(now_acc, cnt)
            fusion_matrix1.update(now_result.cpu().numpy(),
                                  label.cpu().numpy())
            all_labels1.extend(label.cpu().numpy().tolist())
            all_result1.extend(now_result.cpu().numpy().tolist())
            # ====================================================================

            # ===================20 + 2 =================================
            l1v_scores = level_probs[0]
            l1v_labels = label_map[label, 0]

            l1_mask2 = l1v_labels < l1_raw_cls_num
            l1_scores2 = l1v_scores[l1_mask2]
            l1_labels2 = l1v_labels[l1_mask2]
            l1_result2 = torch.argmax(l1_scores2, 1)
            l1_now_acc2, l1_cnt2 = accuracy(l1_result2.cpu().numpy(),
                                            l1_labels2.cpu().numpy())
            p1_acc2.update(l1_now_acc2, l1_cnt2)

            l2_mask2 = l1v_labels >= l1_raw_cls_num
            l2_scores2 = l1v_scores[l2_mask2]
            l2_labels2 = l1v_labels[l2_mask2]
            l2_result2 = torch.argmax(l2_scores2, 1)
            l2_now_acc2, l2_cnt2 = accuracy(l2_result2.cpu().numpy(),
                                            l2_labels2.cpu().numpy())
            p2_acc2.update(l2_now_acc2, l2_cnt2)

            l1v_result = torch.argmax(l1v_scores, 1)
            l1v_now_acc, l1v_cnt = accuracy(l1v_result.cpu().numpy(),
                                            l1v_labels.cpu().numpy())
            acc2.update(l1v_now_acc, l1v_cnt)
            fusion_matrix2.update(l1v_result.cpu().numpy(),
                                  l1v_labels.cpu().numpy())
            all_labels2.extend(l1v_labels.cpu().numpy().tolist())
            all_result2.extend(l1v_result.cpu().numpy().tolist())
            # ====================================================================

    print('Acc (head+tail): %.4f %d' % (acc1.avg, acc1.count))
    print('Acc P1         : %.4f %d' % (p1_acc1.avg, p1_acc1.count))
    print('Acc P2         : %.4f %d' % (p2_acc1.avg, p2_acc1.count))
    print('Acc L1         : %.4f %d' %
          (level_accs[0].avg, level_accs[0].count))
    print('Acc L2         : %.4f %d' %
          (level_accs[1].avg, level_accs[1].count))
    print('=' * 23)
    print('Acc (head+v)   : %.4f %d' % (acc2.avg, acc2.count))
    print('Acc P1         : %.4f %d' % (p1_acc2.avg, p1_acc2.count))
    print('Acc Pv         : %.4f %d' % (p2_acc2.avg, p2_acc2.count))
    print('=' * 23)
    return fusion_matrix1, fusion_matrix2, all_labels1, all_result1, all_labels2, all_result2
Ejemplo n.º 7
0
def valid_model(dataLoader,
                epoch_number,
                model,
                logger,
                device,
                label_map,
                level_label_maps,
                stage=0):

    model.eval()
    num_levels = label_map.shape[1]
    num_classes = dataLoader.dataset.get_num_classes()

    l1_cls_num = label_map[:, 0].max().item() + 1
    l2_cls_num = label_map[:, 1].max().item() + 1
    virtual_cls_num = l1_cls_num + l2_cls_num - num_classes
    l1_raw_cls_num = l1_cls_num - virtual_cls_num
    l2_raw_cls_num = l2_cls_num

    fusion_matrix = FusionMatrix(num_classes)
    func = torch.nn.Softmax(dim=1)
    acc = AverageMeter()
    l1_acc = AverageMeter()
    l2_acc = AverageMeter()

    with torch.no_grad():
        for i, (image, label, meta) in enumerate(dataLoader):
            image, label = image.to(device), label.to(device)

            batch_size = label.shape[0]

            if stage == 1 or stage == 0:
                level_scores = []
                level_probs = []
                for level in range(num_levels):
                    level_score = model(image, level)
                    level_scores.append(level_score)

                    if level == 0:
                        level_prob = func(level_score)
                        level_probs.append(level_prob)
                    else:
                        high_lcid_to_curr_lcid = level_label_maps[level - 1]
                        level_prob = torch.zeros(level_score.shape).cuda()
                        for high_lcid in range(
                                high_lcid_to_curr_lcid.shape[0]):
                            curr_lcid_mask = high_lcid_to_curr_lcid[high_lcid]
                            if curr_lcid_mask.sum().item() > 0:
                                level_prob[:, curr_lcid_mask] = func(
                                    level_score[:, curr_lcid_mask])
                        level_probs.append(level_prob)

                all_probs = torch.ones((batch_size, num_classes)).cuda()
                for level in range(num_levels):
                    level_prob = level_probs[level]
                    related_lcids = label_map[:, level]
                    related_lcids = related_lcids[related_lcids >= 0]
                    unrelated_class_num1 = (label_map[:, level] <
                                            0).sum().item()
                    unrelated_class_num2 = label_map.shape[
                        0] - related_lcids.shape[0]
                    assert unrelated_class_num1 == unrelated_class_num2
                    all_probs[:,
                              unrelated_class_num1:] *= level_prob[:,
                                                                   related_lcids]

            elif stage == 2:
                all_probs = model(image)

            else:
                print('ERROR STAGE: %d' % stage)
                exit(-1)

            # if stage == 1:
            #     all_probs = level_probs[0]
            #     label = label_map[label, 0]

            l1_mask = label < l1_raw_cls_num
            l1_scores = all_probs[l1_mask]
            l1_labels = label[l1_mask]
            l1_result = torch.argmax(l1_scores, 1)
            l1_now_acc, l1_cnt = accuracy(l1_result.cpu().numpy(),
                                          l1_labels.cpu().numpy())
            l1_acc.update(l1_now_acc, l1_cnt)

            l2_mask = label >= l1_raw_cls_num
            l2_scores = all_probs[l2_mask]
            l2_labels = label[l2_mask]
            l2_result = torch.argmax(l2_scores, 1)
            l2_now_acc, l2_cnt = accuracy(l2_result.cpu().numpy(),
                                          l2_labels.cpu().numpy())
            l2_acc.update(l2_now_acc, l2_cnt)

            now_result = torch.argmax(all_probs, 1)
            fusion_matrix.update(now_result.cpu().numpy(), label.cpu().numpy())
            now_acc, cnt = accuracy(now_result.cpu().numpy(),
                                    label.cpu().numpy())
            acc.update(now_acc, cnt)

        pbar_str = "------- Valid: Epoch:{:>3d}  Valid_Acc:{:>5.4f}  P1_Acc:{:>5.4f}  P2_Acc:{:>5.4f}-------".format(
            epoch_number, acc.avg, l1_acc.avg, l2_acc.avg)
        logger.info(pbar_str)
    return acc.avg