示例#1
0
def train_gandataset(
    train_loader, model, gan, criterion, optimizer, epoch, args, display=True
):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1],
        prefix="Epoch: [{}]".format(epoch),
    )

    # switch to train mode
    model.train()

    end = time.time()
    for i, (images, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        batch_size = images.size(0)

        if args.gpu is not None:
            images = images.cuda(args.gpu, non_blocking=True)
        target = target.cuda(args.gpu, non_blocking=True)

        output, loss = step_fn(images, target, model, criterion, optimizer)

        # measure accuracy and record loss
        acc1 = accuracy(output, target, topk=(1,))[0]
        losses.update(loss.item(), images.size(0))
        top1.update(acc1, images.size(0))

        gan_images = gan(*gan.generate_input(batch_size))
        fake_target = torch.ones(batch_size).cuda(args.gpu, non_blocking=True).long()

        output, loss = step_fn(gan_images, fake_target, model, criterion, optimizer)

        # measure accuracy and record loss
        acc1 = accuracy(output, target, topk=(1,))[0]
        losses.update(loss.item(), images.size(0))
        top1.update(acc1, images.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0 and display:
            progress.display(i)

    return top1.avg, losses.avg
示例#2
0
def validate(val_loader, model, criterion, args, display=True):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    progress = ProgressMeter(len(val_loader), [batch_time, losses, top1],
                             prefix='Test: ')

    # switch to evaluate mode
    model.eval()
    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)

            # compute output
            output = model(images)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1 = accuracy(output, target, topk=(1, ))[0]
            losses.update(loss.item(), images.size(0))
            top1.update(acc1.item(), images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0 and display:
                progress.display(i)

        if display:
            # TODO: this should also be done with the ProgressMeter
            print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))

    return top1.avg, losses.avg
示例#3
0
def validate(val_loader, model, criterions, loss_weights, args, display=True):
    batch_time = AverageMeter('Time', ':6.3f')
    loss_meters = {
        name: AverageMeter(f'{name} Loss', ':.4e')
        for name in criterions
    }
    loss_meters['full'] = AverageMeter(f'Full loss', ':4e')

    top1 = AverageMeter('Acc@1', ':6.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time] + list(loss_meters.values()) + [top1],
        prefix='Test: ',
    )

    # switch to evaluate mode
    model.eval()
    with torch.no_grad():
        end = time.time()
        for i, (images, target, heatvols, volmask) in enumerate(val_loader):

            # Real videos get zeroed attn mask (eventually, not currently)
            real_inds = (target == 0).nonzero()
            valid_heatvol_inds = volmask.nonzero()
            # all_inds = torch.unique(torch.cat((real_inds, valid_heatvol_inds)))
            all_inds = valid_heatvol_inds.squeeze()

            # Not all videos have valid heatvols, so extract only those that do
            valid_heatvols = heatvols.index_select(0, all_inds)

            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)
            all_inds = all_inds.cuda(args.gpu, non_blocking=True)
            valid_heatvols = valid_heatvols.cuda(args.gpu, non_blocking=True)

            # compute output and retrieve self attn map
            output, attn = model(images)
            # Compute loss for valid attn maps only
            valid_attn = attn.index_select(0, all_inds)
            print(f'valid_attn: {valid_attn.shape}c')

            if valid_attn.size(0) != 0:
                # Rescale heatvol to match the size of the attn vol.
                valid_heatvols = nn.functional.interpolate(
                    valid_heatvols, size=valid_attn.shape[-2:], mode='area')

            # Compute Cross Entropy loss between the predictions and targets
            # Compute KL Divergence loss and Correlation Coefficent loss between
            # human heat volumes and self attn maps.
            bs = valid_attn.size(0)
            losses = {
                'ce':
                loss_weights['ce'] * criterions['ce'](output, target),
                'cc':
                loss_weights['cc'] *
                criterions['cc'](valid_heatvols, valid_attn),
            }
            if valid_attn.nelement() != 0:
                losses['kl'] = (loss_weights['kl'] * criterions['kl'](
                    F.log_softmax(valid_attn.view(bs, -1), dim=-1),
                    F.softmax(valid_heatvols.view(bs, -1), dim=-1),
                ) if valid_attn.size(0) != 0 else torch.tensor(float('nan')))
            losses['full'] = sum(losses.values())

            # measure accuracy and record loss
            acc1 = accuracy(output, target, topk=(1, ))[0]
            for name, loss in losses.items():
                loss_meters[name].update(loss.item(), images.size(0))
            top1.update(acc1.item(), images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0 and display:
                progress.display(i)

        if display:
            # TODO: this should also be done with the ProgressMeter
            print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))

    return top1.avg, loss_meters['full'].avg
