Ejemplo n.º 1
0
def train_epoch(model, optimizer, scheduler, data_loader, fold, epoch):
    writer = SummaryWriter(
        os.path.join(args.experiment_path, 'fold{}'.format(fold), 'train'))

    metrics = {
        'loss': utils.Mean(),
    }

    update_transforms(np.linspace(0, 1, config.epochs)[epoch - 1].item())
    model.train()
    optimizer.zero_grad()
    for i, (images, feats, _, labels, _) in enumerate(
            tqdm(data_loader, desc='epoch {} train'.format(epoch)), 1):
        images, feats, labels = images.to(DEVICE), feats.to(DEVICE), labels.to(
            DEVICE)
        labels = utils.one_hot(labels, NUM_CLASSES)
        images, labels = cutmix(images, labels)
        logits = model(images, object(), object())

        loss = compute_loss(input=logits, target=labels)
        metrics['loss'].update(loss.data.cpu().numpy())
        labels = labels.argmax(1)

        lr = scheduler.get_lr()
        (loss.mean() / config.opt.acc_steps).backward()

        if i % config.opt.acc_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        scheduler.step()

    with torch.no_grad():
        metrics = {k: metrics[k].compute_and_reset() for k in metrics}
        images = images_to_rgb(images)[:16]
        print('[FOLD {}][EPOCH {}][TRAIN] {}'.format(
            fold, epoch,
            ', '.join('{}: {:.4f}'.format(k, metrics[k]) for k in metrics)))
        for k in metrics:
            writer.add_scalar(k, metrics[k], global_step=epoch)
        writer.add_scalar('learning_rate', lr, global_step=epoch)
        writer.add_image('images',
                         torchvision.utils.make_grid(
                             images,
                             nrow=math.ceil(math.sqrt(images.size(0))),
                             normalize=True),
                         global_step=epoch)
Ejemplo n.º 2
0
def lr_search(train_eval_data):
    train_eval_dataset = TrainEvalDataset(train_eval_data,
                                          transform=train_transform)
    train_eval_data_loader = torch.utils.data.DataLoader(
        train_eval_dataset,
        batch_size=config.batch_size,
        drop_last=True,
        shuffle=True,
        num_workers=args.workers,
        worker_init_fn=worker_init_fn)

    min_lr = 1e-7
    max_lr = 10.
    gamma = (max_lr / min_lr)**(1 / len(train_eval_data_loader))

    lrs = []
    losses = []
    lim = None

    model = Model(config.model, NUM_CLASSES)
    model = model.to(DEVICE)

    optimizer = build_optimizer(config.opt, model.parameters())
    for param_group in optimizer.param_groups:
        param_group['lr'] = min_lr
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma)

    optimizer.train()
    update_transforms(1.)
    model.train()
    optimizer.zero_grad()
    for i, (images, feats, _, labels,
            _) in enumerate(tqdm(train_eval_data_loader, desc='lr search'), 1):
        images, feats, labels = images.to(DEVICE), feats.to(DEVICE), labels.to(
            DEVICE)
        labels = utils.one_hot(labels, NUM_CLASSES)
        images, labels = cutmix(images, labels)
        logits = model(images, object(), object())

        loss = compute_loss(input=logits, target=labels)
        labels = labels.argmax(1)

        lrs.append(np.squeeze(scheduler.get_lr()))
        losses.append(loss.data.cpu().numpy().mean())

        if lim is None:
            lim = losses[0] * 1.1

        if lim < losses[-1]:
            break

        (loss.mean() / config.opt.acc_steps).backward()

        if i % config.opt.acc_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        scheduler.step()

    writer = SummaryWriter(os.path.join(args.experiment_path, 'lr_search'))

    with torch.no_grad():
        losses = np.clip(losses, 0, lim)
        minima_loss = losses[np.argmin(utils.smooth(losses))]
        minima_lr = lrs[np.argmin(utils.smooth(losses))]

        step = 0
        for loss, loss_sm in zip(losses, utils.smooth(losses)):
            writer.add_scalar('search_loss', loss, global_step=step)
            writer.add_scalar('search_loss_sm', loss_sm, global_step=step)
            step += config.batch_size

        fig = plt.figure()
        plt.plot(lrs, losses)
        plt.plot(lrs, utils.smooth(losses))
        plt.axvline(minima_lr)
        plt.xscale('log')
        plt.title('loss: {:.8f}, lr: {:.8f}'.format(minima_loss, minima_lr))
        writer.add_figure('search', fig, global_step=0)

        return minima_lr
