Esempio n. 1
0
def run_epoch(model,
              loader_s,
              loader_u,
              loss_fn,
              optimizer,
              desc_default='',
              epoch=0,
              writer=None,
              verbose=1,
              unsupervised=False,
              scheduler=None,
              num_classes=10):

    top1_per_class = AverageMeterVector(num_classes)

    tqdm_disable = bool(os.environ.get('TASK_NAME', ''))
    tqdm_disable = True
    if verbose:
        loader_s = tqdm(loader_s, disable=tqdm_disable)
        loader_s.set_description('[%s %04d/%04d]' %
                                 (desc_default, epoch, C.get()['epoch']))

    iter_u = iter(loader_u)

    metrics = Accumulator()
    cnt = 0
    total_steps = len(loader_s)
    steps = 0
    for data, label in loader_s:
        steps += 1

        if not unsupervised:
            data, label = data.cuda(), label.cuda()
            preds = model(data)
            loss = loss_fn(preds, label)  # loss for supervised learning
        else:
            label = label.cuda()
            try:
                unlabel1, unlabel2 = next(iter_u)
            except StopIteration:
                iter_u = iter(loader_u)
                unlabel1, unlabel2 = next(iter_u)
            data_all = torch.cat([data, unlabel1, unlabel2]).cuda()

            preds_all = model(data_all)
            preds = preds_all[:len(data)]
            loss = loss_fn(preds, label)  # loss for supervised learning

            preds_unsup = preds_all[len(data):]
            preds1, preds2 = torch.chunk(preds_unsup, 2)
            preds1 = softmax(preds1, dim=1).detach()
            preds2 = log_softmax(preds2, dim=1)
            assert len(preds1) == len(preds2) == C.get()['batch_unsup']

            loss_kldiv = kl_div(preds2, preds1,
                                reduction='none')  # loss for unsupervised
            loss_kldiv = torch.sum(loss_kldiv, dim=1)
            assert len(loss_kldiv) == len(unlabel1)
            # loss += (epoch / 200. * C.get()['ratio_unsup']) * torch.mean(loss_kldiv)
            if C.get()['ratio_mode'] == 'constant':
                loss += C.get()['ratio_unsup'] * torch.mean(loss_kldiv)
            elif C.get()['ratio_mode'] == 'gradual':
                loss += (epoch / float(C.get()['epoch'])
                         ) * C.get()['ratio_unsup'] * torch.mean(loss_kldiv)
            else:
                raise ValueError

        if optimizer:
            loss.backward()
            if C.get()['optimizer'].get('clip', 5) > 0:
                nn.utils.clip_grad_norm_(model.parameters(),
                                         C.get()['optimizer'].get('clip', 5))

            optimizer.step()
            optimizer.zero_grad()

        top1, top5 = accuracy(preds, label, (1, 5))
        prec1_per_class, rec_num = accuracy(preds,
                                            label,
                                            topk=(1, ),
                                            per_class=True)
        top1_per_class.update(prec1_per_class.cpu().numpy(),
                              rec_num.cpu().numpy())
        metrics.add_dict({
            'loss': loss.item() * len(data),
            'top1': top1.item() * len(data),
            'top5': top5.item() * len(data),
        })
        cnt += len(data)
        if verbose:
            postfix = metrics / cnt
            if optimizer:
                postfix['lr'] = optimizer.param_groups[0]['lr']
            loader_s.set_postfix(postfix)

        if scheduler is not None:
            scheduler.step(epoch - 1 + float(steps) / total_steps)

        del preds, loss, top1, top5, data, label

    if tqdm_disable:
        logger.info('[%s %03d/%03d] %s', desc_default, epoch,
                    C.get()['epoch'], metrics / cnt)

    metrics /= cnt
    if optimizer:
        metrics.metrics['lr'] = optimizer.param_groups[0]['lr']
    if verbose:
        for key, value in metrics.items():
            writer.add_scalar(key, value, epoch)
    top1_acc = metrics['top1']
    print(
        f'Epoch {epoch} {[desc_default]} top1_per_class accuracy is: {np.round(top1_per_class.avg,2)}, average: {np.round(top1_acc,4)}',
        flush=True)
    return metrics
