Exemplo n.º 1
0
def get_affinity(aug, aff_bases, config, augment):
    C.get()
    C.get().conf = config
    # setup - provided augmentation rules
    C.get()['aug'] = aug
    load_paths = augment['load_paths']
    cv_num = augment["cv_num"]

    aug_loaders = []
    for cv_id in range(cv_num):
        _, tl, validloader, tl2 = get_dataloaders(C.get()['dataset'],
                                                  C.get()['batch'],
                                                  augment['dataroot'],
                                                  augment['cv_ratio_test'],
                                                  split_idx=cv_id)
        aug_loaders.append(validloader)
        del tl, tl2

    loss_fn = torch.nn.CrossEntropyLoss(reduction='none')
    aug_accs = []
    for cv_id, loader in enumerate(aug_loaders):
        # eval
        model = get_model(C.get()['model'], num_class(C.get()['dataset']))
        ckpt = torch.load(load_paths[cv_id])
        if 'model' in ckpt:
            model.load_state_dict(ckpt['model'])
        else:
            model.load_state_dict(ckpt)
        model.eval()

        metrics = Accumulator()
        for data, label in loader:
            data = data.cuda()
            label = label.cuda()

            pred = model(data)
            loss = loss_fn(pred, label)  # (N)

            _, pred = pred.topk(1, 1, True, True)
            pred = pred.t()
            correct = pred.eq(label.view(
                1, -1).expand_as(pred)).detach().cpu().numpy()  # (1,N)

            metrics.add_dict({
                'minus_loss':
                -1 * np.sum(loss.detach().cpu().numpy()),
                'correct':
                np.sum(correct),
                'cnt':
                len(data)
            })
            del loss, correct, pred, data, label
        aug_accs.append(metrics['correct'] / metrics['cnt'])
    del model
    affs = []
    for aug_valid, clean_valid in zip(aug_accs, aff_bases):
        affs.append(aug_valid - clean_valid)
    return affs
Exemplo n.º 2
0
def eval_tta3(config, augment, reporter):
    C.get()
    C.get().conf = config
    save_path = augment['save_path']
    cv_id, gr_id = augment["cv_id"], augment["gr_id"]
    gr_ids = augment["gr_ids"]

    # setup - provided augmentation rules
    C.get()['aug'] = policy_decoder(augment, augment['num_policy'],
                                    augment['num_op'])

    # eval
    model = get_model(C.get()['model'], num_class(C.get()['dataset']))
    ckpt = torch.load(save_path)
    if 'model' in ckpt:
        model.load_state_dict(ckpt['model'])
    else:
        model.load_state_dict(ckpt)
    del ckpt
    model.eval()

    loader = get_post_dataloader(C.get()["dataset"],
                                 C.get()['batch'], augment["dataroot"],
                                 augment['cv_ratio_test'], cv_id, gr_id,
                                 gr_ids)

    start_t = time.time()
    metrics = Accumulator()
    loss_fn = torch.nn.CrossEntropyLoss(reduction='none')
    for data, label in loader:
        data = data.cuda()
        label = label.cuda()

        pred = model(data)
        loss = loss_fn(pred, label)  # (N)

        _, pred = pred.topk(1, 1, True, True)
        pred = pred.t()
        correct = pred.eq(label.view(
            1, -1).expand_as(pred)).detach().cpu().numpy()  # (1,N)

        metrics.add_dict({
            'loss': np.sum(loss.detach().cpu().numpy()),
            'correct': np.sum(correct),
            'cnt': len(data)
        })
        del loss, correct, pred, data, label
    del model, loader
    metrics = metrics / 'cnt'
    gpu_secs = (time.time() - start_t) * torch.cuda.device_count()
    reporter(loss=metrics['loss'],
             top1_valid=metrics['correct'],
             elapsed_time=gpu_secs,
             done=True)
    return metrics['correct']
