예제 #1
0
def run_epoch(model, loader, loss_fn, optimizer, desc_default='', epoch=0, writer=None, verbose=1, scheduler=None):
    tqdm_disable = bool(os.environ.get('TASK_NAME', ''))    # KakaoBrain Environment
    if verbose:
        loader = tqdm(loader, disable=tqdm_disable)
        loader.set_description('[%s %04d/%04d]' % (desc_default, epoch, C.get()['epoch']))

    metrics = Accumulator()
    cnt = 0
    total_steps = len(loader)
    steps = 0
    for data, label in loader:
        steps += 1
        data, label = data.cuda(), label.cuda()

        if optimizer:
            optimizer.zero_grad()

        preds = model(data)
        loss = loss_fn(preds, label)

        if optimizer:
            loss.backward()
            if getattr(optimizer, "synchronize", None):
                optimizer.synchronize()     # for horovod
            if C.get()['optimizer'].get('clip', 5) > 0:
                nn.utils.clip_grad_norm_(model.parameters(), C.get()['optimizer'].get('clip', 5))
            optimizer.step()

        top1, top5 = accuracy(preds, label, (1, 5))
        metrics.add_dict({
            'loss': loss.item() * len(data),
            'top1': top1.item() * len(data),
            'top5': top5.item() * len(data),
        })
        cnt += len(data)
        if verbose:
            postfix = metrics / cnt
            if optimizer:
                postfix['lr'] = optimizer.param_groups[0]['lr']
            loader.set_postfix(postfix)

        if scheduler is not None:
            scheduler.step(epoch - 1 + float(steps) / total_steps)

        del preds, loss, top1, top5, data, label

    if tqdm_disable:
        if optimizer:
            logger.info('[%s %03d/%03d] %s lr=%.6f', desc_default, epoch, C.get()['epoch'], metrics / cnt, optimizer.param_groups[0]['lr'])
        else:
            logger.info('[%s %03d/%03d] %s', desc_default, epoch, C.get()['epoch'], metrics / cnt)

    metrics /= cnt
    if optimizer:
        metrics.metrics['lr'] = optimizer.param_groups[0]['lr']
    if verbose:
        for key, value in metrics.items():
            writer.add_scalar(key, value, epoch)
    return metrics
예제 #2
0
def run_epoch(model,
              loader,
              loss_fn,
              optimizer,
              desc_default='',
              epoch=0,
              writer=None,
              verbose=1):
    if verbose:
        loader = tqdm(loader)
        if optimizer:
            curr_lr = optimizer.param_groups[0]['lr']
            loader.set_description(
                '[%s %04d/%04d] lr=%.4f' %
                (desc_default, epoch, C.get()['epoch'], curr_lr))
        else:
            loader.set_description('[%s %04d/%04d]' %
                                   (desc_default, epoch, C.get()['epoch']))

    metrics = Accumulator()
    cnt = 0
    for data, label in loader:
        data, label = data.cuda(), label.cuda()

        if optimizer:
            optimizer.zero_grad()

        preds = model(data)
        loss = loss_fn(preds, label)

        if optimizer:
            nn.utils.clip_grad_norm_(model.parameters(), 5)
            loss.backward()
            optimizer.step()

        top1, top5 = accuracy(preds, label, (1, 5))

        metrics.add_dict({
            'loss': loss.item() * len(data),
            'top1': top1.item() * len(data),
            'top5': top5.item() * len(data),
        })
        cnt += len(data)
        if verbose:
            loader.set_postfix(metrics / cnt)

        del preds, loss, top1, top5, data, label

    metrics /= cnt
    if optimizer:
        metrics.metrics['lr'] = optimizer.param_groups[0]['lr']
    if verbose:
        for key, value in metrics.items():
            writer.add_scalar(key, value, epoch)
    return metrics
