Esempio n. 1
0
    def _log_device_tuning(self, global_step):
        if not ivy.exists(self._writer):
            raise Exception('torch must be installed in order to use the file writer for tensorboard logging.')
        if not ivy.exists(self._dev_manager):
            raise Exception('Cannot log device manager tuning if the device manager does not exist.'
                            'Please set either of the params: tune_device_allocation=True or tune_splitting=True')

        # device allocation
        # ToDo: log more useful tuning metrics here
        if self._multi_dev and self._spec.tune_device_allocation:
            self._writer.add_scalar('dev_tuning/device_alloc/tune_count',
                                    self._dev_manager._da_tune_count, global_step)
            self._writer.add_scalar('dev_tuning/device_alloc/unit_tune_count',
                                    self._dev_manager._unit_da_tune_count, global_step)
            self._writer.add_scalar('dev_tuning/device_alloc/step_time',
                                    self._dev_manager._da_step_time, global_step)
            for ds, split in self._dev_manager._dev_strs_da.items():
                self._writer.add_scalar('dev_tuning/device_alloc/split_sizes/{}'.format(ds), split, global_step)

        # per-device splitting
        # ToDo: log more useful tuning metrics here
        if self._spec.tune_splitting:
            self._writer.add_scalar('dev_tuning/splitting/tune_count',
                                    self._dev_manager._ds_tune_count, global_step)
            self._writer.add_scalar('dev_tuning/splitting/step_time',
                                    self._dev_manager._ds_step_time, global_step)
            for ds, split in self._dev_manager._dev_strs_ds.items():
                self._writer.add_scalar('dev_tuning/splitting/split_factors/{}'.format(ds), split, global_step)
Esempio n. 2
0
 def _dev_manager_execute_with_grads(self, network, batch):
     # ToDo: assign this function in constructor rather than performing checks on each training step
     dev_manager_exists = ivy.exists(self._dev_manager)
     tuned = not dev_manager_exists or self._dev_manager.tuned
     if self._compile_network_once_tuned and tuned:
         network.compile_on_next_step()
         self._compile_network_once_tuned = False
     if self._compile_optimizer_once_tuned and tuned:
         self._optimizer.compile_on_next_step()
         self._compile_optimizer_once_tuned = False
     if ivy.exists(self._dev_manager):
         if self._multi_dev:
             if not isinstance(batch, ivy.MultiDevContainer):
                 batch = batch.to_multi_dev(self._spec.dev_strs)
             return self._dev_manager.map(distributed={"batch": batch.at_devs()},
                                          to_clone={"network_v": network.v})
         ret = None
         oom = False
         while ret is None:
             try:
                 ret = self._split_execute_with_grads(network, self._spec.dev_strs[0], batch, network.v)
             except RuntimeError as e:
                 if oom:
                     raise Exception('Out of Memory Error raise twice consecutively {}'.format(e))
                 oom = True
             self._dev_manager.tune_step(oom)
         return ret
     return self._split_execute_with_grads(network, self._spec.dev_strs[0], batch, network.v)
Esempio n. 3
0
    def __init__(self, dataset_spec, batch_size, starting_idx, num_sequences, window_size=1,
                 num_workers=1, cache_size=0, unused_key_chains=None, custom_init_fn=None,
                 container_load_mode='dynamic', custom_container_load_fn=None, preshuffle_data=True,
                 shuffle_buffer_size=0, with_prefetching=True, queue_timeout=None, post_proc_fn=None,
                 prefetch_to_devs='gpu:0', single_pass=False, array_strs=None, float_strs=None, uint8_strs=None,
                 custom_img_strs=None, custom_img_fns=None, custom_strs=None, custom_fns=None, array_mode='pickled',
                 load_gray_as_rgb=True, containers_to_skip=None, **kwargs):

        kw = locals_to_kwargs(locals())

        unused_key_chains = ivy.default(unused_key_chains, [])
        array_strs = ivy.default(array_strs, [])
        float_strs = ivy.default(float_strs, [])
        uint8_strs = ivy.default(uint8_strs, [])
        custom_img_strs = ivy.default(custom_img_strs, [[]])
        custom_img_fns = ivy.default(custom_img_fns, [])
        custom_strs = ivy.default(custom_strs, [[]])
        custom_fns = ivy.default(custom_fns, [])
        containers_to_skip = ivy.default(containers_to_skip, [])
        prefetch_to_devs = prefetch_to_devs if ivy.gpu_is_available() or isinstance(prefetch_to_devs, list) else False
        assert container_load_mode in ['preload', 'dynamic', 'custom']
        if container_load_mode == 'custom':
            assert ivy.exists(custom_container_load_fn)
        else:
            assert ivy.exists(dataset_spec.cont_fname_template)

        super(SeqDataLoaderSpec, self).__init__(dataset_spec,
                                                batch_size=batch_size,
                                                window_size=window_size,
                                                starting_idx=starting_idx,
                                                num_sequences=num_sequences,
                                                num_workers=num_workers,
                                                cache_size=cache_size,
                                                unused_key_chains=unused_key_chains,
                                                custom_init_fn=custom_init_fn,
                                                container_load_mode=container_load_mode,
                                                custom_container_load_fn=custom_container_load_fn,
                                                preshuffle_data=preshuffle_data,
                                                shuffle_buffer_size=shuffle_buffer_size,
                                                with_prefetching=with_prefetching,
                                                post_proc_fn=post_proc_fn,
                                                prefetch_to_devs=prefetch_to_devs,
                                                single_pass=single_pass,
                                                array_strs=array_strs,
                                                float_strs=float_strs,
                                                uint8_strs=uint8_strs,
                                                custom_img_strs=custom_img_strs,
                                                custom_img_fns=custom_img_fns,
                                                custom_strs=custom_strs,
                                                custom_fns=custom_fns,
                                                array_mode=array_mode,
                                                load_gray_as_rgb=load_gray_as_rgb,
                                                containers_to_skip=containers_to_skip,
                                                **kwargs)
        self.queue_timeout = ivy.default(queue_timeout, ivy.queue_timeout())  # conflicts with ivy.Container argument

        self._kwargs = kw