Exemplo n.º 3
0
def eval_tta2(config, augment, reporter):
    C.get()
    C.get().conf = config
    cv_ratio_test, cv_id, save_path = augment['cv_ratio_test'], augment['cv_id'], augment['save_path']
    gr_id = augment["gr_id"]
    num_repeat = 1

    # setup - provided augmentation rules
    C.get()['aug'] = policy_decoder(augment, augment['num_policy'], augment['num_op'])

    # eval
    model = get_model(C.get()['model'], num_class(C.get()['dataset']))
    ckpt = torch.load(save_path)
    if 'model' in ckpt:
        model.load_state_dict(ckpt['model'])
    else:
        model.load_state_dict(ckpt)
    model.eval()

    loaders = []
    for i in range(num_repeat):
        _, tl, validloader, tl2 = get_dataloaders(C.get()['dataset'], C.get()['batch'], augment['dataroot'], cv_ratio_test, split_idx=cv_id, gr_assign=augment["gr_assign"], gr_id=gr_id)
        loaders.append(validloader)
        del tl, tl2


    start_t = time.time()
    metrics = Accumulator()
    loss_fn = torch.nn.CrossEntropyLoss(reduction='none')
    for loader in loaders:
        for data, label in loader:
            data = data.cuda()
            label = label.cuda()

            pred = model(data)
            loss = loss_fn(pred, label) # (N)

            _, pred = pred.topk(1, 1, True, True)
            pred = pred.t()
            correct = pred.eq(label.view(1, -1).expand_as(pred)).detach().cpu().numpy() # (1,N)

            metrics.add_dict({
                'minus_loss': -1 * np.sum(loss.detach().cpu().numpy()),
                'correct': np.sum(correct),
                'cnt': len(data)
            })
            del loss, correct, pred, data, label
    del model
    metrics = metrics / 'cnt'
    gpu_secs = (time.time() - start_t) * torch.cuda.device_count()
    reporter(minus_loss=metrics['minus_loss'], top1_valid=metrics['correct'], elapsed_time=gpu_secs, done=True)
    return metrics['correct']
Exemplo n.º 4
0
def run_epoch(model, loader, loss_fn, optimizer, desc_default='', epoch=0, writer=None, verbose=1, scheduler=None):
    tqdm_disable = bool(os.environ.get('TASK_NAME', ''))    # KakaoBrain Environment
    if verbose:
        loader = tqdm(loader, disable=tqdm_disable)
        loader.set_description('[%s %04d/%04d]' % (desc_default, epoch, C.get()['epoch']))

    metrics = Accumulator()
    cnt = 0
    total_steps = len(loader)
    steps = 0
    for data, label in loader:
        steps += 1
        data, label = data.cuda(), label.cuda()

        if optimizer:
            optimizer.zero_grad()

        preds = model(data)
        loss = loss_fn(preds, label)

        if optimizer:
            loss.backward()
            if getattr(optimizer, "synchronize", None):
                optimizer.synchronize()     # for horovod
            if C.get()['optimizer'].get('clip', 5) > 0:
                nn.utils.clip_grad_norm_(model.parameters(), C.get()['optimizer'].get('clip', 5))
            optimizer.step()

        top1, top5 = accuracy(preds, label, (1, 5))
        metrics.add_dict({
            'loss': loss.item() * len(data),
            'top1': top1.item() * len(data),
            'top5': top5.item() * len(data),
        })
        cnt += len(data)
        if verbose:
            postfix = metrics / cnt
            if optimizer:
                postfix['lr'] = optimizer.param_groups[0]['lr']
            loader.set_postfix(postfix)

        if scheduler is not None:
            scheduler.step(epoch - 1 + float(steps) / total_steps)

        del preds, loss, top1, top5, data, label

    if tqdm_disable:
        if optimizer:
            logger.info('[%s %03d/%03d] %s lr=%.6f', desc_default, epoch, C.get()['epoch'], metrics / cnt, optimizer.param_groups[0]['lr'])
        else:
            logger.info('[%s %03d/%03d] %s', desc_default, epoch, C.get()['epoch'], metrics / cnt)

    metrics /= cnt
    if optimizer:
        metrics.metrics['lr'] = optimizer.param_groups[0]['lr']
    if verbose:
        for key, value in metrics.items():
            writer.add_scalar(key, value, epoch)
    return metrics
