def train_one_epoch(
            self, epoch, net, loader, optimizer, loss_fn,
            lr_scheduler=None, output_dir=None, amp_autocast=suppress,
            loss_scaler=None, model_ema=None, mixup_fn=None, time_limit=math.inf):
        start_tic = time.time()
        if self._augmentation_cfg.mixup_off_epoch and epoch >= self._augmentation_cfg.mixup_off_epoch:
            if self._misc_cfg.prefetcher and loader.mixup_enabled:
                loader.mixup_enabled = False
            elif mixup_fn is not None:
                mixup_fn.mixup_enabled = False

        second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
        losses_m = AverageMeter()
        train_metric_score_m = AverageMeter()

        net.train()

        num_updates = epoch * len(loader)
        self._time_elapsed += time.time() - start_tic
        tic = time.time()
        last_tic = time.time()
        train_metric_name = 'accuracy'
        batch_idx = 0
        for batch_idx, (input, target) in enumerate(loader):
            b_tic = time.time()
            if self._time_elapsed > time_limit:
                return {'train_acc': train_metric_score_m.avg, 'train_loss': losses_m.avg, 'time_limit': True}
            if self._problem_type == REGRESSION:
                target = target.to(torch.float32)
            if not self._misc_cfg.prefetcher:
                # prefetcher would move data to cuda by default
                input, target = input.to(self.ctx[0]), target.to(self.ctx[0])
                if mixup_fn is not None:
                    input, target = mixup_fn(input, target)

            with amp_autocast():
                output = net(input)
                if self._problem_type == REGRESSION:
                    output = output.flatten()
                loss = loss_fn(output, target)
            if self._problem_type == REGRESSION:
                train_metric_name = 'rmse'
                train_metric_score = rmse(output, target)
            else:
                if output.shape == target.shape:
                    train_metric_name = 'rmse'
                    train_metric_score = rmse(output, target)
                else:
                    train_metric_score = accuracy(output, target)[0] / 100

            losses_m.update(loss.item(), input.size(0))
            train_metric_score_m.update(train_metric_score.item(), output.size(0))

            optimizer.zero_grad()
            if loss_scaler is not None:
                loss_scaler(
                    loss, optimizer,
                    clip_grad=self._optimizer_cfg.clip_grad, clip_mode=self._optimizer_cfg.clip_mode,
                    parameters=model_parameters(net, exclude_head='agc' in self._optimizer_cfg.clip_mode),
                    create_graph=second_order)
            else:
                loss.backward(create_graph=second_order)
                if self._optimizer_cfg.clip_grad is not None:
                    dispatch_clip_grad(
                        model_parameters(net, exclude_head='agc' in self._optimizer_cfg.clip_mode),
                        value=self._optimizer_cfg.clip_grad, mode=self._optimizer_cfg.clip_mode)
                optimizer.step()

            if model_ema is not None:
                model_ema.update(net)

            if self.found_gpu:
                torch.cuda.synchronize()

            num_updates += 1
            if (batch_idx+1) % self._misc_cfg.log_interval == 0:
                lrl = [param_group['lr'] for param_group in optimizer.param_groups]
                lr = sum(lrl) / len(lrl)
                self._logger.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f',
                                  epoch, batch_idx,
                                  self._train_cfg.batch_size*self._misc_cfg.log_interval/(time.time()-last_tic),
                                  train_metric_name, train_metric_score_m.avg, lr)
                last_tic = time.time()

                if self._misc_cfg.save_images and output_dir:
                    torchvision.utils.save_image(
                        input,
                        os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
                        padding=0,
                        normalize=True)

            if lr_scheduler is not None:
                lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)

            self._time_elapsed += time.time() - b_tic

        throughput = int(self._train_cfg.batch_size * batch_idx / (time.time() - tic))
        self._logger.info('[Epoch %d] training: %s=%f', epoch, train_metric_name, train_metric_score_m.avg)
        self._logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f', epoch, throughput, time.time()-tic)

        end_time = time.time()
        if hasattr(optimizer, 'sync_lookahead'):
            optimizer.sync_lookahead()

        self._time_elapsed += time.time() - end_time

        return {train_metric_name: train_metric_score_m.avg, 'train_loss': losses_m.avg, 'time_limit': False}