Esempio n. 4
0
def get_json_args(json_spec_path, keys_to_ignore, keychains_to_ignore, keychain_to_show, defaults=False,
                  store_duplicates=False, current_dir_only=False, spec_names=None):
    spec_names = ivy.default(spec_names,
                             [item.split('.json')[0] for item in os.listdir(json_spec_path) if '.json' in item])
    if defaults:
        defaults = '.defaults'
    else:
        defaults = ''
    cont = ivy.Container()
    if current_dir_only:
        for spec_name in spec_names:
            fpath = os.path.join(json_spec_path, spec_name + '.json' + defaults)
            if os.path.isfile(fpath):
                cont[spec_name] = parse_json_to_cont(fpath)
    else:
        for spec_name in spec_names:
            cont[spec_name] = \
                json_spec_from_fpath(json_spec_path, spec_name + '.json' + defaults, store_duplicates)
    for keychain_to_ignore in keychains_to_ignore:
        if keychain_to_ignore in cont:
            cont[keychain_to_ignore] = 'not_shown'
    cont = cont.set_at_keys(dict(zip(keys_to_ignore, ['not_shown']*len(keys_to_ignore))))
    if ivy.exists(keychain_to_show):
        cont = cont[keychain_to_show]
    return cont
Esempio n. 5
0
 def close(self) -> None:
     """
     Close this trainer, and destroy all child objects or processes which may not be garbage collected.
     """
     if ivy.exists(self._dev_manager):
         self._dev_manager.__del__()
     self._spec.data_loader.close()
Esempio n. 6
0
 def step(self):
     total_iterations = min(
         self._trainer_global_step +
         self._train_steps_per_tune_step,
         self._trainer_total_iterations)
     self._trainer_global_step = self._trainer.train(
         self._trainer_global_step, total_iterations)
     self.timestep += 1
     ret_dict = {
         'timestep': self.timestep,
         'cost': ivy.to_numpy(self._trainer.moving_average_loss)
     }
     if self._trainer_global_step >= self._trainer_total_iterations:
         if self._save_at_end:
             self._trainer._save()
         if self._log_at_end and ivy.exists(
                 self._trainer._training_batch):
             self._trainer._log_scalars()
         if self._vis_at_end:
             dl = self._trainer.spec.data_loader
             net = self._trainer.spec.network
             tb = self._trainer._training_batch
             gs = self._trainer._global_step
             self._trainer._write_image_summaries(dl, net, tb, gs)
         ret_dict[tune.result.DONE] = True
     return ret_dict
Esempio n. 7
0
 def _log_device_utilization(self, global_step):
     if not ivy.exists(self._writer):
         raise Exception('torch must be installed in order to use the file writer for tensorboard logging.')
     self._writer.add_scalar('dev_util/CPU', ivy.dev_util('cpu'), global_step)
     for ds in self._spec.dev_strs:
         if 'gpu' not in ds:
             continue
         ds_formatted = ds.replace(':', '_').capitalize()
         self._writer.add_scalar('dev_util/{}'.format(ds_formatted), ivy.dev_util(ds), global_step)
Esempio n. 8
0
 def _log_scalars(self):
     if ivy.exists(self._writer):
         if self._spec.log_time:
             self._writer.add_scalar('time between logs', time.perf_counter() - self._start_time, self._global_step)
         if self._spec.log_learning_rate:
             self._writer.add_scalar('learning rate', self._learning_rate, self._global_step)
     self._write_scalar_summaries(self._spec.data_loader, self._network, self._training_batch,
                                  self._global_step)
     self._start_time = time.perf_counter()
Esempio n. 9
0
 def _update_seq_info_for_window(self, seq_info):
     if not ivy.exists(seq_info):
         return
     seq_idx = int(seq_info.seq_idx[0])
     seq_len = int(seq_info.length[0])
     new_len = self._compute_seq_len(seq_idx, seq_len,
                                     self._spec.containers_to_skip)
     seq_info = seq_info.copy()
     seq_info.length = ivy.ones_like(seq_info.length) * new_len
     return seq_info
Esempio n. 10
0
 def restore(self, checkpoint_path):
     checkpoint = ivy.Container.from_disk_as_hdf5(checkpoint_path)
     loaded_v = checkpoint.network.map(
         lambda x, kc: ivy.variable(ivy.to_dev(x, self._net._dev_str)))
     if ivy.exists(self._net.v):
         # if build_mode is 'on_call', the network variables will not have been built yet
         assert (self._net.v.shapes == loaded_v.shapes).all_true(
             assert_is_bool=True)
     self._net.v = loaded_v
     self._optimizer.set_state(
         checkpoint.optimizer.map(
             lambda x, kc: ivy.to_dev(x, self._net.spec.dev_strs[0])))
Esempio n. 11
0
 def _log_memory(self, global_step):
     if not ivy.exists(self._writer):
         raise Exception('torch must be installed in order to use the file writer for tensorboard logging.')
     self._writer.add_scalar('memory/RAM/global/percent_used', ivy.percent_used_mem_on_dev('cpu'), global_step)
     self._writer.add_scalar('memory/RAM/local/percent_used',
                             ivy.percent_used_mem_on_dev('cpu', process_specific=True), global_step)
     for ds in self._spec.dev_strs:
         if 'gpu' not in ds:
             continue
         ds_formatted = ds.replace(':', '_').capitalize()
         self._writer.add_scalar('memory/{}/global/percent_used'.format(ds_formatted),
                                 ivy.percent_used_mem_on_dev(ds), global_step)
Esempio n. 12
0
 def _initialize_model(self, checkpoint_path=None):
     self._pre_init()
     if self._net_spec.build_mode == 'explicit':
         self._network.build()
     first_batch = self._spec.data_loader.get_first_batch()
     if ivy.exists(self._dev_manager):
         self._dev_manager.dim_size = first_batch.shape[0]
     # for on_call builds
     self._compute_cost(self._network, first_batch[0:1], self._spec.dev_strs[0])
     # compile
     if self._spec.compile_graph:
         valid_modes = ['network', 'optimizer', 'all']
         assert self._spec.compile_graph in ['network', 'optimizer', 'all'], 'invalid value for compile_graph, ' \
                                                                             'must be one of {}'.format(valid_modes)
         if self._spec.compile_graph in ['network', 'all']:
             self._compile_network_once_tuned = True
         if self._spec.compile_graph in ['optimizer', 'all']:
             self._compile_optimizer_once_tuned = True
     if self._spec.save_spec:
         self._save_spec_to_disk()
     self._save_info_to_disk()
     self._init_checkpoint_manager()
     if not checkpoint_path:
         checkpoint_path = self._chkpt_manager.latest_checkpoint_fpath
     if self._spec.ld_chkpt is True and not ivy.exists(checkpoint_path):
         raise Exception('Unable to load checkpoint, no checkpoint files found.')
     elif self._spec.ld_chkpt in [True, 'try'] and ivy.exists(checkpoint_path):
         self._chkpt.restore(checkpoint_path)
         logging.info('loaded checkpoints from {}'.format(checkpoint_path))
         starting_iteration = int(checkpoint_path.split('-')[-1].split('.')[0])
         logging.info('#--------------#\n# MODEL LOADED #\n#--------------#')
         self._post_init()
         if ivy.exists(self._spec.starting_iteration):
             assert starting_iteration == self._spec.starting_iteration
         return starting_iteration
     else:
         logging.info('#-------------#\n# MODEL BUILT #\n#-------------#')
     self._global_step = self._spec.starting_iteration
     self._post_init()
     return ivy.default(self._spec.starting_iteration, 0)
