コード例 #1
0
ファイル: vmdata.py プロジェクト: kkew3/dolphins-mc
 def cleanup_unused_mmapfiles(self):
     logger = logging.getLogger(
         _l(__name__, self, 'cleanup_unused_mmapfiles'))
     for filename in os.listdir(self.root_tmp):
         matched = DATABATCH_FILENAME_PAT.match(filename)
         if matched:
             batch_id = int(matched.group(1))
             with self.access_locks[batch_id]:
                 batchf = os.path.join(self.root_tmp, filename)
                 # Since len(self.gz_cache) >= len(self.mmap_cahce) and they
                 # are updated together, the latter must be a subset of the
                 # former.
                 if batchf not in self.gz_cache and \
                         len(os.listdir(self.root_tmp)) > self.max_gzcache:
                     try:
                         os.remove(batchf)
                     except OSError:
                         # due to concurrency, the file may have already been
                         # removed; due to the lock, however, no process will
                         # try to remove a file when another process is
                         # removing exactly the same file
                         pass
                     else:
                         logger.info(
                             'Decompressed batch "{}" removed'.format(
                                 batchf))
コード例 #2
0
ファイル: vmdata.py プロジェクト: kkew3/dolphins-mc
 def release_mmap(self):
     """
     Release all memory mapped dataset.
     """
     logger = logging.getLogger(_l(__name__, self, 'release_mmap'))
     keys = list(self.mmap_cache.keys())
     for k in keys:
         del self.mmap_cache[k]
     logger.info('All mmap released')
コード例 #3
0
ファイル: vmdata.py プロジェクト: kkew3/dolphins-mc
 def cleanup_all_mmapfiles(self):
     """
     Be sure to call this function only if there's no opened memory-mapped
     file. Usually this function is unnecessary unless the user want to save
     some disk space.
     """
     logger = logging.getLogger(_l(__name__, self, 'cleanup_all_mmapfiles'))
     if os.path.isdir(self.root_tmp):
         shutil.rmtree(self.root_tmp)
     if not os.path.isdir(self.root_tmp):
         os.mkdir(self.root_tmp)
     logger.info('All decompressed batches removed')
コード例 #4
0
ファイル: vmdata.py プロジェクト: kkew3/dolphins-mc
    def __init__(self, root, transform=None, max_mmap=1, max_gzcache=3):
        """
        :param root: the root directory of the dataset
        :type root: str
        :param transform: the transformations to be performed on loaded data
        :param max_mmap: the maximum number of memory map to keep
        :type max_mmap: int
        :param max_gzcache: the maximum number of extracted memory map files to
               keep on disk
        :type max_gzcache: int
        """
        self.root = root
        jsonfile = get_dset_filename_by_ext(root, '.json')
        with open(jsonfile) as infile:
            self.metainfo = json.load(infile)
        self.total_frames = np.sum(self.metainfo['lens'])
        self.lens_cumsum = list(accumulate(self.metainfo['lens']))
        shape = self.metainfo['resolution'] + [self.metainfo['channels']]
        self.frame_shape = tuple(shape)

        # if self.validated_batches[i] == 0, then batch i hasn't been validated
        self.validated_batches = [False] * len(self.metainfo['lens'])
        checksumfile = get_dset_filename_by_ext(root, '.' + HASH_ALGORITHM)
        self.expected_hexes = parse_checksum_file(checksumfile)

        self.root_tmp = os.path.join(root, 'tmp')
        if not os.path.isdir(self.root_tmp):
            os.mkdir(self.root_tmp)

        self.transform = transform

        max_mmap = max(1, max_mmap)
        self.mmap_cache = LRUCache(maxsize=max_mmap)
        self.gz_cache = LRUCache(maxsize=max(max_mmap, max_gzcache))
        self.max_gzcache = max(max_mmap, max_gzcache)

        # fine granularity lock for each data batch
        lockfile_tmpl = get_dset_filename_by_ext(root, '.access{}.lock')
        # note that I use the absolute to construct the file lock, so that the
        # lock will be shared by not only different processes, but also several
        # instances of this class, as long as they have been assigned the same
        # root
        self.access_locks = [
            FileLock(lockfile_tmpl.format(bid))
            for bid in range(len(self.metainfo['lens']))
        ]

        logger = logging.getLogger(_l(__name__, self, '__init__'))
        logger.info('Instantiated: root={}'.format(self.root))
