Example #1
0
def loss_mixup_without_prefetcher(model, loss_fn, input, target, epoch, args):
    if args.mixup > 0.:
        lam = 1.
        if not args.mixup_off_epoch or epoch < args.mixup_off_epoch:
            lam = np.random.beta(args.mixup, args.mixup)
        input.mul_(lam).add_(1 - lam, input.flip(0))
        target = mixup_target(target, args.num_classes, lam, args.smoothing)
        output = model(input)
        loss = loss_fn(output, target)
        return loss
Example #2
0
def train_epoch(epoch,
                model,
                loader,
                optimizer,
                loss_fn,
                args,
                lr_scheduler=None,
                saver=None,
                output_dir='',
                use_amp=False,
                model_ema=None):

    if args.prefetcher and args.mixup > 0 and loader.mixup_enabled:
        if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
            loader.mixup_enabled = False

    batch_time_m = AverageMeter()
    data_time_m = AverageMeter()
    losses_m = AverageMeter()

    model.train()

    end = time.time()
    last_idx = len(loader) - 1
    num_updates = epoch * len(loader)
    for batch_idx, (input, target) in enumerate(loader):
        last_batch = batch_idx == last_idx
        data_time_m.update(time.time() - end)
        if not args.prefetcher:
            input = input.cuda()
            target = target.cuda()
            if args.mixup > 0.:
                lam = 1.
                if not args.mixup_off_epoch or epoch < args.mixup_off_epoch:
                    lam = np.random.beta(args.mixup, args.mixup)
                input.mul_(lam).add_(1 - lam, input.flip(0))
                target = mixup_target(target, args.num_classes, lam,
                                      args.smoothing)

        output = model(input)

        loss = loss_fn(output, target)
        if not args.distributed:
            losses_m.update(loss.item(), input.size(0))

        optimizer.zero_grad()
        if use_amp:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        optimizer.step()

        torch.cuda.synchronize()
        if model_ema is not None:
            model_ema.update(model)
        num_updates += 1

        batch_time_m.update(time.time() - end)
        if last_batch or batch_idx % args.log_interval == 0:
            lrl = [param_group['lr'] for param_group in optimizer.param_groups]
            lr = sum(lrl) / len(lrl)

            if args.distributed:
                reduced_loss = reduce_tensor(loss.data, args.world_size)
                losses_m.update(reduced_loss.item(), input.size(0))

            if args.local_rank == 0:
                logging.info(
                    'Train: {} [{:>4d}/{} ({:>3.0f}%)]  '
                    'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f})  '
                    'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '
                    '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'LR: {lr:.3e}  '
                    'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
                        epoch,
                        batch_idx,
                        len(loader),
                        100. * batch_idx / last_idx,
                        loss=losses_m,
                        batch_time=batch_time_m,
                        rate=input.size(0) * args.world_size /
                        batch_time_m.val,
                        rate_avg=input.size(0) * args.world_size /
                        batch_time_m.avg,
                        lr=lr,
                        data_time=data_time_m))

                if args.save_images and output_dir:
                    torchvision.utils.save_image(
                        input,
                        os.path.join(output_dir,
                                     'train-batch-%d.jpg' % batch_idx),
                        padding=0,
                        normalize=True)

        if saver is not None and args.recovery_interval and (
                last_batch or (batch_idx + 1) % args.recovery_interval == 0):
            saver.save_recovery(model,
                                optimizer,
                                args,
                                epoch,
                                model_ema=model_ema,
                                batch_idx=batch_idx)

        if lr_scheduler is not None:
            lr_scheduler.step_update(num_updates=num_updates,
                                     metric=losses_m.avg)

        end = time.time()

    return OrderedDict([('loss', losses_m.avg)])