Ejemplo n.º 3
0
    init_time = time()

    epoch_train_loss, epoch_train_metric, epoch_val_loss, epoch_val_metric = [], [], [], []
    i1 = trainloader
    if args['verbose'] == 1:
        i1 = tqdm(i1)

    # Training
    model.train()
    for batch in i1:
        x, labels, masks = batch['img'].cuda(), batch['label'].cuda(
        ), batch['mask'].cuda()

        # Cutmix
        pivot = int(x.size(0) * args['cutmix_prob'])
        x[:pivot], targets = cutmix(x[:pivot], labels[:pivot])
        if args['gridmask_ratio'] < 1.:
            x[pivot:] *= masks[pivot:]

        preds = model(x)
        loss_cutmix = mix_criterion(preds[:pivot],
                                    targets,
                                    criterion,
                                    reduction='none')
        loss = criterion(preds[pivot:], labels[pivot:], reduction='none')

        with amp.scale_loss(
                torch.cat([loss_cutmix, loss]).mean() / args['accum_steps'],
                op) as scaled_loss:
            scaled_loss.backward()
        if (global_steps + 1) % args['accum_steps'] == 0:
    def train(self, epoch: int):
        ''' method to train your model for epoch '''
        losses = AverageMeter()
        accuracy = AverageMeter()
        # switch to train mode and train one epoch
        self.model.train()
        loop = tqdm(enumerate(self.train_loader), total=len(self.train_loader), leave=False)
        for i, (input_, target) in loop:
            if i == self.config.test_steps:
                break
            input_ = input_.to(self.device)
            target = target.to(self.device)
            # compute output and loss
            if self.config.aug.type_aug:
                if self.config.aug.type_aug == 'mixup':
                    aug_output = mixup_target(input_, target, self.config, self.device)
                else:
                    assert self.config.aug.type_aug == 'cutmix'
                    aug_output = cutmix(input_, target, self.config, self.device)
                input_, target_a, target_b, lam = aug_output
                tuple_target = (target_a, target_b, lam)
                if self.config.multi_task_learning:
                    hot_target = lam*F.one_hot(target_a[:,0], 2) + (1-lam)*F.one_hot(target_b[:,0], 2)
                else:
                    hot_target = lam*F.one_hot(target_a, 2) + (1-lam)*F.one_hot(target_b, 2)
                output = self.make_output(input_, hot_target)
                if self.config.multi_task_learning:
                    loss = self.multi_task_criterion(output, tuple_target)
                else:
                    loss = self.mixup_criterion(self.criterion, output,
                                                target_a, target_b, lam, 2)
            else:
                new_target = (F.one_hot(target[:,0], num_classes=2)
                            if self.config.multi_task_learning
                            else F.one_hot(target, num_classes=2))
                output = self.make_output(input_, new_target)
                loss = (self.multi_task_criterion(output, target)
                        if self.config.multi_task_learning
                        else self.criterion(output, new_target))

            # compute gradient and do SGD step
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # measure accuracy
            s = self.config.loss.amsoftmax.s
            acc = (precision(output[0], target[:,0].reshape(-1), s)
                  if self.config.multi_task_learning
                  else precision(output, target, s))
            # record loss
            losses.update(loss.item(), input_.size(0))
            accuracy.update(acc, input_.size(0))

            # write to writer for tensorboard
            self.writer.add_scalar('Train/loss', loss, global_step=self.train_step)
            self.writer.add_scalar('Train/accuracy', accuracy.avg, global_step=self.train_step)
            self.train_step += 1

            # update progress bar
            max_epochs = self.config.epochs.max_epoch
            loop.set_description(f'Epoch [{epoch}/{max_epochs}]')
            loop.set_postfix(loss=loss.item(), avr_loss = losses.avg,
                             acc=acc, avr_acc=accuracy.avg,
                             lr=self.optimizer.param_groups[0]['lr'])
        return losses.avg, accuracy.avg