예제 #3
0
def run_epoch(model,
              loader,
              loss_fn,
              optimizer,
              desc_default='',
              epoch=0,
              writer=None,
              verbose=1,
              scheduler=None,
              is_master=True,
              ema=None,
              wd=0.0,
              tqdm_disabled=False):
    if verbose:
        loader = tqdm(loader, disable=tqdm_disabled)
        loader.set_description('[%s %04d/%04d]' %
                               (desc_default, epoch, C.get()['epoch']))

    params_without_bn = [
        params for name, params in model.named_parameters()
        if not ('_bn' in name or '.bn' in name)
    ]

    loss_ema = None
    metrics = Accumulator()
    cnt = 0
    total_steps = len(loader)
    steps = 0
    for data, label in loader:
        steps += 1
        data, label = data.cuda(), label.cuda()

        if C.get().conf.get('mixup', 0.0) <= 0.0 or optimizer is None:
            preds = model(data)
            loss = loss_fn(preds, label)
        else:  # mixup
            data, targets, shuffled_targets, lam = mixup(
                data, label,
                C.get()['mixup'])
            preds = model(data)
            loss = loss_fn(preds, targets, shuffled_targets, lam)
            del shuffled_targets, lam

        if optimizer:
            loss += wd * (1. / 2.) * sum(
                [torch.sum(p**2) for p in params_without_bn])
            loss.backward()
            grad_clip = C.get()['optimizer'].get('clip', 5.0)
            if grad_clip > 0:
                nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()
            optimizer.zero_grad()

            if ema is not None:
                ema(model, (epoch - 1) * total_steps + steps)

        top1, top5 = accuracy(preds, label, (1, 5))
        metrics.add_dict({
            'loss': loss.item() * len(data),
            'top1': top1.item() * len(data),
            'top5': top5.item() * len(data),
        })
        cnt += len(data)
        if loss_ema:
            loss_ema = loss_ema * 0.9 + loss.item() * 0.1
        else:
            loss_ema = loss.item()
        if verbose:
            postfix = metrics / cnt
            if optimizer:
                postfix['lr'] = optimizer.param_groups[0]['lr']
            postfix['loss_ema'] = loss_ema
            loader.set_postfix(postfix)

        if scheduler is not None:
            scheduler.step(epoch - 1 + float(steps) / total_steps)

        del preds, loss, top1, top5, data, label

    if tqdm_disabled and verbose:
        if optimizer:
            logger.info('[%s %03d/%03d] %s lr=%.6f', desc_default, epoch,
                        C.get()['epoch'], metrics / cnt,
                        optimizer.param_groups[0]['lr'])
        else:
            logger.info('[%s %03d/%03d] %s', desc_default, epoch,
                        C.get()['epoch'], metrics / cnt)

    metrics /= cnt
    if optimizer:
        metrics.metrics['lr'] = optimizer.param_groups[0]['lr']
    if verbose:
        for key, value in metrics.items():
            writer.add_scalar(key, value, epoch)
    return metrics
