def train_gandataset( train_loader, model, gan, criterion, optimizer, epoch, args, display=True ): batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') progress = ProgressMeter( len(train_loader), [batch_time, data_time, losses, top1], prefix="Epoch: [{}]".format(epoch), ) # switch to train mode model.train() end = time.time() for i, (images, target) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) batch_size = images.size(0) if args.gpu is not None: images = images.cuda(args.gpu, non_blocking=True) target = target.cuda(args.gpu, non_blocking=True) output, loss = step_fn(images, target, model, criterion, optimizer) # measure accuracy and record loss acc1 = accuracy(output, target, topk=(1,))[0] losses.update(loss.item(), images.size(0)) top1.update(acc1, images.size(0)) gan_images = gan(*gan.generate_input(batch_size)) fake_target = torch.ones(batch_size).cuda(args.gpu, non_blocking=True).long() output, loss = step_fn(gan_images, fake_target, model, criterion, optimizer) # measure accuracy and record loss acc1 = accuracy(output, target, topk=(1,))[0] losses.update(loss.item(), images.size(0)) top1.update(acc1, images.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0 and display: progress.display(i) return top1.avg, losses.avg
def validate(val_loader, model, criterion, args, display=True): batch_time = AverageMeter('Time', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') progress = ProgressMeter(len(val_loader), [batch_time, losses, top1], prefix='Test: ') # switch to evaluate mode model.eval() with torch.no_grad(): end = time.time() for i, (images, target) in enumerate(val_loader): if args.gpu is not None: images = images.cuda(args.gpu, non_blocking=True) target = target.cuda(args.gpu, non_blocking=True) # compute output output = model(images) loss = criterion(output, target) # measure accuracy and record loss acc1 = accuracy(output, target, topk=(1, ))[0] losses.update(loss.item(), images.size(0)) top1.update(acc1.item(), images.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0 and display: progress.display(i) if display: # TODO: this should also be done with the ProgressMeter print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1)) return top1.avg, losses.avg
def validate(val_loader, model, criterions, loss_weights, args, display=True): batch_time = AverageMeter('Time', ':6.3f') loss_meters = { name: AverageMeter(f'{name} Loss', ':.4e') for name in criterions } loss_meters['full'] = AverageMeter(f'Full loss', ':4e') top1 = AverageMeter('Acc@1', ':6.2f') progress = ProgressMeter( len(val_loader), [batch_time] + list(loss_meters.values()) + [top1], prefix='Test: ', ) # switch to evaluate mode model.eval() with torch.no_grad(): end = time.time() for i, (images, target, heatvols, volmask) in enumerate(val_loader): # Real videos get zeroed attn mask (eventually, not currently) real_inds = (target == 0).nonzero() valid_heatvol_inds = volmask.nonzero() # all_inds = torch.unique(torch.cat((real_inds, valid_heatvol_inds))) all_inds = valid_heatvol_inds.squeeze() # Not all videos have valid heatvols, so extract only those that do valid_heatvols = heatvols.index_select(0, all_inds) if args.gpu is not None: images = images.cuda(args.gpu, non_blocking=True) target = target.cuda(args.gpu, non_blocking=True) all_inds = all_inds.cuda(args.gpu, non_blocking=True) valid_heatvols = valid_heatvols.cuda(args.gpu, non_blocking=True) # compute output and retrieve self attn map output, attn = model(images) # Compute loss for valid attn maps only valid_attn = attn.index_select(0, all_inds) print(f'valid_attn: {valid_attn.shape}c') if valid_attn.size(0) != 0: # Rescale heatvol to match the size of the attn vol. valid_heatvols = nn.functional.interpolate( valid_heatvols, size=valid_attn.shape[-2:], mode='area') # Compute Cross Entropy loss between the predictions and targets # Compute KL Divergence loss and Correlation Coefficent loss between # human heat volumes and self attn maps. bs = valid_attn.size(0) losses = { 'ce': loss_weights['ce'] * criterions['ce'](output, target), 'cc': loss_weights['cc'] * criterions['cc'](valid_heatvols, valid_attn), } if valid_attn.nelement() != 0: losses['kl'] = (loss_weights['kl'] * criterions['kl']( F.log_softmax(valid_attn.view(bs, -1), dim=-1), F.softmax(valid_heatvols.view(bs, -1), dim=-1), ) if valid_attn.size(0) != 0 else torch.tensor(float('nan'))) losses['full'] = sum(losses.values()) # measure accuracy and record loss acc1 = accuracy(output, target, topk=(1, ))[0] for name, loss in losses.items(): loss_meters[name].update(loss.item(), images.size(0)) top1.update(acc1.item(), images.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0 and display: progress.display(i) if display: # TODO: this should also be done with the ProgressMeter print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1)) return top1.avg, loss_meters['full'].avg
def train( train_loader, model, criterions, loss_weights, optimizer, logger, epoch, args, display=True, ): batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') loss_meters = { name: AverageMeter(f'{name} Loss', ':.4e') for name in criterions } loss_meters['full'] = AverageMeter(f'Full loss', ':4e') top1 = AverageMeter('Acc@1', ':6.2f') progress = ProgressMeter( len(train_loader), [batch_time, data_time] + list(loss_meters.values()) + [top1], prefix="Epoch: [{}]".format(epoch), ) itr = epoch * len(train_loader) # switch to train mode model.train() end = time.time() for i, (images, target, heatvols, volmask) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) itr += 1 # DEBUG # import matplotlib.pyplot as plt # im = images[0][:,0].permute(1,2,0) # plt.imshow((im-im.min())/(im.max()-im.min())) # plt.show() # Real videos get zeroed attn mask (eventually, not currently) real_inds = (target == 0).nonzero() valid_heatvol_inds = volmask.nonzero() # all_inds = torch.unique(torch.cat((real_inds, valid_heatvol_inds))) all_inds = valid_heatvol_inds.squeeze() # Not all videos have valid heatvols, so extract only those that do valid_heatvols = heatvols.index_select(0, all_inds) if args.gpu is not None: images = images.cuda(args.gpu, non_blocking=True) target = target.cuda(args.gpu, non_blocking=True) all_inds = all_inds.cuda(args.gpu, non_blocking=True) valid_heatvols = valid_heatvols.cuda(args.gpu, non_blocking=True) # compute output and retrieve self attn map output, attn = model(images) # Compute loss for valid attn maps only valid_attn = attn.index_select(0, all_inds) if valid_attn.size(0) != 0: # Rescale heatvol to match the size of the attn vol. valid_heatvols = nn.functional.interpolate( valid_heatvols, size=valid_attn.shape[-2:], mode='area') # Compute Cross Entropy loss between the predictions and targets # Compute KL Divergence loss and Correlation Coefficent loss between # human heat volumes and self attn maps. bs = valid_attn.size(0) losses = { 'ce': loss_weights['ce'] * criterions['ce'](output, target), 'cc': loss_weights['cc'] * criterions['cc'](valid_heatvols, valid_attn), } if valid_attn.nelement() != 0: losses['kl'] = (loss_weights['kl'] * criterions['kl']( F.log_softmax(valid_attn.view(bs, -1), dim=-1), F.softmax(valid_heatvols.view(bs, -1), dim=-1), ) if valid_attn.size(0) != 0 else torch.tensor(float('nan'))) losses['full'] = sum(losses.values()) # measure accuracy and record loss acc1 = accuracy(output, target, topk=(1, ))[0] for name, loss in losses.items(): if not torch.isnan(loss): loss_meters[name].update(loss.item(), images.size(0)) top1.update(acc1, images.size(0)) # compute gradient and do SGD step optimizer.zero_grad() losses['full'].backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0 and display: progress.display(i) logger.log_metrics({ 'Accuracy/train': acc1, 'Loss/train': loss }, step=itr) # logger.save() return top1.avg, loss_meters['full'].avg
def train(train_loader, model, criterion, optimizer, logger, epoch, args, display=True): batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') progress = ProgressMeter( len(train_loader), [batch_time, data_time, losses, top1], prefix="Epoch: [{}]".format(epoch), ) itr = epoch * len(train_loader) # switch to train mode model.train() end = time.time() for i, (images, target) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) itr += 1 # DEBUG # import matplotlib.pyplot as plt # im = images[0][:,0].permute(1,2,0) # plt.imshow((im-im.min())/(im.max()-im.min())) # plt.show() if args.gpu is not None: images = images.cuda(args.gpu, non_blocking=True) target = target.cuda(args.gpu, non_blocking=True) # compute output output = model(images) loss = criterion(output, target) # measure accuracy and record loss acc1 = accuracy(output, target, topk=(1, ))[0] losses.update(loss.item(), images.size(0)) top1.update(acc1, images.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0 and display: progress.display(i) logger.log_metrics({ 'Accuracy/train': acc1, 'Loss/train': loss }, step=itr) # logger.save() return top1.avg, losses.avg