Ejemplo n.º 5
0
def lr_search(train_data, train_transform, config):
    train_dataset = LabeledDataset(train_data, transform=train_transform)
    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.train.batch_size,
        drop_last=True,
        shuffle=True,
        num_workers=config.workers,
        worker_init_fn=worker_init_fn)

    min_lr = 1e-5
    max_lr = 1.
    gamma = (max_lr / min_lr)**(1 / len(train_data_loader))

    lrs = []
    losses = []
    lim = None

    model = Model(config.model,
                  num_classes=CLASS_META['num_classes'].sum()).to(DEVICE)
    optimizer = build_optimizer(model.parameters(), config.train.optimizer)
    for param_group in optimizer.param_groups:
        param_group['lr'] = min_lr
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma)

    # update_transforms(1.)
    model.train()
    for images, targets, _ in tqdm(train_data_loader, desc='lr_search'):
        images, targets = images.to(DEVICE), targets.to(DEVICE)

        if config.train.cutmix is not None:
            images, targets = utils.cutmix(images, targets,
                                           config.train.cutmix)

        logits, etc = model(images)

        loss = compute_loss(input=logits, target=targets, config=config.train)

        lrs.append(np.squeeze(scheduler.get_lr()))
        losses.append(loss.data.cpu().numpy().mean())

        if lim is None:
            lim = losses[0] * 1.1
        if lim < losses[-1]:
            break

        loss.mean().backward()
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()

    writer = SummaryWriter(os.path.join(config.experiment_path, 'lr_search'))

    with torch.no_grad():
        losses = np.clip(losses, 0, lim)
        minima_loss = losses[np.argmin(utils.smooth(losses))]
        minima_lr = lrs[np.argmin(utils.smooth(losses))]

        step = 0
        for loss, loss_sm in zip(losses, utils.smooth(losses)):
            writer.add_scalar('search_loss', loss, global_step=step)
            writer.add_scalar('search_loss_sm', loss_sm, global_step=step)
            step += config.train.batch_size

        fig = plt.figure()
        plt.plot(lrs, losses)
        plt.plot(lrs, utils.smooth(losses))
        plt.axvline(minima_lr)
        plt.xscale('log')
        plt.title('loss: {:.8f}, lr: {:.8f}'.format(minima_loss, minima_lr))
        writer.add_figure('lr_search', fig, global_step=0)
        print(minima_lr)

    writer.flush()
    writer.close()

    return minima_lr
Ejemplo n.º 6
0
def train_epoch(model, data_loader, fold_probs, optimizer, scheduler, epoch,
                config):
    writer = SummaryWriter(
        os.path.join(config.experiment_path, 'F{}'.format(config.fold),
                     'train'))
    metrics = {
        'loss': Mean(),
        'loss_hist': Concat(),
        'entropy': Mean(),
        'lr': Last(),
    }

    model.train()
    for images, targets, indices in tqdm(data_loader,
                                         desc='[F{}][epoch {}] train'.format(
                                             config.fold, epoch)):
        images, targets, indices = images.to(DEVICE), targets.to(
            DEVICE), indices.to(DEVICE)

        if epoch >= config.train.self_distillation.start_epoch:
            targets = weighted_sum(
                targets, fold_probs[indices],
                config.train.self_distillation.target_weight)
        if config.train.cutmix is not None:
            if np.random.uniform() > (epoch - 1) / (config.epochs - 1):
                images, targets = utils.cutmix(images, targets,
                                               config.train.cutmix)

        logits, etc = model(images)

        loss = compute_loss(input=logits, target=targets, config=config.train)

        metrics['loss'].update(loss.data.cpu().numpy())
        metrics['loss_hist'].update(loss.data.cpu().numpy())
        metrics['entropy'].update(compute_entropy(logits).data.cpu().numpy())
        metrics['lr'].update(np.squeeze(scheduler.get_lr()))

        loss.mean().backward()
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()

        # FIXME:
        if epoch >= config.train.self_distillation.start_epoch:
            probs = torch.cat(
                [i.softmax(-1) for i in split_target(logits.detach())], -1)
            fold_probs[indices] = weighted_sum(
                fold_probs[indices], probs,
                config.train.self_distillation.pred_ewa)

    for k in metrics:
        if k.endswith('_hist'):
            writer.add_histogram(k,
                                 metrics[k].compute_and_reset(),
                                 global_step=epoch)
        else:
            writer.add_scalar(k,
                              metrics[k].compute_and_reset(),
                              global_step=epoch)
    writer.add_image('images',
                     torchvision.utils.make_grid(images,
                                                 nrow=compute_nrow(images),
                                                 normalize=True),
                     global_step=epoch)
    if 'stn' in etc:
        writer.add_image('stn',
                         torchvision.utils.make_grid(etc['stn'],
                                                     nrow=compute_nrow(
                                                         etc['stn']),
                                                     normalize=True),
                         global_step=epoch)

    writer.flush()
    writer.close()