def run_epoch(model, loader, loss_fn, optimizer, desc_default='', epoch=0, writer=None, verbose=1, scheduler=None):
    model_name = C.get()['model']['type']
    alpha = C.get()['alpha']
    skip_ratios = ListAverageMeter()
    tqdm_disable = bool(os.environ.get('TASK_NAME', ''))
    if verbose:
        loader = tqdm(loader, disable=tqdm_disable)
        loader.set_description('[%s %04d/%04d]' % (desc_default, epoch, C.get()['epoch']))

    metrics = Accumulator()
    cnt = 0
    total_steps = len(loader)
    steps = 0
    for data, label in loader:
        steps += 1
        data, label = data.cuda(), label.cuda()

        if optimizer:
            optimizer.zero_grad()

        if model_name == 'pyramid_skip':
            if desc_default == '*test':
                with torch.no_grad():
                    preds, masks, gprobs = model(data)
                skips = [mask.data.le(0.5).float().mean() for mask in masks]
                if skip_ratios.len != len(skips):
                    skip_ratios.set_len(len(skips))
                skip_ratios.update(skips, data.size(0))
            else:
                preds, masks, gprobs = model(data)

            sparsity_loss = 0
            for mask in masks:
                sparsity_loss += mask.mean()
            loss1 = loss_fn(preds, label)
            loss2 = alpha * sparsity_loss
            loss = loss1 + loss2
        else:
            preds = model(data)
            loss = loss_fn(preds, label)

        if optimizer:
            loss.backward()
            if getattr(optimizer, "synchronize", None):
                optimizer.skip_synchronize()
            if C.get()['optimizer'].get('clip', 5) > 0:
                nn.utils.clip_grad_norm_(model.parameters(), C.get()['optimizer'].get('clip', 5))

            optimizer.step()

        top1, top5 = accuracy(preds, label, (1, 5))

        if model_name == 'pyramid_skip':
            metrics.add_dict({
                'loss1': loss1.item() * len(data),
                'loss2': loss2.item() * len(data),
                'top1': top1.item() * len(data),
                'top5': top5.item() * len(data),
            })
        else:
            metrics.add_dict({
                'loss': loss.item() * len(data),
                'top1': top1.item() * len(data),
                'top5': top5.item() * len(data),
            })
        cnt += len(data)
        if verbose:
            postfix = metrics / cnt
            if optimizer:
                postfix['lr'] = optimizer.param_groups[0]['lr']
            loader.set_postfix(postfix)

        # if scheduler is not None:
        #     scheduler.step(epoch - 1 + float(steps) / total_steps)

        if model_name == 'pyramid_skip':
            del masks[:], gprobs[:]
        del preds, loss, top1, top5, data, label

    if model_name == 'pyramid_skip':
        if desc_default == '*test':
            skip_summaries = []
            for idx in range(skip_ratios.len):
                skip_summaries.append(1 - skip_ratios.avg[idx])
            cp = ((sum(skip_summaries) + 1) / (len(skip_summaries) + 1)) * 100

    if tqdm_disable:
        logger.info('[%s %03d/%03d] %s', desc_default, epoch, C.get()['epoch'], metrics / cnt)

    metrics /= cnt
    if optimizer:
        metrics.metrics['lr'] = optimizer.param_groups[0]['lr']
    if verbose:
        for key, value in metrics.items():
            writer.add_scalar(key, value, epoch)
        if model_name == 'pyramid_skip':
            if desc_default == '*test':
                writer.add_scalar('Computation Percentage', cp, epoch)
    return metrics