Esempio n. 13
0
 def _float_img_fn(self, filepaths_in_window):
     imgs = list()
     for filepath in filepaths_in_window:
         str_path = bytearray(ivy.to_numpy(filepath).tolist()).decode()
         full_path = os.path.abspath(
             os.path.join(self._container_data_dir, str_path))
         if not ivy.exists(cv2):
             raise Exception(
                 'in order to use _float_img_fn, opencv for python must be installed.'
                 'To install opencv, run pip install opencv-python.')
         img_rgba = cv2.imread(full_path, -1)
         img = ivy.array(
             np.frombuffer(img_rgba.tobytes(),
                           np.float32).reshape((1, ) + img_rgba.shape[:-1]))
         imgs.append(img)
     return ivy.concatenate(imgs, 0)
Esempio n. 14
0
def build_data_loader(data_loader_class=None,
                      dataset_dirs_args=None,
                      dataset_dirs_class=None,
                      dataset_dirs=None,
                      dataset_spec_args=None,
                      dataset_spec_class=None,
                      dataset_spec=None,
                      data_loader_spec_args=None,
                      data_loader_spec_class=None,
                      data_loader_spec=None,
                      json_spec_path=None,
                      spec_cont=None,
                      class_priority=False):
    """
    build data loader
    """

    # build data loader specification
    data_loader_spec = ivy.default(
        data_loader_spec,
        build_data_loader_spec(
            dataset_dirs_args=dataset_dirs_args,
            dataset_dirs_class=dataset_dirs_class,
            dataset_dirs=dataset_dirs,
            dataset_spec_args=dataset_spec_args,
            dataset_spec_class=dataset_spec_class,
            dataset_spec=dataset_spec,
            data_loader_spec_args=data_loader_spec_args,
            data_loader_spec_class=data_loader_spec_class,
            json_spec_path=json_spec_path,
            spec_cont=spec_cont))

    # override data_loader_class if specified in data_loader_spec
    data_loader_class = ivy.default(ivy.default(
        _import_arg_specified_class_if_present(data_loader_spec, 'data_loader_class'),
        data_loader_class, rev=class_priority),
        None)

    # verify data_loader_class exists
    if not ivy.exists(data_loader_class):
        raise Exception('data_loader_class must either be specified in this build_data_loader() method,'
                        'or data_loader_class attribute must be specified in the data_loader_spec instance')

    # return data loader
    return data_loader_class(data_loader_spec)
Esempio n. 15
0
def build_network(network_class=None,
                  dataset_dirs_args=None,
                  dataset_dirs_class=None,
                  dataset_dirs=None,
                  dataset_spec_args=None,
                  dataset_spec_class=None,
                  dataset_spec=None,
                  network_spec_args=None,
                  network_spec_class=None,
                  network_spec=None,
                  json_spec_path=None,
                  spec_cont=None,
                  class_priority=False):
    """
    build network
    """

    # build network specification
    network_spec = ivy.default(
        network_spec,
        build_network_specification(
            dataset_dirs_args=dataset_dirs_args,
            dataset_dirs_class=dataset_dirs_class,
            dataset_dirs=dataset_dirs,
            dataset_spec_args=dataset_spec_args,
            dataset_spec_class=dataset_spec_class,
            dataset_spec=dataset_spec,
            network_spec_args=network_spec_args,
            network_spec_class=network_spec_class,
            json_spec_path=json_spec_path,
            spec_cont=spec_cont))

    # override network_class if specified in network_spec
    network_class = ivy.default(ivy.default(
        _import_arg_specified_class_if_present(network_spec, 'network_class'),
        network_class, rev=class_priority),
        None)

    # verify network_class exists
    if not ivy.exists(network_class):
        raise Exception('network_class must either be specified in this build_network() method,'
                        'or network_class attribute must be specified in the network_spec instance')

    # network
    return network_class(network_spec)
Esempio n. 16
0
 def _save_info_to_disk(self):
     info_dir = os.path.join(self._spec.log_dir, 'info')
     os.makedirs(info_dir, exist_ok=True)
     info_filepath = _get_valid_filepath(info_dir, 'info', '.txt')
     if not ivy.exists(git):
         logging.warning('no gitpython installation found, not saving git commit hash to disk. '
                         'To install gitpython, run pip install gitpython.')
         return
     try:
         repo = git.Repo(search_parent_directories=True)
         sha = repo.head.object.hexsha
     except (git.exc.InvalidGitRepositoryError, ValueError):
         sha = 'NOT A GIT REPO'
     with open(info_filepath, 'w+') as info_file:
         info_file.writelines(['time of execution:\n',
                               str(datetime.now()) + '\n\n',
                               'git commit hash at time of execution:\n',
                               sha + '\n'])
Esempio n. 17
0
 def to_filepaths(self):
     if not ivy.exists(self._fpath_template):
         raise Exception(
             'to_filepaths method is not valid if fpath_template has not been specified'
             'in the constructor.')
     seq_idxs = self._seq_idxs.values()
     sizes = [
         self._raw_sizes
         if self._constant_size else self._raw_sizes[seq_idx]
         for seq_idx in seq_idxs
     ]
     rets = [[
         self._fpath_template % (seq_idx, win_idx)
         for win_idx in range(size) if not SeqDataLoader._skip_cont(
             seq_idx, win_idx, self._conts_to_skip)
     ] for seq_idx, size in zip(seq_idxs, sizes)]
     return [
         r + [''] * (self._max_seq_len - len(r)) for r in rets
         if ''.join(r) != ''
     ]
Esempio n. 18
0
def prune_checkpoints_in_dir(chkpts_dir, cutoff, last_only, remove_all):
    print('pruning checkpoints in {}'.format(chkpts_dir))
    checkpoint_fnames = os.listdir(chkpts_dir)
    if len(checkpoint_fnames) == 0:
        print('No checkpoints found in {}'.format(chkpts_dir))
        return
    checkpoint_fnames.sort(key=lambda x: int(x.split('-')[-1].split('.')[0]))
    if last_only:
        [
            os.remove(os.path.join(chkpts_dir, cfn))
            for cfn in checkpoint_fnames[:-1]
        ]
        return
    if remove_all:
        [os.remove(os.path.join(chkpts_dir, cfn)) for cfn in checkpoint_fnames]
        return
    for checkpoint_fname in checkpoint_fnames:
        checkpoint_val = int(checkpoint_fname.split('-')[-1].split('.')[0])
        if ivy.exists(cutoff) and checkpoint_val > cutoff:
            os.remove(os.path.join(chkpts_dir, checkpoint_fname))