示例#2
0
def train_one_epoch(epoch,
                    model,
                    loader,
                    optimizer,
                    loss_fn,
                    args,
                    lr_scheduler=None,
                    saver=None,
                    output_dir=None,
                    amp_autocast=suppress,
                    loss_scaler=None,
                    model_ema=None,
                    mixup_fn=None,
                    model_KD=None):
    if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
        if args.prefetcher and loader.mixup_enabled:
            loader.mixup_enabled = False
        elif mixup_fn is not None:
            mixup_fn.mixup_enabled = False

    second_order = hasattr(optimizer,
                           'is_second_order') and optimizer.is_second_order
    batch_time_m = AverageMeter()
    data_time_m = AverageMeter()
    losses_m = AverageMeter()

    model.train()

    end = time.time()
    last_idx = len(loader) - 1
    num_updates = epoch * len(loader)
    for batch_idx, (input, target) in enumerate(loader):
        last_batch = batch_idx == last_idx
        data_time_m.update(time.time() - end)
        if not args.prefetcher:
            input, target = input.cuda(), target.cuda()
            if mixup_fn is not None:
                input, target = mixup_fn(input, target)
        if args.channels_last:
            input = input.contiguous(memory_format=torch.channels_last)

        with amp_autocast():
            output = model(input)
            loss = loss_fn(output, target)

            if model_KD is not None:
                # student probability
                prob_s = F.log_softmax(output, dim=-1)

                # teacher probability
                with torch.no_grad():
                    input_kd = model_KD.normalize_input(input, model)
                    out_t = model_KD.model(input_kd.detach())
                    prob_t = F.softmax(out_t, dim=-1)

                # adding KL loss
                loss += args.alpha_kd * F.kl_div(
                    prob_s, prob_t, reduction='batchmean')

        if not args.distributed:
            losses_m.update(loss.item(), input.size(0))

        optimizer.zero_grad()
        if loss_scaler is not None:
            loss_scaler(loss,
                        optimizer,
                        clip_grad=args.clip_grad,
                        clip_mode=args.clip_mode,
                        parameters=model_parameters(model,
                                                    exclude_head='agc'
                                                    in args.clip_mode),
                        create_graph=second_order)
        else:
            loss.backward(create_graph=second_order)
            if args.clip_grad is not None:
                dispatch_clip_grad(model_parameters(model,
                                                    exclude_head='agc'
                                                    in args.clip_mode),
                                   value=args.clip_grad,
                                   mode=args.clip_mode)
            optimizer.step()

        if model_ema is not None:
            model_ema.update(model)

        torch.cuda.synchronize()
        num_updates += 1
        batch_time_m.update(time.time() - end)
        if last_batch or batch_idx % args.log_interval == 0:
            lrl = [param_group['lr'] for param_group in optimizer.param_groups]
            lr = sum(lrl) / len(lrl)

            if args.distributed:
                reduced_loss = reduce_tensor(loss.data, args.world_size)
                losses_m.update(reduced_loss.item(), input.size(0))

            if args.local_rank == 0:
                _logger.info(
                    'Train: {} [{:>4d}/{} ({:>3.0f}%)]  '
                    'Loss: {loss.val:#.4g} ({loss.avg:#.3g})  '
                    'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '
                    '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'LR: {lr:.3e}  '
                    'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
                        epoch,
                        batch_idx,
                        len(loader),
                        100. * batch_idx / last_idx,
                        loss=losses_m,
                        batch_time=batch_time_m,
                        rate=input.size(0) * args.world_size /
                        batch_time_m.val,
                        rate_avg=input.size(0) * args.world_size /
                        batch_time_m.avg,
                        lr=lr,
                        data_time=data_time_m))

                if args.save_images and output_dir:
                    torchvision.utils.save_image(
                        input,
                        os.path.join(output_dir,
                                     'train-batch-%d.jpg' % batch_idx),
                        padding=0,
                        normalize=True)

        if saver is not None and args.recovery_interval and (
                last_batch or (batch_idx + 1) % args.recovery_interval == 0):
            saver.save_recovery(epoch, batch_idx=batch_idx)

        if lr_scheduler is not None:
            lr_scheduler.step_update(num_updates=num_updates,
                                     metric=losses_m.avg)

        end = time.time()
        # end for

    if hasattr(optimizer, 'sync_lookahead'):
        optimizer.sync_lookahead()

    return OrderedDict([('loss', losses_m.avg)])
示例#3
0
def train_one_epoch(epoch,
                    model,
                    loader,
                    optimizer,
                    loss_fn,
                    device,
                    args,
                    lr_scheduler=None,
                    saver=None,
                    output_dir=None):

    second_order = hasattr(optimizer,
                           'is_second_order') and optimizer.is_second_order
    #batch_time_m = AverageMeter()
    #data_time_m = AverageMeter()
    losses_m = AverageMeter()

    model.train()

    #end = time.time()
    last_idx = len(loader) - 1
    num_updates = epoch * len(loader)
    for batch_idx, (input, target) in tqdm(enumerate(loader),
                                           total=len(loader)):
        last_batch = batch_idx == last_idx
        #data_time_m.update(time.time() - end)
        input = input.to(device)
        target = target.to(device)
        output = model(input)
        loss = loss_fn(output, target)

        if not args.distributed:
            losses_m.update(loss.item(), input.size(0))

        optimizer.zero_grad()

        loss.backward(create_graph=second_order)
        if args.clip_grad is not None:
            dispatch_clip_grad(model_parameters(model,
                                                exclude_head='agc'
                                                in args.clip_mode),
                               value=args.clip_grad,
                               mode=args.clip_mode)
        optimizer.step()

        #torch.cuda.synchronize()
        num_updates += 1
        #batch_time_m.update(time.time() - end)
        if last_batch or batch_idx % args.log_interval == 0:
            lrl = [param_group['lr'] for param_group in optimizer.param_groups]
            lr = sum(lrl) / len(lrl)

            if args.local_rank == 0:
                _logger.info('Train: {} [{:>4d}/{} ({:>3.0f}%)]  '
                             'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f})  '
                             'LR: {lr:.3e}  '.format(epoch,
                                                     batch_idx,
                                                     len(loader),
                                                     100. * batch_idx /
                                                     last_idx,
                                                     loss=losses_m,
                                                     lr=lr))

        if lr_scheduler is not None:
            lr_scheduler.step_update(num_updates=num_updates,
                                     metric=losses_m.avg)

        end = time.time()
        # end for

    return OrderedDict([('loss', losses_m.avg)])