Ejemplo n.º 1
0
def compute_loss(input, target, real):
    real = real.unsqueeze(1)

    target = utils.one_hot(target, NUM_CLASSES)
    target = torch.where(real, target, utils.label_smoothing(target, LABEL_SMOOTHING))

    loss = softmax_cross_entropy(input=input, target=target)

    return loss
Ejemplo n.º 2
0
def train_epoch(model, optimizer, scheduler, data_loader, fold, epoch):
    writer = SummaryWriter(
        os.path.join(args.experiment_path, 'fold{}'.format(fold), 'train'))

    metrics = {
        'loss': utils.Mean(),
    }

    update_transforms(np.linspace(0, 1, config.epochs)[epoch - 1].item())
    model.train()
    optimizer.zero_grad()
    for i, (images, feats, _, labels, _) in enumerate(
            tqdm(data_loader, desc='epoch {} train'.format(epoch)), 1):
        images, feats, labels = images.to(DEVICE), feats.to(DEVICE), labels.to(
            DEVICE)
        labels = utils.one_hot(labels, NUM_CLASSES)
        images, labels = cutmix(images, labels)
        logits = model(images, object(), object())

        loss = compute_loss(input=logits, target=labels)
        metrics['loss'].update(loss.data.cpu().numpy())
        labels = labels.argmax(1)

        lr = scheduler.get_lr()
        (loss.mean() / config.opt.acc_steps).backward()

        if i % config.opt.acc_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        scheduler.step()

    with torch.no_grad():
        metrics = {k: metrics[k].compute_and_reset() for k in metrics}
        images = images_to_rgb(images)[:16]
        print('[FOLD {}][EPOCH {}][TRAIN] {}'.format(
            fold, epoch,
            ', '.join('{}: {:.4f}'.format(k, metrics[k]) for k in metrics)))
        for k in metrics:
            writer.add_scalar(k, metrics[k], global_step=epoch)
        writer.add_scalar('learning_rate', lr, global_step=epoch)
        writer.add_image('images',
                         torchvision.utils.make_grid(
                             images,
                             nrow=math.ceil(math.sqrt(images.size(0))),
                             normalize=True),
                         global_step=epoch)
Ejemplo n.º 3
0
def lsep_loss(input, target, exp):
    target = utils.one_hot(target, NUM_CLASSES)
    pos_mask = target > 0.5
    neg_mask = target <= 0.5

    loss = []
    for e in np.unique(exp):
        e_mask = torch.tensor(exp == e, dtype=pos_mask.dtype, device=input.device)
        e_mask = e_mask.unsqueeze(1)

        pos_examples = input[pos_mask & e_mask]
        neg_examples = input[neg_mask & e_mask]

        pos_examples = pos_examples.unsqueeze(1)
        neg_examples = neg_examples.unsqueeze(0)

        loss.append(torch.log(1 + torch.sum(torch.exp(neg_examples - pos_examples), 1)))

    loss = torch.cat(loss, 0)

    return loss
Ejemplo n.º 4
0
def eval_epoch(model, data_loader, fold, epoch):
    writer = SummaryWriter(
        os.path.join(args.experiment_path, 'fold{}'.format(fold), 'eval'))

    metrics = {
        'loss': utils.Mean(),
    }

    model.eval()
    with torch.no_grad():
        fold_labels = []
        fold_logits = []
        fold_exps = []

        for images, exps, labels, _ in tqdm(
                data_loader, desc='epoch {} evaluation'.format(epoch)):
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            labels = utils.one_hot(labels, NUM_CLASSES)
            logits = model(images, None)

            loss = compute_loss(input=logits, target=labels, unsup=False)
            metrics['loss'].update(loss.data.cpu().numpy())
            labels = labels.argmax(1)

            fold_labels.append(labels)
            fold_logits.append(logits)
            fold_exps.extend(exps)

        fold_labels = torch.cat(fold_labels, 0)
        fold_logits = torch.cat(fold_logits, 0)

        if epoch % 10 == 0:
            temp, metric, fig = find_temp_global(input=fold_logits,
                                                 target=fold_labels,
                                                 exps=fold_exps)
            writer.add_scalar('temp', temp, global_step=epoch)
            writer.add_scalar('metric_final', metric, global_step=epoch)
            writer.add_figure('temps', fig, global_step=epoch)

        temp = 1.  # use default temp
        fold_preds = assign_classes(probs=to_prob(fold_logits,
                                                  temp).data.cpu().numpy(),
                                    exps=fold_exps)
        fold_preds = torch.tensor(fold_preds).to(fold_logits.device)
        metric = compute_metric(input=fold_preds,
                                target=fold_labels,
                                exps=fold_exps)

        metrics = {k: metrics[k].compute_and_reset() for k in metrics}
        for k in metric:
            metrics[k] = metric[k].mean().data.cpu().numpy()
        images = images_to_rgb(images)[:16]
        print('[FOLD {}][EPOCH {}][EVAL] {}'.format(
            fold, epoch,
            ', '.join('{}: {:.4f}'.format(k, metrics[k]) for k in metrics)))
        for k in metrics:
            writer.add_scalar(k, metrics[k], global_step=epoch)
        writer.add_image('images',
                         torchvision.utils.make_grid(
                             images,
                             nrow=math.ceil(math.sqrt(images.size(0))),
                             normalize=True),
                         global_step=epoch)

        return metrics