Esempio n. 19
0
 def _custom_img_fn(self, filepaths_in_window, fn):
     imgs = list()
     for filepath in filepaths_in_window:
         str_path = bytearray(ivy.to_numpy(filepath).tolist()).decode()
         full_path = os.path.abspath(
             os.path.join(self._container_data_dir, str_path))
         if not ivy.exists(cv2):
             raise Exception(
                 'in order to use _custom_img_fn, opencv for python must be installed.'
                 'To install opencv, run pip install opencv-python.')
         img_raw = cv2.imread(full_path, -1)
         img = fn(img_raw)
         imgs.append(img)
     img0 = imgs[0]
     if isinstance(img0, ivy.Container):
         return ivy.Container.concat(imgs, 0)
     elif ivy.is_array(img0):
         return ivy.concatenate(imgs, 0)
     else:
         raise Exception(
             'custom image functions should either return an array or an ivy.Container instance,'
             'but found {} or type {}'.format(img0, type(img0)))
Esempio n. 20
0
 def _uint8_img_fn(self, filepaths_in_window):
     imgs = list()
     for filepath in filepaths_in_window:
         str_path = bytearray(ivy.to_numpy(filepath).tolist()).decode()
         full_path = os.path.abspath(
             os.path.join(self._container_data_dir, str_path))
         if not ivy.exists(cv2):
             raise Exception(
                 'in order to use _uint8_img_fn, opencv for python must be installed.'
                 'To install opencv, run pip install opencv-python.')
         img_rgb = cv2.imread(full_path, -1)
         if len(img_rgb.shape) == 2:
             if not self._spec.load_gray_as_rgb:
                 raise Exception(
                     'Found an image with shape {}, but load_gray_as_rgb is set to False.'
                     'Set this to True in order to tile grayscale images to RGB.'
                     .format(img_rgb.shape))
             img_rgb = np.tile(np.expand_dims(img_rgb, -1), (1, 1, 3))
         img = ivy.array(np.expand_dims(img_rgb.astype(np.float32),
                                        0)) / 255
         imgs.append(img)
     return ivy.concatenate(imgs, 0)
Esempio n. 21
0
 def _log_nested(self, nest, global_step, name_hierarchy, spec):
     if not ivy.exists(self._writer):
         raise Exception('torch must be installed in order to use the file writer for tensorboard logging.')
     if 'global_vector_norm' in spec:
         self._writer.add_scalar(name_hierarchy + '/global vector norm',
                                 ivy.to_scalar(ivy.to_native(nest.vector_norm(global_norm=True))), global_step)
     for k, v in nest.items():
         new_name_hierarchy = name_hierarchy + '/' + k
         if isinstance(v, dict):
             self._log_nested(v, global_step, new_name_hierarchy, spec)
         else:
             if 'mean' in spec:
                 self._writer.add_scalar(new_name_hierarchy + '/mean',
                                         ivy.to_scalar(ivy.to_native(ivy.reduce_mean(v))), global_step)
             if 'abs_mean' in spec:
                 self._writer.add_scalar(new_name_hierarchy + '/abs mean',
                                         ivy.to_scalar(ivy.to_native(ivy.reduce_mean(ivy.abs(v)))), global_step)
             if 'var' in spec:
                 self._writer.add_scalar(new_name_hierarchy + '/var',
                                         ivy.to_scalar(ivy.to_native(ivy.reduce_var(v))), global_step)
             if 'abs_var' in spec:
                 self._writer.add_scalar(new_name_hierarchy + '/abs var',
                                         ivy.to_scalar(ivy.to_native(ivy.reduce_var(ivy.abs(v)))), global_step)
             if 'min' in spec:
                 self._writer.add_scalar(new_name_hierarchy + '/min',
                                         ivy.to_scalar(ivy.to_native(ivy.reduce_min(v))), global_step)
             if 'abs_min' in spec:
                 self._writer.add_scalar(new_name_hierarchy + '/abs min',
                                         ivy.to_scalar(ivy.to_native(ivy.reduce_min(ivy.abs(v)))), global_step)
             if 'max' in spec:
                 self._writer.add_scalar(new_name_hierarchy + '/max',
                                         ivy.to_scalar(ivy.to_native(ivy.reduce_max(v))), global_step)
             if 'abs_max' in spec:
                 self._writer.add_scalar(new_name_hierarchy + '/abs max',
                                         ivy.to_scalar(ivy.to_native(ivy.reduce_max(ivy.abs(v)))), global_step)
             if 'vector_norm' in spec:
                 self._writer.add_scalar(new_name_hierarchy + '/vector norm',
                                         ivy.to_scalar(ivy.to_native(ivy.vector_norm(v))), global_step)