Example #3
0
def train_epoch(epoch,
                model,
                loader,
                optimizer,
                loss_fn,
                args,
                lr_scheduler=None,
                saver=None,
                output_dir='',
                use_amp=False,
                model_ema=None,
                teacher_model=None):

    if args.prefetcher and args.mixup > 0 and loader.mixup_enabled:
        if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
            loader.mixup_enabled = False

    batch_time_m = AverageMeter()
    data_time_m = AverageMeter()
    losses_m = AverageMeter()

    model.train()
    if args.KD_train:
        teacher_model.eval()

    end = time.time()
    last_idx = len(loader) - 1
    num_updates = epoch * len(loader)
    for batch_idx, (input, target) in enumerate(loader):
        last_batch = batch_idx == last_idx
        data_time_m.update(time.time() - end)
        if not args.prefetcher:
            input = input.cuda()
            target = target.cuda()
            if args.mixup > 0.:
                lam = 1.
                if not args.mixup_off_epoch or epoch < args.mixup_off_epoch:
                    lam = np.random.beta(args.mixup, args.mixup)
                input.mul_(lam).add_(1 - lam, input.flip(0))
                target = mixup_target(target, args.num_classes, lam,
                                      args.smoothing)

        r = np.random.rand(1)
        if args.beta > 0 and r < args.cutmix_prob:
            # generate mixed sample
            lam = np.random.beta(args.beta, args.beta)
            rand_index = torch.randperm(input.size()[0]).cuda()
            target_a = target
            target_b = target[rand_index]
            bbx1, bby1, bbx2, bby2 = rand_bbox(input.size(), lam)
            input[:, :, bbx1:bbx2, bby1:bby2] = input[rand_index, :, bbx1:bbx2,
                                                      bby1:bby2]
            # adjust lambda to exactly match pixel ratio
            lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) /
                       (input.size()[-1] * input.size()[-2]))
            # compute output
            input_var = torch.autograd.Variable(input, requires_grad=True)
            target_a_var = torch.autograd.Variable(target_a)
            target_b_var = torch.autograd.Variable(target_b)
            output = model(input_var)
            loss = loss_fn(output, target_a_var) * lam + loss_fn(
                output, target_b_var) * (1. - lam)
        else:
            # NOTE KD Train is exclusive with mixcut, FIX it later
            output = model(input)
            if args.KD_train:
                # teacher_model.cuda()
                teacher_outputs_tmp = []
                assert (input.shape[0] % args.teacher_step == 0)
                step_size = int(input.shape[0] // args.teacher_step)
                with torch.no_grad():
                    for k in range(0, int(input.shape[0]), step_size):
                        input_tmp = input[k:k + step_size, :, :, :]
                        teacher_outputs_tmp.append(teacher_model(input_tmp))
                        # torch.cuda.empty_cache()
                # import pdb; pdb.set_trace()
                teacher_outputs = torch.cat(teacher_outputs_tmp)
                alpha = args.KD_alpha
                T = args.KD_temperature
                loss = loss_fn(F.log_softmax(output/T, dim=1),
                                F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \
                F.cross_entropy(output, target) * (1. - alpha)
            else:
                loss = loss_fn(output, target)
        if not args.distributed:
            losses_m.update(loss.item(), input.size(0))

        optimizer.zero_grad()
        if use_amp:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        optimizer.step()

        torch.cuda.synchronize()
        if model_ema is not None:
            model_ema.update(model)
        num_updates += 1

        batch_time_m.update(time.time() - end)
        if last_batch or batch_idx % args.log_interval == 0:
            lrl = [param_group['lr'] for param_group in optimizer.param_groups]
            lr = sum(lrl) / len(lrl)

            if args.distributed:
                reduced_loss = reduce_tensor(loss.data, args.world_size)
                losses_m.update(reduced_loss.item(), input.size(0))

            if args.local_rank == 0:
                logging.info(
                    'Train: {} [{:>4d}/{} ({:>3.0f}%)]  '
                    'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f})  '
                    'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '
                    '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'LR: {lr:.3e}  '
                    'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
                        epoch,
                        batch_idx,
                        len(loader),
                        100. * batch_idx / last_idx,
                        loss=losses_m,
                        batch_time=batch_time_m,
                        rate=input.size(0) * args.world_size /
                        batch_time_m.val,
                        rate_avg=input.size(0) * args.world_size /
                        batch_time_m.avg,
                        lr=lr,
                        data_time=data_time_m))

                if args.save_images and output_dir:
                    torchvision.utils.save_image(
                        input,
                        os.path.join(output_dir,
                                     'train-batch-%d.jpg' % batch_idx),
                        padding=0,
                        normalize=True)

        if saver is not None and args.recovery_interval and (
                last_batch or (batch_idx + 1) % args.recovery_interval == 0):
            save_epoch = epoch + 1 if last_batch else epoch
            saver.save_recovery(model,
                                optimizer,
                                args,
                                save_epoch,
                                model_ema=model_ema,
                                batch_idx=batch_idx)

        if lr_scheduler is not None:
            lr_scheduler.step_update(num_updates=num_updates,
                                     metric=losses_m.avg)

        end = time.time()

    return OrderedDict([('loss', losses_m.avg)])