Esempio n. 2
0
def eval_tta(config, augment, reporter):
    C.get()
    C.get().conf = config
    cv_ratio_test, cv_fold, save_path = augment['cv_ratio_test'], augment[
        'cv_fold'], augment['save_path']

    # setup - provided augmentation rules
    C.get()['aug'] = policy_decoder(augment, augment['num_policy'],
                                    augment['num_op'])

    # eval
    model = get_model(C.get()['model'], num_class(C.get()['dataset']))
    ckpt = torch.load(save_path + '.pth')
    if 'model' in ckpt:
        model.load_state_dict(ckpt['model'])
    else:
        model.load_state_dict(ckpt)
    model = nn.DataParallel(model).cuda()
    model.eval()

    src_loaders = []
    # for _ in range(augment['num_policy']):
    _, src_tl, src_validloader, src_ttl = get_dataloaders(
        C.get()['dataset'],
        C.get()['batch'],
        augment['dataroot'],
        cv_ratio_test,
        cv_num,
        split_idx=cv_fold,
        target=False,
        random_range=C.get()['args'].random_range)

    del src_tl, src_ttl

    start_t = time.time()
    metrics = Accumulator()
    loss_fn = torch.nn.CrossEntropyLoss(reduction='none')

    emd_loss = nn.DataParallel(emdModule()).cuda()

    losses = []
    corrects = []
    for data in src_validloader:
        with torch.no_grad():
            point_cloud = data['point_cloud'].cuda()
            label = torch.ones_like(data['label'], dtype=torch.int64).cuda()
            trans_pc = data['transformed']

            pred = model(trans_pc)

            if C.get()['args'].use_emd_false:
                loss_emd = (torch.mean(emd_loss(point_cloud.permute(0, 2, 1),
                                                trans_pc.permute(0, 2, 1), 0.05, 3000)[0])).unsqueeze(0) \
                           * C.get()['args'].emd_coeff
            else:
                loss_emd = torch.tensor([0.0])

            if C.get()['args'].no_dc:
                loss = loss_emd
            else:
                loss = loss_emd + loss_fn(pred, label)
            # print(loss)
            losses.append(loss.detach().cpu().numpy())

            pred = pred.max(dim=1)[1]
            pred = pred.t()
            correct = float(
                torch.sum(pred == label).item()) / pred.size(0) * 100
            corrects.append(correct)
            del loss, correct, pred, data, label, loss_emd

    losses = np.concatenate(losses)
    losses_min = np.min(losses, axis=0).squeeze()
    corrects_max = max(corrects)
    metrics.add_dict({
        'minus_loss': -1 * np.sum(losses_min),
        'correct': np.sum(corrects_max),
        # 'cnt': len(corrects_max)
    })
    del corrects, corrects_max

    del model
    # metrics = metrics / 'cnt'
    gpu_secs = (time.time() - start_t) * torch.cuda.device_count()
    # print(metrics)
    reporter(minus_loss=metrics['minus_loss'],
             top1_valid=metrics['correct'],
             elapsed_time=gpu_secs,
             done=True)
    return metrics['minus_loss']
            ts.insert(corrupt_idx, lambda img: PIL.Image.fromarray(corrupt(np.array(img), corrupt_level, corrupt_type)))
        else:
            ts.insert(corrupt_idx, lambda img: PIL.Image.fromarray(corrupt(np.array(img), corrupt_level, None, int(corrupt_type))))

    transform_test = transforms.Compose(ts)

    testset = ImageNet(root='/data/public/rw/datasets/imagenet-pytorch', split='val', transform=transform_test)
    sss = StratifiedShuffleSplit(n_splits=5, test_size=0.2, random_state=0)
    for _ in range(1):
        sss = sss.split(list(range(len(testset))), testset.targets)
    train_idx, valid_idx = next(sss)
    testset = Subset(testset, valid_idx)

    testloader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch, shuffle=False, num_workers=32, pin_memory=True, drop_last=False)

    metric = Accumulator()
    dl_test = tqdm(testloader)
    data_id = 0
    tta_rule_cnt = [0] * tta_num
    for data, label in dl_test:
        data = data.view(-1, data.shape[-3], data.shape[-2], data.shape[-1])
        data = data.cuda()

        with torch.no_grad():
            preds = model_target(data)
            preds = torch.softmax(preds, dim=1)

        preds = preds.view(len(label), -1, preds.shape[-1])

        preds_merged = torch.mean(preds, dim=1)     # simple averaging
        # TODO : weighted average mean?