コード例 #5
0
ファイル: vmdata.py プロジェクト: kkew3/dolphins-mc
    def __getitem__(self, frame_id):
        """
        Returns a frame of dimension HWC upon the request of a frame ID.
        Note that when calling this method without using contiguous or nearly
        contiguous indices, the efficiency will be very low.

        :param frame_id: the frame index
        :return: the frame in numpy array of dimension HWC
        :rtype: np.ndarray
        """
        logger = logging.getLogger(_l(__name__, self, '__getitem__'))
        if frame_id < 0 or frame_id >= len(self):
            raise IndexError('Invalid index: {}'.format(frame_id))

        batch_id, rel_frame_id = self.locate_batch(frame_id)
        logger.debug('Waiting for lock ID {}'.format(batch_id))
        with self.access_locks[batch_id]:
            if batch_id not in self.mmap_cache:
                batchf = self.batch_filename_by_id(batch_id)
                if not os.path.isfile(batchf):
                    logger.info('Decompressing "{}"'.format(batchf))
                    extract_gzip(
                        self.batch_filename_by_id(batch_id, gzipped=True),
                        batchf)
                    assert os.path.isfile(batchf), \
                        '"{}" not found after decompressed' \
                            .format(batchf)
                if not self.validated_batches[batch_id]:
                    if not check_file_integrity(batchf,
                                                self.expected_hexes[batch_id]):
                        logger.warning(
                            'File ingerity failed at "{}"; retrying'.format(
                                batchf))
                        # probably there's error with read last time; attempt
                        # to decompress again for once
                        os.remove(batchf)
                        extract_gzip(
                            self.batch_filename_by_id(batch_id, gzipped=True),
                            batchf)
                        assert os.path.isfile(batchf), \
                            '"{}" not found after decompressed' \
                                .format(batchf)
                        if not check_file_integrity(
                                batchf, self.expected_hexes[batch_id]):
                            logger.error('File integrity failed at "{}"; '
                                         'RuntimeError raised'.format(batchf))
                            raise RuntimeError(
                                'Data batch {} corrupted'.format(batch_id))
                    self.validated_batches[batch_id] = True
                    logger.info(
                        'File integrity check completed for batch {}'.format(
                            batch_id))

                # till here file "batchf" has been available
                self.gz_cache[batchf] = True

                shape = (self.metainfo['lens'][batch_id], ) + self.frame_shape
                logger.debug('keys before mmap cache adjustment: {}'.format(
                    list(self.mmap_cache.keys())))
                self.mmap_cache[batch_id] = np.memmap(
                    str(batchf),
                    mode='r',
                    dtype=self.metainfo['dtype'],
                    shape=shape)
                logger.debug('keys after mmap cache adjustment: {}'.format(
                    list(self.mmap_cache.keys())))
        frame = np.copy(self.mmap_cache[batch_id][rel_frame_id])
        if self.transform is not None:
            frame = self.transform(frame)
        self.cleanup_unused_mmapfiles()
        return frame