Esempio n. 22
0
def print_json_args(base_dir=None, keys_to_ignore=None, keychains_to_ignore=None):
    if not ivy.exists(base_dir):
        base_dir = os.getcwd()
    ivy.set_framework('numpy')
    parser = argparse.ArgumentParser()
    parser.add_argument('-sd', '--sub_directory', type=str,
                        help='A sub-directory to print the json args for, default is base_dir passed in.')
    parser.add_argument('-dd', '--diff_directory', type=str,
                        help='The directory from which to compare the difference in specifications.')
    parser.add_argument('-kti', '--keys_to_ignore', type=str, default=keys_to_ignore,
                        help='Keys to ignore when printing the specification.')
    parser.add_argument('-kcti', '--keychains_to_ignore', type=str, default=keychains_to_ignore,
                        help='Key-chains to ignore when printing the specification.')
    parser.add_argument('-kcts', '--keychain_to_show', type=str,
                        help='The key-chain to show. Default is None, in which case all key-chains are shown.')
    parser.add_argument('-sn', '--spec_names', type=str,
                        help='The specification names for the json files. Default is ivy_builder defaults of'
                             '[ dataset_dirs | dataset | data_loader| network | trainer |]')
    parser.add_argument('-d', '--show_defaults', action='store_true',
                        help='Whether to show the default json arguments.'
                             'If unset then the current arguments are shown, not the defaut values.')
    parser.add_argument('-c', '--current_dir_only', action='store_true',
                        help='Whether to only show the json arguments for the current directory,'
                             'without searching through parent directories also.')
    parser.add_argument('-sdo', '--show_diff_only', action='store_true',
                        help='Whether to only show the difference between the current directory'
                             'and the diff directory.')
    parser.add_argument('-sso', '--show_same_only', action='store_true',
                        help='Whether to only show the same entries between the current directory'
                             'and the diff directory.')
    parsed_args = parser.parse_args()
    if (parsed_args.show_diff_only or parsed_args.show_same_only) and not parsed_args.diff_directory:
        raise Exception('show_diff_only and show_same_only flags are only applicable if diff_directory is set.')
    if parsed_args.show_diff_only and parsed_args.show_same_only:
        raise Exception('show_diff_only and show_same_only cannot both be set, please choose one to set.')
    if ivy.exists(parsed_args.spec_names):
        spec_names = [kc[1:-1] for kc in ''.join(parsed_args.spec_names[1:-1]).split(', ')]
    else:
        spec_names = None
    if ivy.exists(parsed_args.sub_directory):
        sub_dir = os.path.normpath(os.path.join(base_dir, parsed_args.sub_directory))
    else:
        sub_dir = base_dir
    if ivy.exists(parsed_args.keys_to_ignore):
        keys_to_ignore = [kc[1:-1] for kc in ''.join(parsed_args.keys_to_ignore[1:-1]).split(', ')]
    else:
        keys_to_ignore = list()
    if ivy.exists(parsed_args.keychains_to_ignore):
        keychains_to_ignore = [kc[1:-1] for kc in ''.join(parsed_args.keychains_to_ignore[1:-1]).split(',')]
    else:
        keychains_to_ignore = list()
    these_json_args = get_json_args(
        sub_dir, keys_to_ignore, keychains_to_ignore, parsed_args.keychain_to_show, parsed_args.show_defaults,
        store_duplicates=True, current_dir_only=parsed_args.current_dir_only, spec_names=spec_names)
    if ivy.exists(parsed_args.diff_directory):
        other_sub_dir = os.path.normpath(os.path.join(base_dir, parsed_args.diff_directory))
        if other_sub_dir == sub_dir:
            raise Exception('Invalid diff_directory {} selected, it is the same as the sub_directory {}.'.format(
                other_sub_dir, sub_dir))
        other_json_args = get_json_args(
            other_sub_dir, keys_to_ignore, keychains_to_ignore, parsed_args.keychain_to_show, parsed_args.show_defaults,
            store_duplicates=True, current_dir_only=parsed_args.current_dir_only, spec_names=spec_names)
        diff_keys = 'diff'
        for sub_folder, other_sub_folder in zip(sub_dir.split('/'), other_sub_dir.split('/')):
            if sub_folder != other_sub_folder:
                diff_keys = [sub_folder, other_sub_folder]
                break
        if parsed_args.show_diff_only:
            mode = 'diff_only'
        elif parsed_args.show_same_only:
            mode = 'same_only'
        else:
            mode = 'all'
        diff_json_args = ivy.Container.diff(these_json_args, other_json_args, mode=mode, diff_keys=diff_keys)
        keyword_color_dict = {'duplicated': 'magenta'}
        if isinstance(diff_keys, list):
            diff_keys_dict = dict(zip(diff_keys, ['red'] * 2))
            keyword_color_dict = {**keyword_color_dict, **diff_keys_dict}
        print(ivy.Container(diff_json_args, keyword_color_dict=keyword_color_dict))
    else:
        print(ivy.Container(these_json_args, keyword_color_dict={'duplicated': 'magenta'}))
    ivy.unset_framework()