Esempio n. 4
0
def run_epoch(
        model,
        loader_s,
        loader_u,
        loss_fn,
        optimizer,
        desc_default="",
        epoch=0,
        writer=None,
        verbose=1,
        unsupervised=False,
        scheduler=None,
        method="UDA",
):
    tqdm_disable = bool(os.environ.get("TASK_NAME", ""))
    if verbose:
        loader_s = tqdm(loader_s, disable=tqdm_disable)
        loader_s.set_description(
            "[%s %04d/%04d]" % (desc_default, epoch, C.get()["epoch"])
        )

    iter_u = iter(loader_u)

    metrics = Accumulator()
    cnt = 0
    total_steps = len(loader_s)
    steps = 0
    for data, label in loader_s:
        steps += 1

        if not unsupervised:
            data, label = data.cuda(), label.cuda()
            preds = model(data)
            loss = loss_fn(preds, label)  # loss for supervised learning
            loss_kldiv = torch.tensor(0).float().cuda()
        else:
            label = label.cuda()
            try:
                unlabel1, unlabel2 = next(iter_u)
            except StopIteration:
                iter_u = iter(loader_u)
                unlabel1, unlabel2 = next(iter_u)
            data_all = torch.cat([data, unlabel1, unlabel2]).cuda()

            preds_all = model(data_all)
            preds = preds_all[: len(data)]
            loss = loss_fn(preds, label)  # loss for supervised learning

            preds_unsup = preds_all[len(data):]
            preds_logit_1, preds_logit_2 = torch.chunk(preds_unsup, 2)

            if method == "UDA":
                preds_softmax_1 = softmax(preds_logit_1, dim=1).detach()
                preds_logsoftmax_2 = log_softmax(preds_logit_2, dim=1)
                assert len(preds_softmax_1) == len(preds_logsoftmax_2) == C.get()["batch_unsup"]

                loss_kldiv = kl_div(
                    preds_logsoftmax_2, preds_softmax_1, reduction="none"
                )  # loss for unsupervised
                loss_kldiv = torch.sum(loss_kldiv, dim=1)
                assert len(loss_kldiv) == len(unlabel1)
            elif method == "IIC":
                loss_kldiv, _ = IIDLoss()(preds_logit_1.softmax(1), preds_logit_2.softmax(1))
            else:
                raise NotImplementedError
            # loss += (epoch / 200. * C.get()['ratio_unsup']) * torch.mean(loss_kldiv)
            loss += args.alpha * torch.mean(loss_kldiv)

        if optimizer:
            loss.backward()
            if C.get()["optimizer"].get("clip", 5) > 0:
                nn.utils.clip_grad_norm_(
                    model.parameters(), C.get()["optimizer"].get("clip", 5)
                )

            optimizer.step()
            optimizer.zero_grad()

        top1, top5 = accuracy(preds, label, (1, 5))

        metrics.add_dict(
            {
                "loss": loss.item() * len(data),
                "top1": top1.item() * len(data),
                "top5": top5.item() * len(data),
                "reg": loss_kldiv.mean().item() * len(data)
            }
        )
        cnt += len(data)
        if verbose:
            postfix = metrics / cnt
            if optimizer:
                postfix["lr"] = optimizer.param_groups[0]["lr"]
            loader_s.set_postfix(postfix)

        if scheduler is not None:
            scheduler.step(epoch - 1 + float(steps) / total_steps)

        del preds, loss, top1, top5, data, label

    if tqdm_disable:
        logger.info(
            "[%s %03d/%03d] %s", desc_default, epoch, C.get()["epoch"], metrics / cnt
        )

    metrics /= cnt
    if optimizer:
        metrics.metrics["lr"] = optimizer.param_groups[0]["lr"]
    if verbose:
        for key, value in metrics.items():
            writer.add_scalar(key, value, epoch)
    return metrics