コード例 #6
0
def train_pred9_f1to8(vdset: vmdata.VideoDataset,
                      trainset: Sequence[int],
                      testset: Sequence[int],
                      savedir: str,
                      statdir: str,
                      device: Union[str, torch.device] = 'cpu',
                      max_epoch: int = 1,
                      lr: float = 0.001,
                      lam_dark: float = 1.0,
                      lam_nrgd: float = 0.2):
    logger = logging.getLogger(_l(__name__, 'train_pred9_f1to8'))
    if isinstance(device, str):
        device = torch.device(device)

    encoder = pred9_f1to8.STCAEEncoder()
    decoder = pred9_f1to8.STCAEDecoder()
    attention = pred9_f1to8.STCAEDecoder()
    if isinstance(vdset.transform, trans.Normalize):
        normalize = vdset.transform
    else:
        normalize = next(
            iter(x for x in vdset.transform.__dict__.values()
                 if isinstance(x, trans.Normalize)))

    ezcae = basicmodels.EzFirstCAE(encoder, decoder, attention).to(device)
    mse = nn.MSELoss().to(device)
    darkp = basicmodels.DarknessPenalty(normalize).to(device)
    nrgdp = basicmodels.NonrigidPenalty().to(device)

    def criterion(_outputs: torch.Tensor, _attns: torch.Tensor,
                  _targets: torch.Tensor) -> Tuple[torch.Tensor, np.ndarray]:
        loss1 = mse(_attns * _outputs, _attns * _targets)
        loss2 = darkp(_attns)
        loss3 = nrgdp(_attns.view(-1, 1, *_attns.shape[-2:]))
        _loss = loss1 + lam_dark * loss2 + lam_nrgd * loss3
        _loss123 = np.array(
            [loss1.item(), loss2.item(),
             loss3.item()], dtype=np.float64)
        return _loss, _loss123

    cpsaver = trainlib.CheckpointSaver(
        ezcae,
        savedir,
        checkpoint_tmpl='checkpoint_{0}_{1}.pth',
        fired=lambda pg: True)
    stsaver = trainlib.StatSaver(statdir,
                                 statname_tmpl='stats_{0}_{1}.npz',
                                 fired=lambda pg: True)
    alpha = 0.9  # the resistance of the moving average approximation of mean loss
    optimizer = optim.Adam(ezcae.parameters(), lr=lr)

    for epoch in range(max_epoch):
        for stage, dataset in [('train', trainset), ('eval', testset)]:
            swsam = more_sampler.SlidingWindowBatchSampler(
                dataset,
                1 + pred9_f1to8.temporal_batch_size,
                shuffled=True,
                batch_size=8)
            dataloader = DataLoader(vdset, batch_sampler=swsam)
            moving_average = None
            getattr(ezcae, stage)()  # ezcae.train() or ezcae.eval()
            torch.set_grad_enabled(stage == 'train')
            for j, inputs in enumerate(dataloader):
                progress = epoch, j
                inputs = more_trans.rearrange_temporal_batch(
                    inputs, 1 + pred9_f1to8.temporal_batch_size)
                inputs, targets = inputs[:, :, :-1, :, :], inputs[:, :,
                                                                  -1:, :, :]
                inputs, targets = inputs.to(device), targets.to(device)
                outputs, attns = ezcae(inputs)
                loss, loss123 = criterion(outputs, attns, targets)

                if stage == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                stat_names = ['loss', 'loss_mse', 'loss_dark', 'loss_nrgd']
                stat_vals = [loss.item()] + list(loss123)
                if stage == 'train':
                    moving_average = loss123 if moving_average is None else \
                        alpha * moving_average + (1 - alpha) * loss123
                    cpsaver(progress)
                    stsaver(progress, **dict(zip(stat_names, stat_vals)))
                logger.info(('[epoch{}/batch{}] '.format(epoch, j) +
                             ' '.join('{}={{:.2f}}'.format(n)
                                      for n in stat_names)).format(*stat_vals))
コード例 #7
0
def train_pred9_f1to8_no_attn(vdset: vmdata.VideoDataset,
                              trainset: Sequence[int],
                              testset: Sequence[int],
                              savedir: str,
                              statdir: str,
                              device: Union[str, torch.device] = 'cpu',
                              max_epoch: int = 1,
                              lr: float = 0.001):
    logger = logging.getLogger(_l(__name__, 'train_pred9_f1to8_no_attn'))
    if isinstance(device, str):
        device = torch.device(device)

    encoder = pred9_f1to8.STCAEEncoder()
    decoder = pred9_f1to8.STCAEDecoder()

    cae = basicmodels.CAE(encoder, decoder).to(device)
    mse = nn.MSELoss().to(device)

    cpsaver = trainlib.CheckpointSaver(
        cae,
        savedir,
        checkpoint_tmpl='checkpoint_{0}_{1}.pth',
        fired=lambda pg: True)
    stsaver = trainlib.StatSaver(statdir,
                                 statname_tmpl='stats_{0}_{1}.npz',
                                 fired=lambda pg: True)
    alpha = 0.9  # the resistance of the moving average approximation of mean loss
    optimizer = optim.Adam(cae.parameters(), lr=lr)

    for epoch in range(max_epoch):
        for stage, dataset in [('train', trainset), ('eval', testset)]:
            swsam = more_sampler.SlidingWindowBatchSampler(dataset,
                                                           9,
                                                           shuffled=True,
                                                           batch_size=8)
            dataloader = DataLoader(vdset, batch_sampler=swsam)
            moving_average = None
            getattr(cae, stage)()  # ezcae.train() or ezcae.eval()
            torch.set_grad_enabled(stage == 'train')
            for j, inputs in enumerate(dataloader):
                progress = epoch, j
                inputs = more_trans.rearrange_temporal_batch(inputs, 9)
                inputs, targets = inputs[:, :, :-1, :, :], inputs[:, :,
                                                                  -1:, :, :]
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = cae(inputs)
                loss = mse(outputs, targets)

                if stage == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                loss_val = loss.item()
                if stage == 'train':
                    moving_average = loss_val if moving_average is None else \
                        alpha * moving_average + (1 - alpha) * loss_val
                    cpsaver(progress)
                    stsaver(progress, loss=loss_val)
                logger.info('[epoch{}/batch{}] loss={:.2f}'.format(
                    epoch, j, loss_val))