Пример #1
0
    def __init__(self,
                 net_module,
                 root: str,
                 transform,
                 trainset_indices: Sequence[int],
                 gr_strength=0.0,
                 max_epoch: int = 1,
                 batch_size: int = 8,
                 device='cpu'):
        logger = _l(self, '__init__')
        self.net_module = net_module
        net = Autoencoder(self.net_module.STCAEEncoder(),
                          self.net_module.STCAEDecoder())
        logger.debug('Loaded autoencoder net from module {}'.format(
            self.net_module.__name__))

        super().__init__(net, max_epoch=max_epoch, device=device)
        self.vdset = vmdata.VideoDataset(root, transform=transform)
        self.criterion = nn.MSELoss().to(device)
        self.optimizer = optim.Adam(self.net.parameters(),
                                    lr=type(self).optim_lr)
        self.trainset_indices = trainset_indices
        self.gr_strength = gr_strength
        self.batch_size = batch_size

        self.stat_names = ('loss', )
        self.run_stages = ('train', )
        self.train_batch_sampler_ = lambda: more_sampler.SlidingWindowBatchSampler(
            self.trainset_indices,
            1 + self.net_module.temporal_batch_size,
            batch_size=batch_size,
            shuffled=(self.max_epoch > 1))
Пример #2
0
def main():
    args = make_parser().parse_args()
    root = vmdata.dataset_root(args.camera, (8, 0, 0))
    with vmdata.VideoDataset(root) as vdset:
        _analyze_flow = partial(analyze_flow, vdset)
        if args.debug:
            print(args.fids)
            return
        for fid in args.fids:
            print('Running on frame {}'.format(fid), file=sys.stderr)
            _analyze_flow(fid)
Пример #3
0
def visualize(todir: str,
              root: str,
              transform,
              normalize_stats,
              indices: Sequence[int],
              net: nn.Module,
              net_module,
              temperature: float = 1.0,
              bwth: float = None,
              device: str = 'cpu',
              batch_size: int = 1,
              predname_tmpl: str = 'pred{}.png',
              attnname_tmpl: str = 'attn{}_pred{}.png') -> None:
    r"""
    Visualize predictions and input gradients.

    :param todir: directory under which to save the visualization
    :param root: the dataset root
    :param transform: transform on dataset inputs
    :param normalize_stats: the mean-std tuple used in normalization
    :param indices: indices of dataset inputs involved in visualization
    :param net: the trained network
    :param net_module: the module from which ``net`` is loaded, as returned by
           ``load_trained_net``
    :param temperature: the larger it is, the more contrast in the attention
           map is
    :param bwth: if not specified, plot the attention map as sigmoidal map,
           where 0.5 means zero gradient, and :math:`0.5 \pn 0.5` means
           positive and negative gradients respectively; if specified as a
           range [0.0, 1.0], get the absolute value of the gradient, multiply
           float in by ``temperature``, take the sigmoid, and remove all values
           lower than the ``bwth`` threshold
    :param device: where to do inference
    :param batch_size: batch size when doing inference; will be set to 1 if
           ``device`` is 'cpu'
    :param predname_tmpl: the basename template of prediction visualization
    :param attnname_tmpl: the basename template of attention (inputs gradient)
           visualization
    """
    if bwth is not None and not (0.0 <= bwth <= 1.0):
        raise ValueError(
            'bwth must be in range [0,1], but got {}'.format(bwth))
    logger = _l('visualize')
    tbatch_size = net_module.temporal_batch_size
    if device == 'cpu':
        batch_size = 1
    sam = SlidingWindowBatchSampler(indices,
                                    1 + tbatch_size,
                                    shuffled=False,
                                    batch_size=batch_size,
                                    drop_last=True)
    sam_ = SlidingWindowBatchSampler(indices,
                                     1 + tbatch_size,
                                     shuffled=False,
                                     batch_size=batch_size,
                                     drop_last=True)
    denormalize = DeNormalize(*normalize_stats)
    os.makedirs(todir, exist_ok=True)

    mse = nn.MSELoss().to(device)
    sigmoid = nn.Sigmoid().to(device)
    net = net.to(device)
    with vmdata.VideoDataset(root, transform=transform) as vdset:
        loader = DataLoader(vdset, batch_sampler=sam)
        for frames, iframes in zip(loader, iter(sam_)):
            frames = rearrange_temporal_batch(frames, 1 + tbatch_size)
            iframes = np.array(iframes).reshape((batch_size, 1 + tbatch_size))
            inputs, targets = frames[:, :, :-1, :, :], frames[:, :, -1:, :, :]
            iinputs, itargets = iframes[:, :-1], iframes[:, -1]
            inputs, targets = inputs.to(device), targets.to(device)

            inputs.requires_grad_()
            outputs = net(inputs)
            loss = mse(outputs, targets)
            loss.backward()
            attns = inputs.grad * temperature

            logger.info(
                '[f{}-{}/eval/attn] l1norm={} l2norm={} numel={} max={}'.
                format(np.min(iframes), np.max(iframes),
                       torch.norm(attns.detach(), 1).item(),
                       torch.norm(attns.detach(), 2).item(),
                       torch.numel(attns.detach()), torch.max(attns.detach())))
            logger.info('[f{}-{}/eval/loss] mse={} B={}'.format(
                np.min(iframes), np.max(iframes), loss.item(),
                targets.size(0)))

            if bwth is not None:
                attns = sigmoid(torch.abs(attns)) * 2 - 1
                mask = (attns >= bwth).to(attns.dtype)
                attns = mask * attns
            else:
                attns = sigmoid(attns)

            inputs, attns, outputs, targets = postprocess(
                inputs, attns, outputs, targets, denormalize, tbatch_size)

            for b in range(batch_size):
                f = os.path.join(todir, predname_tmpl.format(itargets[b]))
                draw_left_right(f, (outputs[b], targets[b]))
                for t in range(tbatch_size):
                    f = os.path.join(
                        todir, attnname_tmpl.format(iinputs[b, t],
                                                    itargets[b]))
                    draw_up_down(f, (inputs[b, t], attns[b, t]))