Exemplo n.º 5
0
def run_epoch(model,
              loader,
              loss_fn,
              optimizer,
              desc_default='',
              epoch=0,
              writer=None,
              verbose=1):
    if verbose:
        loader = tqdm(loader)
        if optimizer:
            curr_lr = optimizer.param_groups[0]['lr']
            loader.set_description(
                '[%s %04d/%04d] lr=%.4f' %
                (desc_default, epoch, C.get()['epoch'], curr_lr))
        else:
            loader.set_description('[%s %04d/%04d]' %
                                   (desc_default, epoch, C.get()['epoch']))

    metrics = Accumulator()
    cnt = 0
    for data, label in loader:
        data, label = data.cuda(), label.cuda()

        if optimizer:
            optimizer.zero_grad()

        preds = model(data)
        loss = loss_fn(preds, label)

        if optimizer:
            nn.utils.clip_grad_norm_(model.parameters(), 5)
            loss.backward()
            optimizer.step()

        top1, top5 = accuracy(preds, label, (1, 5))

        metrics.add_dict({
            'loss': loss.item() * len(data),
            'top1': top1.item() * len(data),
            'top5': top5.item() * len(data),
        })
        cnt += len(data)
        if verbose:
            loader.set_postfix(metrics / cnt)

        del preds, loss, top1, top5, data, label

    metrics /= cnt
    if optimizer:
        metrics.metrics['lr'] = optimizer.param_groups[0]['lr']
    if verbose:
        for key, value in metrics.items():
            writer.add_scalar(key, value, epoch)
    return metrics
Exemplo n.º 6
0
def eval_tta(config, augment, reporter, num_class, get_model, get_dataloaders):
    C.get()
    C.get().conf = config
    cv_ratio_test, cv_fold, save_path = (
        augment["cv_ratio_test"],
        augment["cv_fold"],
        augment["save_path"],
    )

    # setup - provided augmentation rules
    C.get()["aug"] = policy_decoder(augment, augment["num_policy"], augment["num_op"])

    # eval
    model = get_model(C.get()["model"], num_class(C.get()["dataset"]))
    ckpt = torch.load(save_path)
    if "model" in ckpt:
        model.load_state_dict(ckpt["model"])
    else:
        model.load_state_dict(ckpt)
    model.eval()

    loaders = []
    for _ in range(augment["num_policy"]):  # TODO
        _, tl, validloader, tl2 = get_dataloaders(
            C.get()["dataset"],
            C.get()["batch"],
            augment["dataroot"],
            cv_ratio_test,
            split_idx=cv_fold,
        )
        loaders.append(iter(validloader))
        del tl, tl2

    start_t = time.time()
    metrics = Accumulator()
    loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
    try:
        while True:
            losses = []
            corrects = []
            for loader in loaders:
                data, label = next(loader)
                data = data.cuda()
                label = label.cuda()

                pred = model(data)

                loss = loss_fn(pred, label)
                losses.append(loss.detach().cpu().numpy())

                _, pred = pred.topk(1, 1, True, True)
                pred = pred.t()
                correct = (
                    pred.eq(label.view(1, -1).expand_as(pred)).detach().cpu().numpy()
                )
                corrects.append(correct)
                del loss, correct, pred, data, label

            losses = np.concatenate(losses)
            losses_min = np.min(losses, axis=0).squeeze()

            corrects = np.concatenate(corrects)
            corrects_max = np.max(corrects, axis=0).squeeze()
            metrics.add_dict(
                {
                    "minus_loss": -1 * np.sum(losses_min),
                    "correct": np.sum(corrects_max),
                    "cnt": len(corrects_max),
                }
            )
            del corrects, corrects_max
    except StopIteration:
        pass

    del model
    metrics = metrics / "cnt"
    gpu_secs = (time.time() - start_t) * torch.cuda.device_count()
    reporter(
        minus_loss=metrics["minus_loss"],
        top1_valid=metrics["correct"],
        elapsed_time=gpu_secs,
        done=True,
    )
    return metrics["correct"]
