def cleanup_unused_mmapfiles(self): logger = logging.getLogger( _l(__name__, self, 'cleanup_unused_mmapfiles')) for filename in os.listdir(self.root_tmp): matched = DATABATCH_FILENAME_PAT.match(filename) if matched: batch_id = int(matched.group(1)) with self.access_locks[batch_id]: batchf = os.path.join(self.root_tmp, filename) # Since len(self.gz_cache) >= len(self.mmap_cahce) and they # are updated together, the latter must be a subset of the # former. if batchf not in self.gz_cache and \ len(os.listdir(self.root_tmp)) > self.max_gzcache: try: os.remove(batchf) except OSError: # due to concurrency, the file may have already been # removed; due to the lock, however, no process will # try to remove a file when another process is # removing exactly the same file pass else: logger.info( 'Decompressed batch "{}" removed'.format( batchf))
def release_mmap(self): """ Release all memory mapped dataset. """ logger = logging.getLogger(_l(__name__, self, 'release_mmap')) keys = list(self.mmap_cache.keys()) for k in keys: del self.mmap_cache[k] logger.info('All mmap released')
def cleanup_all_mmapfiles(self): """ Be sure to call this function only if there's no opened memory-mapped file. Usually this function is unnecessary unless the user want to save some disk space. """ logger = logging.getLogger(_l(__name__, self, 'cleanup_all_mmapfiles')) if os.path.isdir(self.root_tmp): shutil.rmtree(self.root_tmp) if not os.path.isdir(self.root_tmp): os.mkdir(self.root_tmp) logger.info('All decompressed batches removed')
def __init__(self, root, transform=None, max_mmap=1, max_gzcache=3): """ :param root: the root directory of the dataset :type root: str :param transform: the transformations to be performed on loaded data :param max_mmap: the maximum number of memory map to keep :type max_mmap: int :param max_gzcache: the maximum number of extracted memory map files to keep on disk :type max_gzcache: int """ self.root = root jsonfile = get_dset_filename_by_ext(root, '.json') with open(jsonfile) as infile: self.metainfo = json.load(infile) self.total_frames = np.sum(self.metainfo['lens']) self.lens_cumsum = list(accumulate(self.metainfo['lens'])) shape = self.metainfo['resolution'] + [self.metainfo['channels']] self.frame_shape = tuple(shape) # if self.validated_batches[i] == 0, then batch i hasn't been validated self.validated_batches = [False] * len(self.metainfo['lens']) checksumfile = get_dset_filename_by_ext(root, '.' + HASH_ALGORITHM) self.expected_hexes = parse_checksum_file(checksumfile) self.root_tmp = os.path.join(root, 'tmp') if not os.path.isdir(self.root_tmp): os.mkdir(self.root_tmp) self.transform = transform max_mmap = max(1, max_mmap) self.mmap_cache = LRUCache(maxsize=max_mmap) self.gz_cache = LRUCache(maxsize=max(max_mmap, max_gzcache)) self.max_gzcache = max(max_mmap, max_gzcache) # fine granularity lock for each data batch lockfile_tmpl = get_dset_filename_by_ext(root, '.access{}.lock') # note that I use the absolute to construct the file lock, so that the # lock will be shared by not only different processes, but also several # instances of this class, as long as they have been assigned the same # root self.access_locks = [ FileLock(lockfile_tmpl.format(bid)) for bid in range(len(self.metainfo['lens'])) ] logger = logging.getLogger(_l(__name__, self, '__init__')) logger.info('Instantiated: root={}'.format(self.root))
def __getitem__(self, frame_id): """ Returns a frame of dimension HWC upon the request of a frame ID. Note that when calling this method without using contiguous or nearly contiguous indices, the efficiency will be very low. :param frame_id: the frame index :return: the frame in numpy array of dimension HWC :rtype: np.ndarray """ logger = logging.getLogger(_l(__name__, self, '__getitem__')) if frame_id < 0 or frame_id >= len(self): raise IndexError('Invalid index: {}'.format(frame_id)) batch_id, rel_frame_id = self.locate_batch(frame_id) logger.debug('Waiting for lock ID {}'.format(batch_id)) with self.access_locks[batch_id]: if batch_id not in self.mmap_cache: batchf = self.batch_filename_by_id(batch_id) if not os.path.isfile(batchf): logger.info('Decompressing "{}"'.format(batchf)) extract_gzip( self.batch_filename_by_id(batch_id, gzipped=True), batchf) assert os.path.isfile(batchf), \ '"{}" not found after decompressed' \ .format(batchf) if not self.validated_batches[batch_id]: if not check_file_integrity(batchf, self.expected_hexes[batch_id]): logger.warning( 'File ingerity failed at "{}"; retrying'.format( batchf)) # probably there's error with read last time; attempt # to decompress again for once os.remove(batchf) extract_gzip( self.batch_filename_by_id(batch_id, gzipped=True), batchf) assert os.path.isfile(batchf), \ '"{}" not found after decompressed' \ .format(batchf) if not check_file_integrity( batchf, self.expected_hexes[batch_id]): logger.error('File integrity failed at "{}"; ' 'RuntimeError raised'.format(batchf)) raise RuntimeError( 'Data batch {} corrupted'.format(batch_id)) self.validated_batches[batch_id] = True logger.info( 'File integrity check completed for batch {}'.format( batch_id)) # till here file "batchf" has been available self.gz_cache[batchf] = True shape = (self.metainfo['lens'][batch_id], ) + self.frame_shape logger.debug('keys before mmap cache adjustment: {}'.format( list(self.mmap_cache.keys()))) self.mmap_cache[batch_id] = np.memmap( str(batchf), mode='r', dtype=self.metainfo['dtype'], shape=shape) logger.debug('keys after mmap cache adjustment: {}'.format( list(self.mmap_cache.keys()))) frame = np.copy(self.mmap_cache[batch_id][rel_frame_id]) if self.transform is not None: frame = self.transform(frame) self.cleanup_unused_mmapfiles() return frame
def train_pred9_f1to8(vdset: vmdata.VideoDataset, trainset: Sequence[int], testset: Sequence[int], savedir: str, statdir: str, device: Union[str, torch.device] = 'cpu', max_epoch: int = 1, lr: float = 0.001, lam_dark: float = 1.0, lam_nrgd: float = 0.2): logger = logging.getLogger(_l(__name__, 'train_pred9_f1to8')) if isinstance(device, str): device = torch.device(device) encoder = pred9_f1to8.STCAEEncoder() decoder = pred9_f1to8.STCAEDecoder() attention = pred9_f1to8.STCAEDecoder() if isinstance(vdset.transform, trans.Normalize): normalize = vdset.transform else: normalize = next( iter(x for x in vdset.transform.__dict__.values() if isinstance(x, trans.Normalize))) ezcae = basicmodels.EzFirstCAE(encoder, decoder, attention).to(device) mse = nn.MSELoss().to(device) darkp = basicmodels.DarknessPenalty(normalize).to(device) nrgdp = basicmodels.NonrigidPenalty().to(device) def criterion(_outputs: torch.Tensor, _attns: torch.Tensor, _targets: torch.Tensor) -> Tuple[torch.Tensor, np.ndarray]: loss1 = mse(_attns * _outputs, _attns * _targets) loss2 = darkp(_attns) loss3 = nrgdp(_attns.view(-1, 1, *_attns.shape[-2:])) _loss = loss1 + lam_dark * loss2 + lam_nrgd * loss3 _loss123 = np.array( [loss1.item(), loss2.item(), loss3.item()], dtype=np.float64) return _loss, _loss123 cpsaver = trainlib.CheckpointSaver( ezcae, savedir, checkpoint_tmpl='checkpoint_{0}_{1}.pth', fired=lambda pg: True) stsaver = trainlib.StatSaver(statdir, statname_tmpl='stats_{0}_{1}.npz', fired=lambda pg: True) alpha = 0.9 # the resistance of the moving average approximation of mean loss optimizer = optim.Adam(ezcae.parameters(), lr=lr) for epoch in range(max_epoch): for stage, dataset in [('train', trainset), ('eval', testset)]: swsam = more_sampler.SlidingWindowBatchSampler( dataset, 1 + pred9_f1to8.temporal_batch_size, shuffled=True, batch_size=8) dataloader = DataLoader(vdset, batch_sampler=swsam) moving_average = None getattr(ezcae, stage)() # ezcae.train() or ezcae.eval() torch.set_grad_enabled(stage == 'train') for j, inputs in enumerate(dataloader): progress = epoch, j inputs = more_trans.rearrange_temporal_batch( inputs, 1 + pred9_f1to8.temporal_batch_size) inputs, targets = inputs[:, :, :-1, :, :], inputs[:, :, -1:, :, :] inputs, targets = inputs.to(device), targets.to(device) outputs, attns = ezcae(inputs) loss, loss123 = criterion(outputs, attns, targets) if stage == 'train': optimizer.zero_grad() loss.backward() optimizer.step() stat_names = ['loss', 'loss_mse', 'loss_dark', 'loss_nrgd'] stat_vals = [loss.item()] + list(loss123) if stage == 'train': moving_average = loss123 if moving_average is None else \ alpha * moving_average + (1 - alpha) * loss123 cpsaver(progress) stsaver(progress, **dict(zip(stat_names, stat_vals))) logger.info(('[epoch{}/batch{}] '.format(epoch, j) + ' '.join('{}={{:.2f}}'.format(n) for n in stat_names)).format(*stat_vals))
def train_pred9_f1to8_no_attn(vdset: vmdata.VideoDataset, trainset: Sequence[int], testset: Sequence[int], savedir: str, statdir: str, device: Union[str, torch.device] = 'cpu', max_epoch: int = 1, lr: float = 0.001): logger = logging.getLogger(_l(__name__, 'train_pred9_f1to8_no_attn')) if isinstance(device, str): device = torch.device(device) encoder = pred9_f1to8.STCAEEncoder() decoder = pred9_f1to8.STCAEDecoder() cae = basicmodels.CAE(encoder, decoder).to(device) mse = nn.MSELoss().to(device) cpsaver = trainlib.CheckpointSaver( cae, savedir, checkpoint_tmpl='checkpoint_{0}_{1}.pth', fired=lambda pg: True) stsaver = trainlib.StatSaver(statdir, statname_tmpl='stats_{0}_{1}.npz', fired=lambda pg: True) alpha = 0.9 # the resistance of the moving average approximation of mean loss optimizer = optim.Adam(cae.parameters(), lr=lr) for epoch in range(max_epoch): for stage, dataset in [('train', trainset), ('eval', testset)]: swsam = more_sampler.SlidingWindowBatchSampler(dataset, 9, shuffled=True, batch_size=8) dataloader = DataLoader(vdset, batch_sampler=swsam) moving_average = None getattr(cae, stage)() # ezcae.train() or ezcae.eval() torch.set_grad_enabled(stage == 'train') for j, inputs in enumerate(dataloader): progress = epoch, j inputs = more_trans.rearrange_temporal_batch(inputs, 9) inputs, targets = inputs[:, :, :-1, :, :], inputs[:, :, -1:, :, :] inputs, targets = inputs.to(device), targets.to(device) outputs = cae(inputs) loss = mse(outputs, targets) if stage == 'train': optimizer.zero_grad() loss.backward() optimizer.step() loss_val = loss.item() if stage == 'train': moving_average = loss_val if moving_average is None else \ alpha * moving_average + (1 - alpha) * loss_val cpsaver(progress) stsaver(progress, loss=loss_val) logger.info('[epoch{}/batch{}] loss={:.2f}'.format( epoch, j, loss_val))