Esempio n. 23
0
    def _get_dataset(self, starting_example, ending_example):
        class ContainerIdxMap:
            def __init__(self,
                         sizes,
                         fpath_template=None,
                         seq_idxs=None,
                         start=None,
                         end=None,
                         max_seq_len=None,
                         conts_to_skip=None,
                         pruned_sizes=None):
                if isinstance(sizes, (tuple, list)):
                    pruned_sizes = ivy.default(pruned_sizes, [
                        SeqDataLoader._compute_seq_len(i, sl, conts_to_skip)
                        for i, sl in enumerate(sizes)
                    ])
                    num_empty = sum([ps == 0 for ps in pruned_sizes])
                    self._raw_sizes = dict(
                        zip(range(start, end + 1 + num_empty),
                            sizes[start:end + 1 + num_empty]))
                    self._pruned_sizes = dict(
                        zip(range(start, end + 1 + num_empty),
                            pruned_sizes[start:end + 1 + num_empty]))
                elif isinstance(sizes, (int, dict)):
                    self._raw_sizes = sizes
                    self._pruned_sizes = ivy.default(pruned_sizes, sizes)
                    if isinstance(self._pruned_sizes, int):
                        pruned_dict = dict()
                        for seq_idx, win_idx in conts_to_skip:
                            if seq_idx not in pruned_dict:
                                pruned_dict[seq_idx] = list()
                            pruned_dict[seq_idx].append(win_idx)
                        pruned_dict = {
                            k: len(set(v))
                            for k, v in pruned_dict.items()
                        }
                        pruned_sizes_dict = {
                            k: self._pruned_sizes - num_pruned
                            for k, num_pruned in pruned_dict.items()
                        }
                        num_empty = sum(
                            [size == 0 for size in pruned_sizes_dict.values()])
                        pruned_sizes = collections.defaultdict(
                            lambda: self._pruned_sizes, pruned_sizes_dict)
                    else:
                        num_empty = sum([ps == 0 for ps in self._pruned_sizes])
                else:
                    raise Exception(
                        'Invalid type for sizes, expected one of int, dict, tuple or list,'
                        'but found {} or type {}'.format(sizes, type(sizes)))
                self._constant_size = isinstance(self._raw_sizes, int)
                if max_seq_len:
                    self._max_seq_len = max_seq_len
                else:
                    self._max_seq_len = self._pruned_sizes if self._constant_size else max(
                        self._pruned_sizes.values())
                self._fpath_template = fpath_template
                self._conts_to_skip = conts_to_skip
                if seq_idxs:
                    self._seq_idxs = seq_idxs
                else:
                    vals = [
                        v
                        for i, v in enumerate(range(start, end + 1 +
                                                    num_empty))
                        if pruned_sizes[i] > 0
                    ]
                    keys = range(0, min(end - start + 1 + num_empty,
                                        len(vals)))
                    self._seq_idxs = dict(zip(keys, vals))

            def __getitem__(self, slice_obj):
                if isinstance(slice_obj, slice):
                    seq_idxs = collections.OrderedDict([
                        (i, self._seq_idxs[idx]) for i, idx in enumerate(
                            range(slice_obj.start, slice_obj.stop,
                                  ivy.default(slice_obj.step, 1)))
                    ])
                elif isinstance(slice_obj, int):
                    seq_idxs = collections.OrderedDict(
                        {0: self._seq_idxs[slice_obj]})
                else:
                    raise Exception(
                        'Invalid type for slice_obj, expected either slice or int,'
                        'but found {} of type {}'.format(
                            slice_obj, type(slice_obj)))
                if self._constant_size:
                    sizes = self._raw_sizes
                else:
                    sizes = collections.OrderedDict({
                        seq_idx: self._raw_sizes[seq_idx]
                        for seq_idx in seq_idxs.values()
                    })
                return ContainerIdxMap(sizes,
                                       self._fpath_template,
                                       seq_idxs,
                                       max_seq_len=self._max_seq_len,
                                       conts_to_skip=self._conts_to_skip,
                                       pruned_sizes=self._pruned_sizes)

            def __len__(self):
                return len(self._seq_idxs)

            def shuffle(self):
                mapped_idxs = list(self._seq_idxs.values())
                np.random.shuffle(mapped_idxs)
                self._seq_idxs = collections.OrderedDict(
                    zip(self._seq_idxs.keys(), mapped_idxs))

            def to_idxs(self):
                seq_idxs = self._seq_idxs.values()
                sizes = [
                    self._raw_sizes
                    if self._constant_size else self._raw_sizes[seq_idx]
                    for seq_idx in seq_idxs
                ]
                rets = [[(seq_idx, win_idx) for win_idx in range(size)
                         if not SeqDataLoader._skip_cont(
                             seq_idx, win_idx, self._conts_to_skip)]
                        for seq_idx, size in zip(seq_idxs, sizes)]
                return [
                    r + [(None, None)] * (self._max_seq_len - len(r))
                    for r in rets if list(set(r)) != [None]
                ]

            def to_filepaths(self):
                if not ivy.exists(self._fpath_template):
                    raise Exception(
                        'to_filepaths method is not valid if fpath_template has not been specified'
                        'in the constructor.')
                seq_idxs = self._seq_idxs.values()
                sizes = [
                    self._raw_sizes
                    if self._constant_size else self._raw_sizes[seq_idx]
                    for seq_idx in seq_idxs
                ]
                rets = [[
                    self._fpath_template % (seq_idx, win_idx)
                    for win_idx in range(size) if not SeqDataLoader._skip_cont(
                        seq_idx, win_idx, self._conts_to_skip)
                ] for seq_idx, size in zip(seq_idxs, sizes)]
                return [
                    r + [''] * (self._max_seq_len - len(r)) for r in rets
                    if ''.join(r) != ''
                ]

            @property
            def sizes(self):
                return self._pruned_sizes

        # container filepaths
        if self._spec.container_load_mode in ['preload', 'dynamic']:
            fpath_template = os.path.join(
                self._container_data_dir,
                self._spec.dataset_spec.cont_fname_template)
        else:
            fpath_template = None
        container_idx_map = ContainerIdxMap(
            self._spec.dataset_spec.unpruned_sequence_lengths,
            fpath_template,
            start=starting_example,
            end=ending_example,
            conts_to_skip=self._spec.containers_to_skip)

        if self._spec.num_sequences != -1:
            container_idx_map = container_idx_map[0:self._spec.num_sequences]

        # shuffle sequences
        if self._spec.preshuffle_data:
            container_idx_map.shuffle()

        # extract sequence lengths
        if self._fixed_sequence_length:
            self._sequence_lengths =\
                collections.OrderedDict(zip(range(len(container_idx_map)),
                                            [self._spec.dataset_spec.sequence_lengths] * len(container_idx_map)))
            self._windows_per_seq = self._sequence_lengths[
                0] - self._window_size + 1
            # windowing values
            window_idxs_per_seq = ivy.reshape(
                ivy.arange(self._windows_per_seq, 0, 1),
                (self._windows_per_seq, 1))
            gather_idxs_list = list()
            for x in window_idxs_per_seq:
                gather_idxs_list.append(
                    ivy.expand_dims(
                        ivy.arange(x[0] + self._window_size, x[0], 1), 0))
            gather_idxs = ivy.concatenate(gather_idxs_list, 0)
            self._gather_idxs = \
                ivy.to_numpy(ivy.reshape(gather_idxs, (self._windows_per_seq * self._window_size, 1))).tolist()
        else:
            self._sequence_lengths = container_idx_map.sizes

        # maybe pre-load containers
        if self._spec.container_load_mode == 'preload':
            # load containers with vector data and image filepath entries
            container_slices = self._get_containers_w_filepath_img_entries_as_tensor_slices(
                container_idx_map.to_filepaths())
            if self._first_frame_validity_fn is not None:
                container_slices =\
                    self._first_frame_validity_fn(container_slices, [ending_example - starting_example + 1])

            # prune unwanted chains of keys
            if 'unused_key_chains' in self._spec:
                container_slices = self._prune_unused_key_chains(
                    container_slices)

            dataset = Dataset(ivy.Container.list_stack([
                c[0]
                for c in container_slices.unstack(0, container_slices.shape[0])
            ], 0),
                              'base',
                              container_slices.shape[0],
                              numpy_loading=True,
                              cache_size=self._base_cache_size,
                              queue_timeout=self._spec.queue_timeout)
        else:
            if self._spec.container_load_mode == 'dynamic':
                # load containers with filepath entries
                dataset = Dataset(ivy.Container({'fpaths': container_idx_map}),
                                  'base',
                                  len(container_idx_map),
                                  trans_fn=lambda cont: cont.map(
                                      lambda x_, kc: x_.to_filepaths()),
                                  elementwise_query_fn=False,
                                  numpy_loading=True,
                                  cache_size=self._base_cache_size,
                                  queue_timeout=self._spec.queue_timeout)
                dataset = dataset.map('loaded_json', self._load_json_files,
                                      self._num_workers.loaded_json)
                dataset = dataset.map('parsed_json', self._parse_json_strings,
                                      self._num_workers.parsed_json)
            else:
                dataset = Dataset(ivy.Container({'idx_map':
                                                 container_idx_map}),
                                  'base',
                                  len(container_idx_map),
                                  trans_fn=lambda cont: self._spec.
                                  custom_container_load_fn(self, cont),
                                  elementwise_query_fn=False,
                                  numpy_loading=True,
                                  cache_size=self._base_cache_size,
                                  queue_timeout=self._spec.queue_timeout)
            if 'unused_key_chains' in self._spec:
                dataset = dataset.map('keychain_pruned',
                                      self._prune_unused_key_chains,
                                      self._num_workers.keychain_pruned)
            if self._first_frame_validity_fn is not None:
                dataset = dataset.map(
                    'valid_first_frames',
                    lambda x_: self._first_frame_validity_fn(x_, None),
                    self._num_workers.valid_first_frames)
        if not (self._spec.dataset_spec.sequence_lengths == 1
                and self._window_size == 1):
            # ToDo: add other conditionals which make the loading more efficient if only one of the
            #  above two conditions is True
            dataset = dataset.map(
                'windowed', self._group_container_into_windowed_container,
                self._num_workers.windowed)
            dataset = dataset.unbatch(
                'unbatched',
                self._num_workers.unbatched,
                batch_sizes=[
                    max(seq_len, self._window_size) - self._window_size + 1
                    for seq_len in self._sequence_lengths.values()
                    if seq_len > 0
                ])
        if self._spec.shuffle_buffer_size > 0:
            dataset = dataset.shuffle('shuffled',
                                      self._spec.shuffle_buffer_size,
                                      self._num_workers.shuffled)
        dataset = dataset.map('loaded_data',
                              self._load_data_from_filepath_tensors,
                              self._num_workers.loaded_data)
        dataset = dataset.batch('batched', self._batch_size,
                                self._num_workers.batched)
        dataset = dataset.map(
            'from_np',
            lambda cont: cont.map(lambda x_, kc: ivy.array(x_, dev_str='cpu')),
            self._num_workers.from_np,
            numpy_loading=False)
        if ivy.exists(self._spec.post_proc_fn):
            dataset = dataset.map('post_processed', self._spec.post_proc_fn,
                                  self._num_workers.post_processed)
        if self._spec.with_prefetching:
            dataset = dataset.prefetch('prefetch')
        # ToDo: find way to make pre-fetching to GPU actually pre-fetch, ideally using multi-processing.
        #  For example, swapping prefetch and to_gpu ops around would work if to_gpu could accept self._num_workers.
        if self._spec.prefetch_to_devs:
            if isinstance(self._spec.prefetch_to_devs, str):
                dataset = dataset.to_dev('to_dev', self._spec.prefetch_to_devs)
            elif len(self._spec.prefetch_to_devs) == 1:
                dataset = dataset.to_dev('to_dev',
                                         self._spec.prefetch_to_devs[0])
            else:
                dataset = dataset.to_devs('to_devs',
                                          self._spec.prefetch_to_devs)
        return dataset