示例#4
0
def train(
    train_loader,
    model,
    criterions,
    loss_weights,
    optimizer,
    logger,
    epoch,
    args,
    display=True,
):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    loss_meters = {
        name: AverageMeter(f'{name} Loss', ':.4e')
        for name in criterions
    }
    loss_meters['full'] = AverageMeter(f'Full loss', ':4e')

    top1 = AverageMeter('Acc@1', ':6.2f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time] + list(loss_meters.values()) + [top1],
        prefix="Epoch: [{}]".format(epoch),
    )

    itr = epoch * len(train_loader)
    # switch to train mode
    model.train()

    end = time.time()
    for i, (images, target, heatvols, volmask) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        itr += 1

        # DEBUG
        #         import matplotlib.pyplot as plt
        #         im = images[0][:,0].permute(1,2,0)
        #         plt.imshow((im-im.min())/(im.max()-im.min()))
        #         plt.show()

        # Real videos get zeroed attn mask (eventually, not currently)
        real_inds = (target == 0).nonzero()
        valid_heatvol_inds = volmask.nonzero()
        # all_inds = torch.unique(torch.cat((real_inds, valid_heatvol_inds)))
        all_inds = valid_heatvol_inds.squeeze()

        # Not all videos have valid heatvols, so extract only those that do
        valid_heatvols = heatvols.index_select(0, all_inds)

        if args.gpu is not None:
            images = images.cuda(args.gpu, non_blocking=True)
        target = target.cuda(args.gpu, non_blocking=True)
        all_inds = all_inds.cuda(args.gpu, non_blocking=True)
        valid_heatvols = valid_heatvols.cuda(args.gpu, non_blocking=True)

        # compute output and retrieve self attn map
        output, attn = model(images)
        # Compute loss for valid attn maps only
        valid_attn = attn.index_select(0, all_inds)

        if valid_attn.size(0) != 0:
            # Rescale heatvol to match the size of the attn vol.
            valid_heatvols = nn.functional.interpolate(
                valid_heatvols, size=valid_attn.shape[-2:], mode='area')
        # Compute Cross Entropy loss between the predictions and targets
        # Compute KL Divergence loss and Correlation Coefficent loss between
        # human heat volumes and self attn maps.
        bs = valid_attn.size(0)
        losses = {
            'ce':
            loss_weights['ce'] * criterions['ce'](output, target),
            'cc':
            loss_weights['cc'] * criterions['cc'](valid_heatvols, valid_attn),
        }
        if valid_attn.nelement() != 0:
            losses['kl'] = (loss_weights['kl'] * criterions['kl'](
                F.log_softmax(valid_attn.view(bs, -1), dim=-1),
                F.softmax(valid_heatvols.view(bs, -1), dim=-1),
            ) if valid_attn.size(0) != 0 else torch.tensor(float('nan')))
        losses['full'] = sum(losses.values())

        # measure accuracy and record loss
        acc1 = accuracy(output, target, topk=(1, ))[0]
        for name, loss in losses.items():
            if not torch.isnan(loss):
                loss_meters[name].update(loss.item(), images.size(0))
        top1.update(acc1, images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        losses['full'].backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0 and display:
            progress.display(i)
            logger.log_metrics({
                'Accuracy/train': acc1,
                'Loss/train': loss
            },
                               step=itr)

            # logger.save()

    return top1.avg, loss_meters['full'].avg
示例#5
0
def train(train_loader,
          model,
          criterion,
          optimizer,
          logger,
          epoch,
          args,
          display=True):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1],
        prefix="Epoch: [{}]".format(epoch),
    )

    itr = epoch * len(train_loader)
    # switch to train mode
    model.train()

    end = time.time()
    for i, (images, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        itr += 1

        # DEBUG
        #         import matplotlib.pyplot as plt
        #         im = images[0][:,0].permute(1,2,0)
        #         plt.imshow((im-im.min())/(im.max()-im.min()))
        #         plt.show()

        if args.gpu is not None:
            images = images.cuda(args.gpu, non_blocking=True)
        target = target.cuda(args.gpu, non_blocking=True)

        # compute output
        output = model(images)
        loss = criterion(output, target)

        # measure accuracy and record loss
        acc1 = accuracy(output, target, topk=(1, ))[0]
        losses.update(loss.item(), images.size(0))
        top1.update(acc1, images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0 and display:
            progress.display(i)
            logger.log_metrics({
                'Accuracy/train': acc1,
                'Loss/train': loss
            },
                               step=itr)

            # logger.save()

    return top1.avg, losses.avg