def get_trainloader(self) -> Iterable[Tuple[torch.Tensor, torch.Tensor]]: dataloader = DataLoader(self.vdset, batch_sampler=self.train_batch_sampler) for frames in dataloader: frames = more_trans.rearrange_temporal_batch( frames, 1 + self.net_module.temporal_batch_size) inputs, targets = frames[:, :, :-1, :, :], frames[:, :, -1:, :, :] inputs, targets = inputs.to(self.device), targets.to(self.device) yield inputs, targets
def LoCHW2BTHW(tensors: Iterable[torch.Tensor], T: int) -> torch.Tensor: tensor = torch.stack(tuple(tensors)) tensor = rearrange_temporal_batch(tensor, T) tensor = tensor[:, 0, :, :, :] return tensor
def visualize(todir: str, root: str, transform, normalize_stats, indices: Sequence[int], net: nn.Module, net_module, temperature: float = 1.0, bwth: float = None, device: str = 'cpu', batch_size: int = 1, predname_tmpl: str = 'pred{}.png', attnname_tmpl: str = 'attn{}_pred{}.png') -> None: r""" Visualize predictions and input gradients. :param todir: directory under which to save the visualization :param root: the dataset root :param transform: transform on dataset inputs :param normalize_stats: the mean-std tuple used in normalization :param indices: indices of dataset inputs involved in visualization :param net: the trained network :param net_module: the module from which ``net`` is loaded, as returned by ``load_trained_net`` :param temperature: the larger it is, the more contrast in the attention map is :param bwth: if not specified, plot the attention map as sigmoidal map, where 0.5 means zero gradient, and :math:`0.5 \pn 0.5` means positive and negative gradients respectively; if specified as a range [0.0, 1.0], get the absolute value of the gradient, multiply float in by ``temperature``, take the sigmoid, and remove all values lower than the ``bwth`` threshold :param device: where to do inference :param batch_size: batch size when doing inference; will be set to 1 if ``device`` is 'cpu' :param predname_tmpl: the basename template of prediction visualization :param attnname_tmpl: the basename template of attention (inputs gradient) visualization """ if bwth is not None and not (0.0 <= bwth <= 1.0): raise ValueError( 'bwth must be in range [0,1], but got {}'.format(bwth)) logger = _l('visualize') tbatch_size = net_module.temporal_batch_size if device == 'cpu': batch_size = 1 sam = SlidingWindowBatchSampler(indices, 1 + tbatch_size, shuffled=False, batch_size=batch_size, drop_last=True) sam_ = SlidingWindowBatchSampler(indices, 1 + tbatch_size, shuffled=False, batch_size=batch_size, drop_last=True) denormalize = DeNormalize(*normalize_stats) os.makedirs(todir, exist_ok=True) mse = nn.MSELoss().to(device) sigmoid = nn.Sigmoid().to(device) net = net.to(device) with vmdata.VideoDataset(root, transform=transform) as vdset: loader = DataLoader(vdset, batch_sampler=sam) for frames, iframes in zip(loader, iter(sam_)): frames = rearrange_temporal_batch(frames, 1 + tbatch_size) iframes = np.array(iframes).reshape((batch_size, 1 + tbatch_size)) inputs, targets = frames[:, :, :-1, :, :], frames[:, :, -1:, :, :] iinputs, itargets = iframes[:, :-1], iframes[:, -1] inputs, targets = inputs.to(device), targets.to(device) inputs.requires_grad_() outputs = net(inputs) loss = mse(outputs, targets) loss.backward() attns = inputs.grad * temperature logger.info( '[f{}-{}/eval/attn] l1norm={} l2norm={} numel={} max={}'. format(np.min(iframes), np.max(iframes), torch.norm(attns.detach(), 1).item(), torch.norm(attns.detach(), 2).item(), torch.numel(attns.detach()), torch.max(attns.detach()))) logger.info('[f{}-{}/eval/loss] mse={} B={}'.format( np.min(iframes), np.max(iframes), loss.item(), targets.size(0))) if bwth is not None: attns = sigmoid(torch.abs(attns)) * 2 - 1 mask = (attns >= bwth).to(attns.dtype) attns = mask * attns else: attns = sigmoid(attns) inputs, attns, outputs, targets = postprocess( inputs, attns, outputs, targets, denormalize, tbatch_size) for b in range(batch_size): f = os.path.join(todir, predname_tmpl.format(itargets[b])) draw_left_right(f, (outputs[b], targets[b])) for t in range(tbatch_size): f = os.path.join( todir, attnname_tmpl.format(iinputs[b, t], itargets[b])) draw_up_down(f, (inputs[b, t], attns[b, t]))
def train_pred9_f1to8(vdset: vmdata.VideoDataset, trainset: Sequence[int], testset: Sequence[int], savedir: str, statdir: str, device: Union[str, torch.device] = 'cpu', max_epoch: int = 1, lr: float = 0.001, lam_dark: float = 1.0, lam_nrgd: float = 0.2): logger = logging.getLogger(_l(__name__, 'train_pred9_f1to8')) if isinstance(device, str): device = torch.device(device) encoder = pred9_f1to8.STCAEEncoder() decoder = pred9_f1to8.STCAEDecoder() attention = pred9_f1to8.STCAEDecoder() if isinstance(vdset.transform, trans.Normalize): normalize = vdset.transform else: normalize = next( iter(x for x in vdset.transform.__dict__.values() if isinstance(x, trans.Normalize))) ezcae = basicmodels.EzFirstCAE(encoder, decoder, attention).to(device) mse = nn.MSELoss().to(device) darkp = basicmodels.DarknessPenalty(normalize).to(device) nrgdp = basicmodels.NonrigidPenalty().to(device) def criterion(_outputs: torch.Tensor, _attns: torch.Tensor, _targets: torch.Tensor) -> Tuple[torch.Tensor, np.ndarray]: loss1 = mse(_attns * _outputs, _attns * _targets) loss2 = darkp(_attns) loss3 = nrgdp(_attns.view(-1, 1, *_attns.shape[-2:])) _loss = loss1 + lam_dark * loss2 + lam_nrgd * loss3 _loss123 = np.array( [loss1.item(), loss2.item(), loss3.item()], dtype=np.float64) return _loss, _loss123 cpsaver = trainlib.CheckpointSaver( ezcae, savedir, checkpoint_tmpl='checkpoint_{0}_{1}.pth', fired=lambda pg: True) stsaver = trainlib.StatSaver(statdir, statname_tmpl='stats_{0}_{1}.npz', fired=lambda pg: True) alpha = 0.9 # the resistance of the moving average approximation of mean loss optimizer = optim.Adam(ezcae.parameters(), lr=lr) for epoch in range(max_epoch): for stage, dataset in [('train', trainset), ('eval', testset)]: swsam = more_sampler.SlidingWindowBatchSampler( dataset, 1 + pred9_f1to8.temporal_batch_size, shuffled=True, batch_size=8) dataloader = DataLoader(vdset, batch_sampler=swsam) moving_average = None getattr(ezcae, stage)() # ezcae.train() or ezcae.eval() torch.set_grad_enabled(stage == 'train') for j, inputs in enumerate(dataloader): progress = epoch, j inputs = more_trans.rearrange_temporal_batch( inputs, 1 + pred9_f1to8.temporal_batch_size) inputs, targets = inputs[:, :, :-1, :, :], inputs[:, :, -1:, :, :] inputs, targets = inputs.to(device), targets.to(device) outputs, attns = ezcae(inputs) loss, loss123 = criterion(outputs, attns, targets) if stage == 'train': optimizer.zero_grad() loss.backward() optimizer.step() stat_names = ['loss', 'loss_mse', 'loss_dark', 'loss_nrgd'] stat_vals = [loss.item()] + list(loss123) if stage == 'train': moving_average = loss123 if moving_average is None else \ alpha * moving_average + (1 - alpha) * loss123 cpsaver(progress) stsaver(progress, **dict(zip(stat_names, stat_vals))) logger.info(('[epoch{}/batch{}] '.format(epoch, j) + ' '.join('{}={{:.2f}}'.format(n) for n in stat_names)).format(*stat_vals))
def train_pred9_f1to8_no_attn(vdset: vmdata.VideoDataset, trainset: Sequence[int], testset: Sequence[int], savedir: str, statdir: str, device: Union[str, torch.device] = 'cpu', max_epoch: int = 1, lr: float = 0.001): logger = logging.getLogger(_l(__name__, 'train_pred9_f1to8_no_attn')) if isinstance(device, str): device = torch.device(device) encoder = pred9_f1to8.STCAEEncoder() decoder = pred9_f1to8.STCAEDecoder() cae = basicmodels.CAE(encoder, decoder).to(device) mse = nn.MSELoss().to(device) cpsaver = trainlib.CheckpointSaver( cae, savedir, checkpoint_tmpl='checkpoint_{0}_{1}.pth', fired=lambda pg: True) stsaver = trainlib.StatSaver(statdir, statname_tmpl='stats_{0}_{1}.npz', fired=lambda pg: True) alpha = 0.9 # the resistance of the moving average approximation of mean loss optimizer = optim.Adam(cae.parameters(), lr=lr) for epoch in range(max_epoch): for stage, dataset in [('train', trainset), ('eval', testset)]: swsam = more_sampler.SlidingWindowBatchSampler(dataset, 9, shuffled=True, batch_size=8) dataloader = DataLoader(vdset, batch_sampler=swsam) moving_average = None getattr(cae, stage)() # ezcae.train() or ezcae.eval() torch.set_grad_enabled(stage == 'train') for j, inputs in enumerate(dataloader): progress = epoch, j inputs = more_trans.rearrange_temporal_batch(inputs, 9) inputs, targets = inputs[:, :, :-1, :, :], inputs[:, :, -1:, :, :] inputs, targets = inputs.to(device), targets.to(device) outputs = cae(inputs) loss = mse(outputs, targets) if stage == 'train': optimizer.zero_grad() loss.backward() optimizer.step() loss_val = loss.item() if stage == 'train': moving_average = loss_val if moving_average is None else \ alpha * moving_average + (1 - alpha) * loss_val cpsaver(progress) stsaver(progress, loss=loss_val) logger.info('[epoch{}/batch{}] loss={:.2f}'.format( epoch, j, loss_val))