Esempio n. 24
0
def build_trainer(data_loader_class=None,
                  network_class=None,
                  trainer_class=None,
                  dataset_dirs_args=None,
                  dataset_dirs_class=None,
                  dataset_dirs=None,
                  dataset_spec_args=None,
                  dataset_spec_class=None,
                  dataset_spec=None,
                  data_loader_spec_args=None,
                  data_loader_spec_class=None,
                  data_loader_spec=None,
                  data_loader=None,
                  network_spec_args=None,
                  network_spec_class=None,
                  network_spec=None,
                  network=None,
                  trainer_spec_args=None,
                  trainer_spec_class=None,
                  trainer_spec=None,
                  json_spec_path=None,
                  spec_cont=None,
                  class_priority=False):
    """
    build trainer
    """

    # build trainer spec
    trainer_spec = ivy.default(
        trainer_spec,
        build_trainer_spec(
            data_loader_class=data_loader_class,
            network_class=network_class,
            dataset_dirs_args=dataset_dirs_args,
            dataset_dirs_class=dataset_dirs_class,
            dataset_dirs=dataset_dirs,
            dataset_spec_args=dataset_spec_args,
            dataset_spec_class=dataset_spec_class,
            dataset_spec=dataset_spec,
            data_loader_spec_args=data_loader_spec_args,
            data_loader_spec_class=data_loader_spec_class,
            data_loader_spec=data_loader_spec,
            data_loader=data_loader,
            network_spec_args=network_spec_args,
            network_spec_class=network_spec_class,
            network_spec=network_spec,
            network=network,
            trainer_spec_args=trainer_spec_args,
            trainer_spec_class=trainer_spec_class,
            json_spec_path=json_spec_path,
            spec_cont=spec_cont))

    # override trainer_class if specified in trainer_spec
    trainer_class = ivy.default(ivy.default(
        _import_arg_specified_class_if_present(trainer_spec, 'trainer_class'),
        trainer_class, rev=class_priority),
        None)

    # verify trainer_class exists
    if not ivy.exists(trainer_class):
        raise Exception('trainer_class must either be specified in this build_trainer() method,'
                        'or trainer_class attribute must be specified in the trainer_spec instance')

    # return trainer
    return trainer_class(trainer_spec)
Esempio n. 25
0
    def custom_container_load_fn(self, cont):

        new_cont = ivy.Container()
        all_idxs = cont.idx_map.to_idxs()

        actions_seqs_list = list()

        seq_idxs_seqs_list = list()
        idxs_seqs_list = list()
        lengths_seqs_list = list()

        for seq in all_idxs:

            action_arrays_list = list()

            seq_idx_arrays_list = list()
            idx_arrays_list = list()

            found_end = False
            j = 0
            idx = 0
            last_idx = 0
            seq_idx = seq[0][0]

            for j, (_, idx) in enumerate(seq):
                if not ivy.exists(idx) and not found_end:
                    found_end = True
                    last_idx = j - 1
                if found_end:
                    idx = last_idx

                action_as_list = self._actions_dict[str(seq_idx)][str(idx)]
                action_arrays_list.append(
                    ivy.array(action_as_list, dtype_str='float32')[0])

                seq_idx_arrays_list.append(
                    ivy.array([seq_idx], dtype_str='float32'))
                idx_arrays_list.append(ivy.array([idx], dtype_str='float32'))
            length_arrays_list = [
                ivy.array([last_idx + 1 if found_end else idx + 1],
                          dtype_str='float32')
            ] * (j + 1)

            action_arrays = ivy.concatenate(action_arrays_list, 0)
            actions_seqs_list.append(action_arrays)

            seq_idx_arrays = ivy.concatenate(seq_idx_arrays_list, 0)
            seq_idxs_seqs_list.append(seq_idx_arrays)
            idx_arrays = ivy.concatenate(idx_arrays_list, 0)
            idxs_seqs_list.append(idx_arrays)
            length_arrays = ivy.concatenate(length_arrays_list, 0)
            lengths_seqs_list.append(length_arrays)

        new_cont.actions = actions_seqs_list

        new_cont.seq_info = ivy.Container()
        new_cont.seq_info.seq_idx = seq_idxs_seqs_list
        new_cont.seq_info.idx = idxs_seqs_list
        new_cont.seq_info.length = lengths_seqs_list

        return new_cont