Ejemplo n.º 5
0
def train_epoch(model, optimizer, scheduler, data_loader, unsup_data_loader,
                fold, epoch):
    assert len(data_loader) <= len(unsup_data_loader), (len(data_loader),
                                                        len(unsup_data_loader))

    writer = SummaryWriter(
        os.path.join(args.experiment_path, 'fold{}'.format(fold), 'train'))

    metrics = {
        'loss': utils.Mean(),
    }

    update_transforms(np.linspace(0, 1, config.epochs)[epoch - 1].item())
    data = zip(data_loader, unsup_data_loader)
    total = min(len(data_loader), len(unsup_data_loader))
    model.train()
    optimizer.zero_grad()
    for i, ((images_s, _, labels_s, _), (images_u, _, _)) \
            in enumerate(tqdm(data, desc='epoch {} train'.format(epoch), total=total), 1):
        images_s, labels_s, images_u = images_s.to(DEVICE), labels_s.to(
            DEVICE), images_u.to(DEVICE)
        labels_s = utils.one_hot(labels_s, NUM_CLASSES)

        with torch.no_grad():
            b, n, c, h, w = images_u.size()
            images_u = images_u.view(b * n, c, h, w)
            logits_u = model(images_u, None, True)
            logits_u = logits_u.view(b, n, NUM_CLASSES)
            labels_u = logits_u.softmax(2).mean(1, keepdim=True)
            labels_u = labels_u.repeat(1, n, 1).view(b * n, NUM_CLASSES)
            labels_u = dist_sharpen(labels_u, temp=SHARPEN_TEMP)

        assert images_s.size() == images_u.size()
        assert labels_s.size() == labels_u.size()

        images, labels = torch.cat([images_s, images_u],
                                   0), torch.cat([labels_s, labels_u], 0)
        images, labels = mixup(images, labels)
        assert images.size(0) == config.batch_size * 2
        logits = model(images, None, True)

        loss = compute_loss(input=logits, target=labels, unsup=True)
        metrics['loss'].update(loss.data.cpu().numpy())
        labels = labels.argmax(1)

        lr = scheduler.get_lr()
        (loss.mean() / config.opt.acc_steps).backward()

        if i % config.opt.acc_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        scheduler.step()

    with torch.no_grad():
        metrics = {k: metrics[k].compute_and_reset() for k in metrics}
        images = images_to_rgb(images)[:16]
        print('[FOLD {}][EPOCH {}][TRAIN] {}'.format(
            fold, epoch,
            ', '.join('{}: {:.4f}'.format(k, metrics[k]) for k in metrics)))
        for k in metrics:
            writer.add_scalar(k, metrics[k], global_step=epoch)
        writer.add_scalar('learning_rate', lr, global_step=epoch)
        writer.add_image('images',
                         torchvision.utils.make_grid(
                             images,
                             nrow=math.ceil(math.sqrt(images.size(0))),
                             normalize=True),
                         global_step=epoch)
Ejemplo n.º 6
0
def lr_search(train_eval_data):
    train_eval_dataset = TrainEvalDataset(train_eval_data,
                                          transform=train_transform)
    train_eval_data_loader = torch.utils.data.DataLoader(
        train_eval_dataset,
        batch_size=config.batch_size,
        drop_last=True,
        shuffle=True,
        num_workers=args.workers,
        worker_init_fn=worker_init_fn)

    min_lr = 1e-7
    max_lr = 10.
    gamma = (max_lr / min_lr)**(1 / len(train_eval_data_loader))

    lrs = []
    losses = []
    lim = None

    model = Model(config.model, NUM_CLASSES)
    model = model.to(DEVICE)

    optimizer = build_optimizer(config.opt, model.parameters())
    for param_group in optimizer.param_groups:
        param_group['lr'] = min_lr
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma)

    optimizer.train()
    update_transforms(1.)
    model.train()
    optimizer.zero_grad()
    for i, (images, _, labels,
            _) in enumerate(tqdm(train_eval_data_loader, desc='lr search'), 1):
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        labels = utils.one_hot(labels, NUM_CLASSES)
        images, labels = mixup(images, labels)
        logits = model(images, None, True)

        loss = compute_loss(input=logits, target=labels)
        labels = labels.argmax(1)

        lrs.append(np.squeeze(scheduler.get_lr()))
        losses.append(loss.data.cpu().numpy().mean())

        if lim is None:
            lim = losses[0] * 1.1

        if lim < losses[-1]:
            break

        (loss.mean() / config.opt.acc_steps).backward()

        if i % config.opt.acc_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        scheduler.step()

    writer = SummaryWriter(os.path.join(args.experiment_path, 'lr_search'))

    with torch.no_grad():
        losses = np.clip(losses, 0, lim)
        minima_loss = losses[np.argmin(utils.smooth(losses))]
        minima_lr = lrs[np.argmin(utils.smooth(losses))]

        step = 0
        for loss, loss_sm in zip(losses, utils.smooth(losses)):
            writer.add_scalar('search_loss', loss, global_step=step)
            writer.add_scalar('search_loss_sm', loss_sm, global_step=step)
            step += config.batch_size

        fig = plt.figure()
        plt.plot(lrs, losses)
        plt.plot(lrs, utils.smooth(losses))
        plt.axvline(minima_lr)
        plt.xscale('log')
        plt.title('loss: {:.8f}, lr: {:.8f}'.format(minima_loss, minima_lr))
        writer.add_figure('search', fig, global_step=0)

        return minima_lr
Ejemplo n.º 7
0
def one_hot(input):
    return utils.one_hot(input, num_classes=NUM_CLASSES).permute((0, 3, 1, 2))