Exemplo n.º 7
0
def run_epoch(model,
              loader,
              loss_fn,
              optimizer,
              desc_default='',
              epoch=0,
              writer=None,
              verbose=1,
              scheduler=None,
              is_master=True,
              ema=None,
              wd=0.0,
              tqdm_disabled=False):
    if verbose:
        loader = tqdm(loader, disable=tqdm_disabled)
        loader.set_description('[%s %04d/%04d]' %
                               (desc_default, epoch, C.get()['epoch']))

    params_without_bn = [
        params for name, params in model.named_parameters()
        if not ('_bn' in name or '.bn' in name)
    ]

    loss_ema = None
    metrics = Accumulator()
    cnt = 0
    total_steps = len(loader)
    steps = 0
    for data, label in loader:
        steps += 1
        data, label = data.cuda(), label.cuda()

        if C.get().conf.get('mixup', 0.0) <= 0.0 or optimizer is None:
            preds = model(data)
            loss = loss_fn(preds, label)
        else:  # mixup
            data, targets, shuffled_targets, lam = mixup(
                data, label,
                C.get()['mixup'])
            preds = model(data)
            loss = loss_fn(preds, targets, shuffled_targets, lam)
            del shuffled_targets, lam

        if optimizer:
            loss += wd * (1. / 2.) * sum(
                [torch.sum(p**2) for p in params_without_bn])
            loss.backward()
            grad_clip = C.get()['optimizer'].get('clip', 5.0)
            if grad_clip > 0:
                nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()
            optimizer.zero_grad()

            if ema is not None:
                ema(model, (epoch - 1) * total_steps + steps)

        top1, top5 = accuracy(preds, label, (1, 5))
        metrics.add_dict({
            'loss': loss.item() * len(data),
            'top1': top1.item() * len(data),
            'top5': top5.item() * len(data),
        })
        cnt += len(data)
        if loss_ema:
            loss_ema = loss_ema * 0.9 + loss.item() * 0.1
        else:
            loss_ema = loss.item()
        if verbose:
            postfix = metrics / cnt
            if optimizer:
                postfix['lr'] = optimizer.param_groups[0]['lr']
            postfix['loss_ema'] = loss_ema
            loader.set_postfix(postfix)

        if scheduler is not None:
            scheduler.step(epoch - 1 + float(steps) / total_steps)

        del preds, loss, top1, top5, data, label

    if tqdm_disabled and verbose:
        if optimizer:
            logger.info('[%s %03d/%03d] %s lr=%.6f', desc_default, epoch,
                        C.get()['epoch'], metrics / cnt,
                        optimizer.param_groups[0]['lr'])
        else:
            logger.info('[%s %03d/%03d] %s', desc_default, epoch,
                        C.get()['epoch'], metrics / cnt)

    metrics /= cnt
    if optimizer:
        metrics.metrics['lr'] = optimizer.param_groups[0]['lr']
    if verbose:
        for key, value in metrics.items():
            writer.add_scalar(key, value, epoch)
    return metrics
Exemplo n.º 8
0
def eval_tta(config, augment):
    C.get()
    C.get().conf = config
    cv_ratio_test, cv_fold, save_path = augment['cv_ratio_test'], augment[
        'cv_fold'], augment['save_path']
    print(augment)
    # setup - provided augmentation rules
    C.get()['aug'] = policy_decoder(augment, augment['num_policy'],
                                    augment['num_op'])

    # eval
    model = get_model(C.get()['model'], num_class(C.get()['dataset']))
    ckpt = torch.load(save_path)
    if 'model' in ckpt:
        model.load_state_dict(ckpt['model'])
    else:
        model.load_state_dict(ckpt)
    model.eval()

    loaders = []
    for _ in range(augment['num_policy']):  # TODO
        _, tl, validloader, tl2 = get_dataloaders(C.get()['dataset'],
                                                  C.get()['batch'],
                                                  augment['dataroot'],
                                                  cv_ratio_test,
                                                  split_idx=cv_fold)
        loaders.append(iter(validloader))
        del tl, tl2

    start_t = time.time()
    metrics = Accumulator()
    loss_fn = torch.nn.CrossEntropyLoss(reduction='none')
    try:
        while True:
            losses = []
            corrects = []
            for loader in loaders:
                data, label = next(loader)
                data = data.cuda()
                label = label.cuda()

                pred = model(data)

                loss = loss_fn(pred, label)
                losses.append(loss.detach().cpu().numpy())

                _, pred = pred.topk(1, 1, True, True)
                pred = pred.t()
                correct = pred.eq(label.view(
                    1, -1).expand_as(pred)).detach().cpu().numpy()
                corrects.append(correct)
                del loss, correct, pred, data, label

            losses = np.concatenate(losses)
            losses_min = np.min(losses, axis=0).squeeze()

            corrects = np.concatenate(corrects)
            corrects_max = np.max(corrects, axis=0).squeeze()
            metrics.add_dict({
                'minus_loss': -1 * np.sum(losses_min),
                'correct': np.sum(corrects_max),
                'cnt': len(corrects_max)
            })
            del corrects, corrects_max
    except StopIteration:
        pass

    del model
    metrics = metrics / 'cnt'
    gpu_secs = (time.time() - start_t) * torch.cuda.device_count()
    # reporter(minus_loss=metrics['minus_loss'], top1_valid=metrics['correct'], elapsed_time=gpu_secs, done=True)
    tune.track.log(minus_loss=metrics['minus_loss'],
                   top1_valid=metrics['correct'],
                   elapsed_time=gpu_secs,
                   done=True)
    return metrics['correct']