Esempio n. 5
0
def run_epoch(model,
              loader,
              loss_fn,
              optimizer,
              desc_default='',
              epoch=0,
              writer=None,
              verbose=1,
              scheduler=None,
              is_train=False):
    tqdm_disable = bool(os.environ.get('TASK_NAME', ''))
    if verbose:
        loader = tqdm(loader, disable=tqdm_disable)
        loader.set_description('[%s %04d/%04d]' %
                               (desc_default, epoch, C.get()['epoch']))

    metrics = Accumulator()
    cnt = 0
    total_steps = len(loader)
    steps = 0
    for data, label in loader:
        steps += 1
        data, label = data.cuda(), label.cuda()

        if is_train:
            data, targets_a, targets_b, lam = mixup_data(data,
                                                         label,
                                                         use_cuda=True)
            data, targets_a, targets_b = map(Variable,
                                             (data, targets_a, targets_b))
            preds = model(data)
            loss = mixup_criterion(loss_fn, preds, targets_a, targets_b, lam)
        else:
            preds = model(data)
            loss = loss_fn(preds, label)
        if optimizer:
            optimizer.zero_grad()
        if optimizer:
            loss.backward()
            if getattr(optimizer, "synchronize", None):
                optimizer.synchronize()
            if C.get()['optimizer'].get('clip', 5) > 0:
                nn.utils.clip_grad_norm_(model.parameters(),
                                         C.get()['optimizer'].get('clip', 5))
            optimizer.step()

        top1, top5 = accuracy(preds, label, (1, 5))

        metrics.add_dict({
            'loss': loss.item() * len(data),
            'top1': top1.item() * len(data),
            'top5': top5.item() * len(data),
        })
        cnt += len(data)
        if verbose:
            postfix = metrics / cnt
            if optimizer:
                postfix['lr'] = optimizer.param_groups[0]['lr']
            loader.set_postfix(postfix)

        del preds, loss, top1, top5, data, label

    if tqdm_disable:
        logger.info('[%s %03d/%03d] %s', desc_default, epoch,
                    C.get()['epoch'], metrics / cnt)

    metrics /= cnt
    if optimizer:
        metrics.metrics['lr'] = optimizer.param_groups[0]['lr']
    if verbose:
        for key, value in metrics.items():
            writer.add_scalar(key, value, epoch)
    return metrics