Пример #4
0
logging.basicConfig(level=logging.INFO, format='%(name)s %(asctime)s -- %(message)s',
                    filename='main.{}.log'.format(_rid))

import sys

import torchvision.transforms as trans

import vmdata
import ezfirstae.loaddata as ld
import ezfirstae.train as train

max_epoch = 1
root = vmdata.dataset_root(9, (8, 0, 0))
normalize = trans.Normalize(*vmdata.get_normalization_stats(root, bw=True))
transform = ld.PreProcTransform(normalize, pool_scale=8, downsample_scale=3)
statdir = 'stat.{}'.format(_rid)
savedir = 'save.{}'.format(_rid)
device = 'cuda'

if __name__ == '__main__':
    logger = logging.getLogger()
    logger.info('Begin training: model=ezfirstae.models.pred9_f1to8(no-attention)')
    with vmdata.VideoDataset(root, transform=transform, max_mmap=3, max_gzcache=100) as vdset:
        trainset, testset = ld.contiguous_partition_dataset(range(len(vdset)), (5, 1))
        try:
            train.train_pred9_f1to8_no_attn(vdset, trainset, testset,
                                            savedir, statdir, device, max_epoch)
        except KeyboardInterrupt:
            logger.warning('User interrupt')
            print('Cleaning up ...', file=sys.stderr)
Пример #5
0
#!/usr/bin/env python
import torch
import torchvision.transforms as trans

import train
import vmdata
import more_trans
import salicae

if __name__ == '__main__':
    root = vmdata.dataset_root(9, (8, 0, 0))
    normalize = trans.Normalize(*vmdata.get_normalization_stats(root))
    dset = vmdata.VideoDataset(root,
                               max_mmap=4,
                               max_gzcache=10,
                               transform=trans.Compose([
                                   more_trans.MedianBlur(),
                                   trans.ToTensor(),
                                   normalize,
                               ]))
    net = salicae.SaliencyCAE(vgg_arch='Ashallow', batch_norm=False)
    device = torch.device('cuda')
    batch_size = 16
    lasso_strength = 1.
    savedir = 'save'
    statdir = 'stats'
    train.train(net, dset, device, batch_size, lasso_strength, statdir,
                savedir)
Пример #6
0
import vmdata
import ezfirstae.loaddata as ld
import ezfirstae.train as train

max_epoch = 1
root = vmdata.dataset_root(9, (8, 0, 0))
normalize = trans.Normalize(*vmdata.get_normalization_stats(root, bw=True))
transform = ld.PreProcTransform(normalize, pool_scale=8, downsample_scale=3)
statdir = 'stat.{}'.format(_rid)
savedir = 'save.{}'.format(_rid)
device = 'cuda'
lr = 5e-5
lam_dark = 0.1
lam_nrgd = 0.0

if __name__ == '__main__':
    logger = logging.getLogger()
    logger.info(
        'Begin training: model=ezfirstae.models.pred9_f1to8 lam_dark={}'
        ' lam_nrgd={}'.format(lam_dark, lam_nrgd))
    with vmdata.VideoDataset(root, transform=transform) as vdset:
        trainset, testset = ld.contiguous_partition_dataset(
            range(len(vdset)), (5, 1))
        try:
            train.train_pred9_f1to8(vdset, trainset, testset, savedir, statdir,
                                    device, max_epoch, lr, lam_dark, lam_nrgd)
        except KeyboardInterrupt:
            logger.warning('User interrupt')
            print('Cleaning up ...', file=sys.stderr)