Esempio n. 26
0
    def __init__(
            self,
            data_loader_class,
            network_class,
            trainer_class,
            dataset_dirs_args: dict = None,
            dataset_dirs_class: DatasetDirs.__base__ = DatasetDirs,
            dataset_dirs: DatasetDirs = None,
            dataset_spec_args: dict = None,
            dataset_spec_class: DatasetSpec.__base__ = DatasetSpec,
            dataset_spec: DatasetSpec = None,
            data_loader_spec_args: dict = None,
            data_loader_spec_class: DataLoaderSpec.__base__ = DataLoaderSpec,
            data_loader_spec: DataLoaderSpec = None,
            data_loader=None,
            network_spec_args: dict = None,
            network_spec_class: NetworkSpec.__base__ = NetworkSpec,
            network_spec: NetworkSpec = None,
            network=None,
            trainer_spec_args: dict = None,
            trainer_spec_class: TrainerSpec.__base__ = TrainerSpec,
            trainer_spec: TrainerSpec = None,
            trainer=None,
            tuner_spec_args: dict = None,
            tuner_spec_class: TunerSpec.__base__ = TunerSpec,
            tuner_spec: TunerSpec = None,
            json_spec_path: str = None,
            spec_cont: dict = None):
        """
        base class for any tune trainers
        """
        if not ivy.exists(tune):
            raise Exception(
                'ray[tune] is needed in order to use the Tuner class, but it is not installed.'
                'Please install via pip install ray[tune]')
        self._data_loader_class = data_loader_class
        self._network_class = network_class
        self._trainer_class = trainer_class
        self._dataset_dirs_args = ivy.default(dataset_dirs_args, dict())
        self._dataset_dirs_class = dataset_dirs_class
        self._dataset_dirs = dataset_dirs
        self._dataset_spec_args = ivy.default(dataset_spec_args, dict())
        self._dataset_spec_class = dataset_spec_class
        self._dataset_spec = dataset_spec
        self._data_loader_spec_args = ivy.default(data_loader_spec_args,
                                                  dict())
        self._data_loader_spec_class = data_loader_spec_class
        self._data_loader_spec = data_loader_spec
        self._data_loader = data_loader
        self._network_spec_args = ivy.default(network_spec_args, dict())
        self._network_spec_class = network_spec_class
        self._network_spec = network_spec
        self._network = network
        self._trainer_spec_args = ivy.default(trainer_spec_args, dict())
        self._trainer_spec_class = trainer_spec_class
        self._trainer_spec = trainer_spec
        self._trainer = trainer
        self._tuner_spec_args = ivy.default(tuner_spec_args, dict())
        self._tuner_spec_class = tuner_spec_class
        self._tuner_spec = tuner_spec
        self._json_spec_path = json_spec_path
        self._spec_cont = spec_cont

        # initialized on _setup
        self._trainer = None

        # builder
        while len(ivy.framework_stack) > 0:
            logging.info(
                'unsetting framework {}, framework stack must be empty when'
                'initializing tuner class.'.format(ivy.framework_stack[-1]))
            ivy.unset_framework()
        self._builder = builder_module

        # tuner spec
        self._spec = self._builder.build_tuner_spec(
            data_loader_class=self._data_loader_class,
            network_class=self._network_class,
            trainer_class=self._trainer_class,
            dataset_dirs_args=self._dataset_dirs_args,
            dataset_dirs_class=self._dataset_dirs_class,
            dataset_dirs=self._dataset_dirs,
            dataset_spec_args=self._dataset_spec_args,
            dataset_spec_class=self._dataset_spec_class,
            dataset_spec=self._dataset_spec,
            data_loader_spec_args=self._data_loader_spec_args,
            data_loader_spec_class=self._data_loader_spec_class,
            data_loader_spec=self._data_loader_spec,
            data_loader=self._data_loader,
            network_spec_args=self._network_spec_args,
            network_spec_class=self._network_spec_class,
            network_spec=self._network_spec,
            network=self._network,
            trainer_spec_args=self._trainer_spec_args,
            trainer_spec_class=self._trainer_spec_class,
            trainer_spec=self._trainer_spec,
            trainer=self._trainer,
            tuner_spec_args=self._tuner_spec_args,
            tuner_spec_class=self._tuner_spec_class,
            json_spec_path=self._json_spec_path,
            spec_cont=self._spec_cont)
        self._spec = _convert_tuner_spec(self._spec)
Esempio n. 27
0
    def __init__(self, data_loader_spec: SeqDataLoaderSpec):
        super(SeqDataLoader, self).__init__(data_loader_spec)

        # cpus
        if 'num_workers' in data_loader_spec:
            self._total_num_workers = data_loader_spec.num_workers
        else:
            self._total_num_workers = multiprocessing.cpu_count()

        # first frame validity
        if 'first_frame_validity_fn' in data_loader_spec:
            self._first_frame_validity_fn = data_loader_spec.first_frame_validity_fn
        else:
            self._first_frame_validity_fn = None

        # data loader specification
        self._spec = data_loader_spec
        self._container_data_dir = os.path.join(
            self._spec.dataset_spec.dirs.dataset_dir, 'containers/')
        self._batch_size = self._spec.batch_size
        self._base_cache_size = self._spec.cache_size * self._spec.batch_size * self._spec.window_size
        self._window_size = self._spec.window_size
        start_idx = self._spec.starting_idx
        end_idx = start_idx + self._spec.num_sequences - 1

        # specs before pruning via containers_to_skip
        self._spec.dataset_spec.unpruned_sequence_lengths = self._spec.dataset_spec.sequence_lengths
        self._spec.unpruned_num_sequences = self._spec.num_sequences

        # sequence lengths and windows per sequence
        if 'sequence_lengths' in self._spec.dataset_spec:
            self._fixed_sequence_length = isinstance(
                self._spec.dataset_spec.sequence_lengths, int)
            if self._fixed_sequence_length:
                self._windows_per_seq = self._spec.dataset_spec.sequence_lengths - (
                    self._window_size - 1)
            else:
                # update sequences lengths
                self._spec.dataset_spec.sequence_lengths =\
                    [self._compute_seq_len(i, sl, self._spec.containers_to_skip)
                     for i, sl in enumerate(self._spec.dataset_spec.sequence_lengths)]
                self._spec.num_sequences =\
                    sum([sl > 0 for sl in self._spec.dataset_spec.sequence_lengths[start_idx:end_idx+1]])
                self._windows_per_seq = ivy.array(
                    self._spec.dataset_spec.sequence_lengths) - (
                        self._window_size - 1)
        else:
            self._fixed_sequence_length = False

        # new end idx following containers_to_skip pruning
        end_idx = start_idx + self._spec.num_sequences - 1

        # compute num workers for each component
        self._compute_num_workers()

        # custom init
        self._custom_init_fn = self._spec.custom_init_fn
        if ivy.exists(self._custom_init_fn):
            self._custom_init_fn(self)

        # dataset
        self._dataset = self._get_dataset(start_idx, end_idx)
        self._iterator = iter(self._dataset)

        # dummy batch
        self._first_batch = None

        # counter
        self._counter = 0