Esempio n. 6
0
    def run_epoch(_loader, _model, _optimizer, _tag, _ema, _epoch, _scheduler=None, max_step=100000):
        params_without_bn = [params for name, params in _model.named_parameters() if not ('_bn' in name or '.bn' in name)]

        tta_cnt = [0] * tta_num
        metric = Accumulator()
        batch = []
        total_steps = len(_loader)
        tqdm_loader = tqdm(_loader, desc=f'[{_tag} epoch={_epoch+1:03}/{args.epoch:03}]', total=min(max_step, total_steps))
        try:
            for example_id, (img_orig, lb, losses, corrects) in enumerate(tqdm_loader):
                batch.append((img_orig, lb, losses, corrects))
                if (example_id + 1) % args.batch != 0:
                    continue

                if max_step < example_id:
                    break

                imgs = torch.cat([x[0] for x in batch]).cuda()
                lbs = torch.cat([x[1] for x in batch]).long().cuda()
                losses = torch.cat([x[2] for x in batch]).cuda()
                corrects = torch.cat([x[3] for x in batch]).cuda()
                assert len(imgs) > 0

                imgs = imgs.view(imgs.size(0) * imgs.size(1), imgs.size(2), imgs.size(3), imgs.size(4))
                lbs = lbs.view(lbs.size(0) * lbs.size(1))
                losses = losses.view(losses.size(0) * losses.size(1), -1)
                corrects = corrects.view(corrects.size(0) * corrects.size(1), -1)
                assert losses.shape[1] == tta_num, losses.shape
                assert corrects.shape[1] == tta_num, corrects.shape
                assert torch.isnan(losses).sum() == 0

                softmin_target = torch.nn.functional.softmin(losses / args.tau, dim=1).detach()
                pred = _model(imgs)
                pred_softmax = torch.nn.functional.softmax(pred, dim=1)
                assert torch.isnan(pred).sum() == 0, pred
                assert torch.isnan(pred_softmax).sum() == 0, pred_softmax
                assert torch.isnan(softmin_target).sum() == 0
                assert softmin_target.shape[0] == pred_softmax.shape[0], (softmin_target.shape, pred_softmax.shape)
                assert softmin_target.shape[1] == pred_softmax.shape[1], (softmin_target.shape, pred_softmax.shape)

                pred_final = pred_softmax
                loss = spearman_loss(pred_softmax, softmin_target)

                if _optimizer is not None:
                    loss_total = loss + args.decay * sum([torch.norm(p, p=args.regularization) for p in params_without_bn])
                    loss_total.backward()
                    optimizer.step()
                    optimizer.zero_grad()

                if _ema is not None:
                    _ema(_model, _epoch * total_steps + example_id)

                for idx in torch.argmax(pred_softmax, dim=1):
                    tta_cnt[idx] += 1

                pred_correct = torch.Tensor([x[y] for x, y in zip(corrects, torch.argmax(pred_final, dim=1))])
                orac_correct = torch.Tensor([x[y] for x, y in zip(corrects, torch.argmax(softmin_target, dim=1))])
                defa_correct = corrects[:, encoded_tta_default()]

                pred_loss = torch.Tensor([x[y] for x, y in zip(losses, torch.argmax(pred_final, dim=1))])
                defa_loss = losses[:, encoded_tta_default()]
                corr_p = prediction_correlation(pred_final, softmin_target)

                metric.add('loss', loss.item())
                metric.add('l_l2t', torch.mean(pred_loss).item())
                metric.add('l_org', torch.mean(defa_loss).item())
                metric.add('top1_l2t', torch.mean(pred_correct).item())
                metric.add('top1_oracle', torch.mean(orac_correct).item())
                metric.add('top1_org', torch.mean(defa_correct).item())
                metric.add('corr_p', corr_p)
                metric.add('cnt', 1)
                tqdm_loader.set_postfix(
                    lr=_optimizer.param_groups[0]['lr'] if _optimizer is not None else 0,
                    l=metric['loss'] / metric['cnt'],
                    l_l2t=metric['l_l2t'] / metric['cnt'],
                    l_org=metric['l_org'] / metric['cnt'],
                    # l_curr=loss.item(),
                    corr_p=metric['corr_p'] / metric['cnt'],
                    acc_l2t=metric['top1_l2t'] / metric['cnt'],
                    acc_org=metric['top1_org'] / metric['cnt'],
                    acc_d=(metric['top1_l2t'] - metric['top1_org']) / metric['cnt'],
                    acc_O=metric['top1_oracle'] / metric['cnt'],
                    # tta_top=decode_desc(np.argmax(tta_cnt)),
                    # tta_max='%.2f(%d)' % (max(tta_cnt) / float(sum(tta_cnt)), np.argmax(tta_cnt)),
                    ttas=f'{tta_cnt[0]/sum(tta_cnt):.2f},{tta_cnt[-3]/sum(tta_cnt):.2f},{tta_cnt[-2]/sum(tta_cnt):.2f},{tta_cnt[-1]/sum(tta_cnt):.2f}'
                    # tta_min='%.2f' % (min(tta_cnt) / float(sum(tta_cnt))),
                    # grad_l2=metric['grad_l2'] / metric['cnt'],
                )

                batch = []
                if _scheduler is not None:
                    _scheduler.step(_epoch + (float(example_id) / total_steps))
                del pred, loss
        except KeyboardInterrupt as e:
            if 'test' not in _tag:
                raise e
            pass
        finally:
            tqdm_loader.close()

        del tqdm_loader, batch

        if 'test' in _tag:
            if metric['top1_l2t'] >= metric['top1_org']:
                c = 107     # green
            else:
                c = 124     # red

        else:
            if metric['top1_l2t'] >= metric['top1_org']:
                c = 149
            else:
                c = 14      # light_cyan
        logger.info(f'[{_tag} epoch={_epoch + 1}] ' + stylize(
            'loss=%.4f l(l2t=%.4f org=%.4f) top1_O=%.4f top1_org=%.4f << corr_p=%.4f delta=%.4f %s(%s)>>' %
            (metric['loss'] / metric['cnt'],
             metric['l_l2t'] / metric['cnt'], metric['l_org'] / metric['cnt'],
             metric['top1_oracle'] / metric['cnt'],
             metric['top1_l2t'] / metric['cnt'],
             metric['top1_org'] / metric['cnt'],
             metric['corr_p'] / metric['cnt'],
             (metric['top1_l2t'] / metric['cnt']) - (metric['top1_org'] / metric['cnt']),
             decode_desc(np.argmax(tta_cnt)), '%.2f(%d)' % (max(tta_cnt) / float(sum(tta_cnt)), np.argmax(tta_cnt)),
             )
        , colored.fg(c)))
        return metric