def run_epoch(model, loader, loss_fn, optimizer, desc_default='', epoch=0, writer=None, verbose=1, scheduler=None):
    model_name = C.get()['model']['type']
    alpha = C.get()['alpha']
    skip_ratios = ListAverageMeter()
    tqdm_disable = bool(os.environ.get('TASK_NAME', ''))
    if verbose:
        loader = tqdm(loader, disable=tqdm_disable)
        loader.set_description('[%s %04d/%04d]' % (desc_default, epoch, C.get()['epoch']))

    metrics = Accumulator()
    cnt = 0
    total_steps = len(loader)
    steps = 0
    for data, label in loader:
        steps += 1
        data, label = data.cuda(), label.cuda()

        if optimizer:
            optimizer.zero_grad()

        if model_name == 'pyramid_skip':
            if desc_default == '*test':
                with torch.no_grad():
                    preds, masks, gprobs = model(data)
                skips = [mask.data.le(0.5).float().mean() for mask in masks]
                if skip_ratios.len != len(skips):
                    skip_ratios.set_len(len(skips))
                skip_ratios.update(skips, data.size(0))
            else:
                preds, masks, gprobs = model(data)

            sparsity_loss = 0
            for mask in masks:
                sparsity_loss += mask.mean()
            loss1 = loss_fn(preds, label)
            loss2 = alpha * sparsity_loss
            loss = loss1 + loss2
        else:
            preds = model(data)
            loss = loss_fn(preds, label)

        if optimizer:
            loss.backward()
            if getattr(optimizer, "synchronize", None):
                optimizer.skip_synchronize()
            if C.get()['optimizer'].get('clip', 5) > 0:
                nn.utils.clip_grad_norm_(model.parameters(), C.get()['optimizer'].get('clip', 5))

            optimizer.step()

        top1, top5 = accuracy(preds, label, (1, 5))

        if model_name == 'pyramid_skip':
            metrics.add_dict({
                'loss1': loss1.item() * len(data),
                'loss2': loss2.item() * len(data),
                'top1': top1.item() * len(data),
                'top5': top5.item() * len(data),
            })
        else:
            metrics.add_dict({
                'loss': loss.item() * len(data),
                'top1': top1.item() * len(data),
                'top5': top5.item() * len(data),
            })
        cnt += len(data)
        if verbose:
            postfix = metrics / cnt
            if optimizer:
                postfix['lr'] = optimizer.param_groups[0]['lr']
            loader.set_postfix(postfix)

        # if scheduler is not None:
        #     scheduler.step(epoch - 1 + float(steps) / total_steps)

        if model_name == 'pyramid_skip':
            del masks[:], gprobs[:]
        del preds, loss, top1, top5, data, label

    if model_name == 'pyramid_skip':
        if desc_default == '*test':
            skip_summaries = []
            for idx in range(skip_ratios.len):
                skip_summaries.append(1 - skip_ratios.avg[idx])
            cp = ((sum(skip_summaries) + 1) / (len(skip_summaries) + 1)) * 100

    if tqdm_disable:
        logger.info('[%s %03d/%03d] %s', desc_default, epoch, C.get()['epoch'], metrics / cnt)

    metrics /= cnt
    if optimizer:
        metrics.metrics['lr'] = optimizer.param_groups[0]['lr']
    if verbose:
        for key, value in metrics.items():
            writer.add_scalar(key, value, epoch)
        if model_name == 'pyramid_skip':
            if desc_default == '*test':
                writer.add_scalar('Computation Percentage', cp, epoch)
    return metrics