예제 #5
0
def train_controller(controller, dataloaders, save_path, ctl_save_path):
    dataset = C.get()['test_dataset']
    ctl_train_steps = 1500
    ctl_num_aggre = 10
    ctl_entropy_w = 1e-5
    ctl_ema_weight = 0.95
    metrics = Accumulator()
    cnt = 0

    controller.train()
    test_ratio = 0.
    _, _, dataloader, _ = dataloaders  # validloader
    optimizer = optim.SGD(controller.parameters(),
                          lr=0.00035,
                          momentum=0.9,
                          weight_decay=0.0,
                          nesterov=True)
    # optimizer = optim.Adam(controller.parameters(), lr = 0.00035)
    # create a model & a criterion
    model = get_model(C.get()['model'], num_class(dataset), local_rank=-1)
    criterion = CrossEntropyLabelSmooth(num_class(dataset),
                                        C.get().conf.get('lb_smooth', 0),
                                        reduction="batched_sum").cuda()
    # load model weights
    data = torch.load(save_path)
    key = 'model' if 'model' in data else 'state_dict'

    if 'epoch' not in data:
        model.load_state_dict(data)
    else:
        logger.info('checkpoint epoch@%d' % data['epoch'])
        if not isinstance(model, (DataParallel, DistributedDataParallel)):
            model.load_state_dict(
                {k.replace('module.', ''): v
                 for k, v in data[key].items()})
        else:
            model.load_state_dict({
                k if 'module.' in k else 'module.' + k: v
                for k, v in data[key].items()
            })
    del data

    model.eval()
    loader_iter = iter(dataloader)  # [(image)->ToTensor->Normalize]
    baseline = None
    if os.path.isfile(ctl_save_path):
        logger.info('------Controller load------')
        checkpoint = torch.load(ctl_save_path)
        controller.load_state_dict(checkpoint['ctl_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        cnt = checkpoint['cnt']
        mean_probs = checkpoint['mean_probs']
        accs = checkpoint['accs']
        metrics_dict = checkpoint['metrics']
        metrics.metrics = metrics_dict
        init_step = checkpoint['step']
    else:
        logger.info('------Train Controller from scratch------')
        mean_probs = []
        accs = []
        init_step = 0
    for step in tqdm(range(init_step + 1,
                           ctl_train_steps * ctl_num_aggre + 1)):
        try:
            inputs, labels = next(loader_iter)
        except:
            loader_iter = iter(dataloader)
            inputs, labels = next(loader_iter)
        batch_size = len(labels)
        inputs, labels = inputs.cuda(), labels.cuda()
        log_probs, entropys, sampled_policies = controller(inputs)
        # evaluate model with augmented validation dataset
        with torch.no_grad():
            # compare Accuracy before/after augmentation
            # ori_preds = model(inputs)
            # ori_top1, ori_top5 = accuracy(ori_preds, labels, (1, 5))
            batch_policies = batch_policy_decoder(
                sampled_policies
            )  # (list:list:list:tuple) [batch, num_policy, n_op, 3]
            aug_inputs, applied_policy = augment_data(inputs, batch_policies)
            aug_inputs = aug_inputs.cuda()
            # assert type(aug_inputs) == torch.Tensor, "Augmented Input Type Error: {}".format(type(aug_inputs))
            preds = model(aug_inputs)
            model_losses = criterion(preds, labels)  # (tensor)[batch]
            top1, top5 = accuracy(preds, labels, (1, 5))
            # logger.info("Acc B/A Aug, {:.2f}->{:.2f}".format(ori_top1, top1))
        # assert model_losses.shape == entropys.shape == log_probs.shape, \
        #         "[Size miss match] loss: {}, entropy: {}, log_prob: {}".format(model_losses.shape, entropys.shape, log_probs.shape)
        rewards = -model_losses + ctl_entropy_w * entropys  # (tensor)[batch]
        if baseline is None:
            baseline = -model_losses.mean()  # scalar tensor
        else:
            # assert baseline, "len(baseline): {}".format(len(baseline))
            baseline = baseline - (1 - ctl_ema_weight) * (
                baseline - rewards.mean().detach())
        # baseline = 0.
        loss = -1 * (log_probs * (rewards - baseline)).mean()  #scalar tensor
        # Average gradient over controller_num_aggregate samples
        loss = loss / ctl_num_aggre
        loss.backward(retain_graph=True)
        metrics.add_dict({
            'loss': loss.item() * batch_size,
            'top1': top1.item() * batch_size,
            'top5': top5.item() * batch_size,
        })
        cnt += batch_size
        if (step + 1) % ctl_num_aggre == 0:
            torch.nn.utils.clip_grad_norm_(controller.parameters(), 5.0)
            optimizer.step()
            controller.zero_grad()
            # torch.cuda.empty_cache()
            logger.info('\n[Train Controller %03d/%03d] log_prob %02f, %s', step, ctl_train_steps*ctl_num_aggre, \
            log_probs.mean().item(), metrics / cnt
            )
        if step % 100 == 0 or step == ctl_train_steps * ctl_num_aggre:
            save_pic(inputs, aug_inputs, labels, applied_policy,
                     batch_policies, step)
            ps = []
            for pol in batch_policies:  # (list:list:list:tuple) [batch, num_policy, n_op, 3]
                for ops in pol:
                    for op in ops:
                        p = op[1]
                        ps.append(p)
            mean_prob = np.mean(ps)
            mean_probs.append(mean_prob)
            accs.append(top1.item())
            print("Mean probability: {:.2f}".format(mean_prob))
            torch.save(
                {
                    'step': step,
                    'ctl_state_dict': controller.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'metrics': dict(metrics.metrics),
                    'cnt': cnt,
                    'mean_probs': mean_probs,
                    'accs': accs
                }, ctl_save_path)
    return metrics, None  #baseline.item()