Exemplo n.º 10
0
def train_controller(controller, dataloaders, save_path, ctl_save_path):
    dataset = C.get()['test_dataset']
    ctl_train_steps = 1500
    ctl_num_aggre = 10
    ctl_entropy_w = 1e-5
    ctl_ema_weight = 0.95
    metrics = Accumulator()
    cnt = 0

    controller.train()
    test_ratio = 0.
    _, _, dataloader, _ = dataloaders  # validloader
    optimizer = optim.SGD(controller.parameters(),
                          lr=0.00035,
                          momentum=0.9,
                          weight_decay=0.0,
                          nesterov=True)
    # optimizer = optim.Adam(controller.parameters(), lr = 0.00035)
    # create a model & a criterion
    model = get_model(C.get()['model'], num_class(dataset), local_rank=-1)
    criterion = CrossEntropyLabelSmooth(num_class(dataset),
                                        C.get().conf.get('lb_smooth', 0),
                                        reduction="batched_sum").cuda()
    # load model weights
    data = torch.load(save_path)
    key = 'model' if 'model' in data else 'state_dict'

    if 'epoch' not in data:
        model.load_state_dict(data)
    else:
        logger.info('checkpoint epoch@%d' % data['epoch'])
        if not isinstance(model, (DataParallel, DistributedDataParallel)):
            model.load_state_dict(
                {k.replace('module.', ''): v
                 for k, v in data[key].items()})
        else:
            model.load_state_dict({
                k if 'module.' in k else 'module.' + k: v
                for k, v in data[key].items()
            })
    del data

    model.eval()
    loader_iter = iter(dataloader)  # [(image)->ToTensor->Normalize]
    baseline = None
    if os.path.isfile(ctl_save_path):
        logger.info('------Controller load------')
        checkpoint = torch.load(ctl_save_path)
        controller.load_state_dict(checkpoint['ctl_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        cnt = checkpoint['cnt']
        mean_probs = checkpoint['mean_probs']
        accs = checkpoint['accs']
        metrics_dict = checkpoint['metrics']
        metrics.metrics = metrics_dict
        init_step = checkpoint['step']
    else:
        logger.info('------Train Controller from scratch------')
        mean_probs = []
        accs = []
        init_step = 0
    for step in tqdm(range(init_step + 1,
                           ctl_train_steps * ctl_num_aggre + 1)):
        try:
            inputs, labels = next(loader_iter)
        except:
            loader_iter = iter(dataloader)
            inputs, labels = next(loader_iter)
        batch_size = len(labels)
        inputs, labels = inputs.cuda(), labels.cuda()
        log_probs, entropys, sampled_policies = controller(inputs)
        # evaluate model with augmented validation dataset
        with torch.no_grad():
            # compare Accuracy before/after augmentation
            # ori_preds = model(inputs)
            # ori_top1, ori_top5 = accuracy(ori_preds, labels, (1, 5))
            batch_policies = batch_policy_decoder(
                sampled_policies
            )  # (list:list:list:tuple) [batch, num_policy, n_op, 3]
            aug_inputs, applied_policy = augment_data(inputs, batch_policies)
            aug_inputs = aug_inputs.cuda()
            # assert type(aug_inputs) == torch.Tensor, "Augmented Input Type Error: {}".format(type(aug_inputs))
            preds = model(aug_inputs)
            model_losses = criterion(preds, labels)  # (tensor)[batch]
            top1, top5 = accuracy(preds, labels, (1, 5))
            # logger.info("Acc B/A Aug, {:.2f}->{:.2f}".format(ori_top1, top1))
        # assert model_losses.shape == entropys.shape == log_probs.shape, \
        #         "[Size miss match] loss: {}, entropy: {}, log_prob: {}".format(model_losses.shape, entropys.shape, log_probs.shape)
        rewards = -model_losses + ctl_entropy_w * entropys  # (tensor)[batch]
        if baseline is None:
            baseline = -model_losses.mean()  # scalar tensor
        else:
            # assert baseline, "len(baseline): {}".format(len(baseline))
            baseline = baseline - (1 - ctl_ema_weight) * (
                baseline - rewards.mean().detach())
        # baseline = 0.
        loss = -1 * (log_probs * (rewards - baseline)).mean()  #scalar tensor
        # Average gradient over controller_num_aggregate samples
        loss = loss / ctl_num_aggre
        loss.backward(retain_graph=True)
        metrics.add_dict({
            'loss': loss.item() * batch_size,
            'top1': top1.item() * batch_size,
            'top5': top5.item() * batch_size,
        })
        cnt += batch_size
        if (step + 1) % ctl_num_aggre == 0:
            torch.nn.utils.clip_grad_norm_(controller.parameters(), 5.0)
            optimizer.step()
            controller.zero_grad()
            # torch.cuda.empty_cache()
            logger.info('\n[Train Controller %03d/%03d] log_prob %02f, %s', step, ctl_train_steps*ctl_num_aggre, \
            log_probs.mean().item(), metrics / cnt
            )
        if step % 100 == 0 or step == ctl_train_steps * ctl_num_aggre:
            save_pic(inputs, aug_inputs, labels, applied_policy,
                     batch_policies, step)
            ps = []
            for pol in batch_policies:  # (list:list:list:tuple) [batch, num_policy, n_op, 3]
                for ops in pol:
                    for op in ops:
                        p = op[1]
                        ps.append(p)
            mean_prob = np.mean(ps)
            mean_probs.append(mean_prob)
            accs.append(top1.item())
            print("Mean probability: {:.2f}".format(mean_prob))
            torch.save(
                {
                    'step': step,
                    'ctl_state_dict': controller.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'metrics': dict(metrics.metrics),
                    'cnt': cnt,
                    'mean_probs': mean_probs,
                    'accs': accs
                }, ctl_save_path)
    return metrics, None  #baseline.item()
Exemplo n.º 11
0
def eval_tta(config, augment, reporter):
    C.get()
    C.get().conf = config
    save_path = augment['save_path']
    cv_id, gr_id = augment["cv_id"], augment["gr_id"]
    gr_ids = augment["gr_ids"]

    # setup - provided augmentation rules
    C.get()['aug'] = policy_decoder(augment, augment['num_policy'],
                                    augment['num_op'])

    # eval
    model = get_model(C.get()['model'], num_class(C.get()['dataset']))
    ckpt = torch.load(save_path)
    if 'model' in ckpt:
        model.load_state_dict(ckpt['model'])
    else:
        model.load_state_dict(ckpt)
    model.eval()

    loaders = []
    for _ in range(augment['num_policy']):  # TODO
        loader = get_post_dataloader(C.get()["dataset"],
                                     C.get()['batch'], augment["dataroot"],
                                     augment['cv_ratio_test'], cv_id, gr_id,
                                     gr_ids)
        loaders.append(iter(loader))

    start_t = time.time()
    metrics = Accumulator()
    loss_fn = torch.nn.CrossEntropyLoss(reduction='none')
    try:
        while True:
            losses = []
            corrects = []
            for loader in loaders:
                data, label = next(loader)
                data = data.cuda()
                label = label.cuda()

                pred = model(data)

                loss = loss_fn(pred, label)
                losses.append(loss.detach().cpu().numpy().reshape(1,
                                                                  -1))  # (1,N)

                _, pred = pred.topk(1, 1, True, True)
                pred = pred.t()
                correct = pred.eq(label.view(
                    1, -1).expand_as(pred)).detach().cpu().numpy()  # (1,N)
                corrects.append(correct)
                del loss, correct, pred, data, label

            losses = np.concatenate(losses)
            losses_min = np.min(losses, axis=0).squeeze()  # (N,)

            corrects = np.concatenate(corrects)
            corrects_max = np.max(corrects, axis=0).squeeze()  # (N,)
            metrics.add_dict({
                'loss': np.sum(losses_min),
                'correct': np.sum(corrects_max),
                'cnt': corrects_max.size
            })
            del corrects, corrects_max
    except StopIteration:
        pass

    del model
    metrics = metrics / 'cnt'
    gpu_secs = (time.time() - start_t) * torch.cuda.device_count()
    reporter(loss=metrics['loss'],
             top1_valid=metrics['correct'],
             elapsed_time=gpu_secs,
             done=True)
    return metrics['correct']