コード例 #1
0
    def _build(self, s_in: Shape, s_out: Shape) -> ShapeList:
        """ build the network, count params, log, maybe load pretrained weights """
        assert isinstance(s_out, Shape), "Attempting to build a network with an output that is not a Shape!"
        s_out_copy = s_out.copy(copy_id=True)
        self.shape_in = s_in.copy(copy_id=True)
        s_out_net = self._build2(s_in, s_out)
        LoggerManager().get_logger().info('Network built, it has %d parameters!' % self.get_num_parameters())

        # validate output shape sizes
        assert isinstance(s_out_net, ShapeList), "The network must output a list of Shapes, one shape per head! (ShapeList)"
        for shape in s_out_net.shapes:
            if not s_out_copy == shape:
                text = "One or more output shapes mismatch: %s, expected: %s" % (s_out_net, s_out_copy)
                if self.assert_output_match:
                    raise ValueError(text)
                else:
                    LoggerManager().get_logger().warning(text)
                    break

        # load weights?
        if len(self.checkpoint_path) > 0:
            path = CheckpointCallback.find_pretrained_weights_path(self.checkpoint_path, self.model_name,
                                                                   raise_missing=len(self.checkpoint_path) > 0)
            num_replacements = 1 if self.is_external() else 999
            self.loaded_weights(CheckpointCallback.load_network(path, self.get_network(), num_replacements))

        self.shape_out = s_out_net.shapes[0].copy(copy_id=True)
        self.shape_in_list = self.shape_in.shape
        self.shape_out_list = self.shape_out.shape
        return s_out_net
コード例 #2
0
    def _build(self, s_in: Shape, s_out: Shape) -> ShapeList:
        LoggerManager().get_logger().info('Building %s:' % self.__class__.__name__)
        rows = [('cell index', 'name', 'class', 'input shapes', '', 'output shapes', '#params')]

        def get_row(idx, name: str, obj: AbstractModule) -> tuple:
            s_in_str = obj.get_shape_in().str()
            s_inner = obj.get_cached('shape_inner')
            s_inner_str = '' if s_inner is None else s_inner.str()
            s_out_str = obj.get_shape_out().str()
            return str(idx), name, obj.__class__.__name__, s_in_str, s_inner_str, s_out_str, count_parameters(obj)

        s_out_data = s_out.copy()
        out_shapes = self.stem.build(s_in)
        final_out_shapes = []
        rows.append(get_row('', '-', self.stem))

        # cells and (aux) heads
        updated_cell_order = []
        for i, cell_name in enumerate(self.cell_order):
            strategy_name, cell = self._get_cell(name=cell_name, cell_index=i)
            assert self.stem.num_outputs() == cell.num_inputs() == cell.num_outputs(), 'Cell does not fit the network!'
            updated_cell_order.append(cell.name)
            s_ins = out_shapes[-cell.num_inputs():]
            with StrategyManagerDefault(strategy_name):
                s_out = cell.build(s_ins.copy(),
                                   features_mul=self.features_mul,
                                   features_fixed=self.features_first_cell if i == 0 else -1)
            out_shapes.extend(s_out)
            rows.append(get_row(i, cell_name, cell))
            self.cells.append(cell)

            # optional (aux) head after every cell
            head = self._head_positions.get(i, None)
            if head is not None:
                if head.weight > 0:
                    final_out_shapes.append(head.build(s_out[-1], s_out_data))
                    rows.append(get_row('', '-', head))
                else:
                    LoggerManager().get_logger().info('not adding head after cell %d, weight <= 0' % i)
                    del self._head_positions[i]
            else:
                assert i != len(self.cell_order) - 1, "Must have a head after the final cell"

        # remove heads that are impossible to add
        for i in self._head_positions.keys():
            if i >= len(self.cells):
                LoggerManager().get_logger().warning('Can not add a head after cell %d which does not exist, deleting the head!' % i)
                head = self._head_positions.get(i)
                for j, head2 in enumerate(self.heads):
                    if head is head2:
                        self.heads.__delitem__(j)
                        break

        s_out = ShapeList(final_out_shapes)
        rows.append(('complete network', '', '', self.get_shape_in().str(), '', s_out.str(), count_parameters(self)))
        log_in_columns(LoggerManager().get_logger(), rows, start_space=4)
        self.set(cell_order=updated_cell_order)
        return s_out
コード例 #3
0
ファイル: verify.py プロジェクト: Light-Reflection/uninas
def verify():
    logger = LoggerManager().get_logger()

    parser = argparse.ArgumentParser('get_network')
    parser.add_argument('--config_path', type=str, default='FairNasC')
    parser.add_argument('--weights_path', type=str, default='{path_tmp}/s3/')
    parser.add_argument('--data_dir',
                        type=str,
                        default='{path_data}/ImageNet_ILSVRC2012/')
    parser.add_argument('--data_batch_size', type=int, default=128)
    parser.add_argument('--data_num_workers', type=int, default=8)
    parser.add_argument('--num_batches',
                        type=int,
                        default=-1,
                        help='>0 to stop early')
    args, _ = parser.parse_known_args()

    # ImageNet with default augmentations / cropping
    data_set = get_imagenet(
        data_dir=args.data_dir,
        num_workers=args.data_num_workers,
        batch_size=args.data_batch_size,
        aug_dict={
            "cls_augmentations": "TimmImagenetAug",
            "DartsImagenetAug#0.crop_size": 224,
        },
    )

    # network
    network = get_network(args.config_path, data_set.get_data_shape(),
                          data_set.get_label_shape(), args.weights_path)
    network.eval()
    network = network.cuda()

    # measure the accuracy
    top1, top5, num_samples = 0, 0, 0
    with torch.no_grad():
        for i, (data, targets) in enumerate(data_set.test_loader()):
            if i >= args.num_batches > 0:
                break
            outputs = network(data.cuda())
            t1, t5 = accuracy(outputs, targets.cuda(), topk=(1, 5))
            n = data.size(0)
            top1 += t1 * n
            top5 += t5 * n
            num_samples += n

    logger.info('results:')
    logger.info('\ttested images: %d' % num_samples)
    logger.info('\ttop1: %.4f (%d/%d)' %
                (top1 / num_samples, top1, num_samples))
    logger.info('\ttop5: %.4f (%d/%d)' %
                (top5 / num_samples, top5, num_samples))
コード例 #4
0
ファイル: checkpoint.py プロジェクト: Light-Reflection/uninas
 def find_last_checkpoint_path(cls,
                               save_dir: str,
                               index=0,
                               try_general_checkpoint=True) -> str:
     """
     attempt finding the checkpoint path in a dir,
     if 'save_dir' is a file, return it
     if there is a general checkpoint and 'try_general_checkpoint', return its path
     otherwise try finding the most recent checkpoint of the CheckpointCallback with index 'index'
     """
     save_dir = replace_standard_paths(save_dir)
     # try as path and general path
     if os.path.isfile(save_dir):
         return save_dir
     if try_general_checkpoint and os.path.isfile(
             cls._general_checkpoint_file(save_dir)):
         return cls._general_checkpoint_file(save_dir)
     # try by index
     lst = sorted(cls.list_infos(save_dir, index),
                  key=lambda inf: inf.checkpoint_path)
     if len(lst) > 0:
         return lst[-1].checkpoint_path
     # try to find any checkpoint.pt in dir
     for path in glob.glob('%s/**/checkpoint.pt' % save_dir,
                           recursive=True):
         return path
     # failure
     LoggerManager().get_logger().info(
         'Can not find a uninas checkpoint (history) in: %s' % save_dir)
     return ''
コード例 #5
0
ファイル: checkpoint.py プロジェクト: Light-Reflection/uninas
    def load_network(cls,
                     file_path: str,
                     network: nn.Module,
                     num_replacements=1) -> bool:
        """
        load network checkpoint from method checkpoint file
        replace parts of the param names to match the requirements
        """
        checkpoint = cls.load_last_checkpoint(file_path)
        if len(checkpoint) > 0:
            state_dict, net_state_dict = checkpoint.get(
                'state_dict', checkpoint), {}

            # map state dict keys accordingly
            key_mappings = {'net.': ''}
            for key0, v in state_dict.items():
                key1 = key0
                for k0, k1 in key_mappings.items():
                    key1 = key1.replace(k0, k1, num_replacements)
                net_state_dict[key1] = v
            network.load_state_dict(net_state_dict, strict=True)

            LoggerManager().get_logger().info('Loaded weights from file: %s' %
                                              file_path)
            return True
        else:
            return False
コード例 #6
0
 def _build(self, s_in: Shape, s_out: Shape) -> ShapeList:
     """ build the network, count params, log, maybe load pretrained weights """
     s_in_net = s_in.copy(copy_id=True)
     super()._build(s_in, s_out)
     rows = [('cell index', 'input shapes', 'output shapes', '#params'),
             ('stem', s_in.str(), self.get_stem_output_shape(), count_parameters(self.get_stem()))]
     LoggerManager().get_logger().info('%s (%s):' % (self.__class__.__name__, self.model_name))
     for i, (s_in, s_out, cell) in enumerate(zip(self.get_cell_input_shapes(flatten=False),
                                                 self.get_cell_output_shapes(flatten=False), self.get_cells())):
         rows.append((i, s_in.str(), s_out.str(), count_parameters(cell)))
     rows.append(('head(s)', self.get_heads_input_shapes(), self.get_network_output_shapes(flatten=False),
                  count_parameters(self.get_heads())))
     rows.append(("complete network", s_in_net.str(), self.get_network_output_shapes(flatten=False),
                  count_parameters(self)))
     log_in_columns(LoggerManager().get_logger(), rows, start_space=4)
     return self.get_network_output_shapes(flatten=False)
コード例 #7
0
    def __init__(self, args: Namespace, wildcards: dict, descriptions: dict = None):
        super().__init__()

        # args, seed
        self.args = args
        self.save_dir = self._parsed_argument('save_dir', args)
        self.is_test_run = self._parsed_argument('is_test_run', args)
        self.seed = self._parsed_argument('seed', args)
        self.is_deterministic = self._parsed_argument('is_deterministic', args)
        random.seed(self.seed)
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        if self.is_deterministic:
            # see https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility
            os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
            torch.set_deterministic(self.is_deterministic)

        # maybe delete old dir, note arguments, save run_config
        if self._parsed_argument('save_del_old', args):
            shutil.rmtree(self.save_dir, ignore_errors=True)
        os.makedirs(self.save_dir, exist_ok=True)
        save_as_json(args, get_task_config_path(self.save_dir), wildcards)
        dump_system_info(self.save_dir + 'sysinfo.txt')

        # logging
        self.log_file = '%slog_task.txt' % self.save_dir
        LoggerManager().set_logging(default_save_file=self.log_file)
        self.logger = self.new_logger(index=None)
        log_args(self.logger, None, self.args, add_git_hash=True, descriptions=descriptions)
        Register.log_all(self.logger)

        # reset weight strategies so that consecutive tasks do not conflict with each other
        StrategyManager().reset()

        self.methods = []
コード例 #8
0
 def run(self):
     try:
         path = replace_standard_paths('{path_tmp}/tmp.run_config')
         self.interactive.to_json(path)
         Main.new_task(path).run()
     except Exception as e:
         LoggerManager().get_logger().error(str(e), exc_info=e)
         tkm.showwarning(message=str(e))
コード例 #9
0
ファイル: checkpoint.py プロジェクト: Light-Reflection/uninas
 def load(cls, file_path: str, pl_module: AbstractMethod = None) -> dict:
     """ load method checkpoint from method checkpoint file and return it """
     file_path = replace_standard_paths(file_path)
     if os.path.isfile(file_path):
         LoggerManager().get_logger().info('Found checkpoint: %s' %
                                           file_path)
         checkpoint = torch.load(file_path)
         if pl_module is not None:
             pl_module.load_state_dict(checkpoint['state_dict'])
             pl_module.on_load_checkpoint(checkpoint)
             LoggerManager().get_logger().info(
                 'Loaded weights from file: %s' % file_path)
         return checkpoint
     else:
         LoggerManager().get_logger().info(
             'Can not load weights, does not exist / not a file: %s' %
             file_path)
         return {}
コード例 #10
0
 def profile_macs(self) -> np.int64:
     """ profile the macs for a single forward pass on a single data point """
     macs = -1
     try:
         macs = self.net.profile_macs()
     except Exception as e:
         LoggerManager().get_logger().error(
             "Failed profiling macs:\n%s\n..." % str(e)[:500])
     return macs
コード例 #11
0
ファイル: tabular.py プロジェクト: Light-Reflection/uninas
def explore(mini: MiniNASTabularBenchmark, logger: Logger = None, n=-1, sort_by='acc1', maximize=True):
    if logger is None:
        logger = LoggerManager().get_logger()
    log_headline(logger, "highest acc1 topologies (%s, %s, %s)"
                 % (mini.get_name(), mini.get_default_data_set(), mini.get_default_result_type()))
    rows = [("%s rank" % sort_by, "acc1", "loss", "params", "flops", "latency", "tuple")]
    for i, r in enumerate(mini.get_all_sorted([sort_by], [maximize])):
        if i >= n > 0:
            break
        rows.append((i, r.get_acc1(), r.get_loss(), r.get_params(), r.get_flops(), r.get_latency(), r.arch_tuple))
    log_in_columns(logger, rows)
コード例 #12
0
ファイル: abstract.py プロジェクト: Light-Reflection/uninas
 def __init__(self,
              profile_fun: AbstractProfileFunction = None,
              is_test_run=False,
              **__):
     super().__init__()
     self.data = dict(meta=dict(cls=self.__class__.__name__))
     if profile_fun is not None:
         self.set_all(profile_cls=profile_fun.__class__.__name__,
                      is_test_run=is_test_run)
     self.profile_fun = profile_fun
     self.logger = LoggerManager().get_logger()
コード例 #13
0
ファイル: abstract2.py プロジェクト: Light-Reflection/uninas
 def get_logger(cls,
                logger=None,
                is_test_run=False,
                save_dir='/tmp/',
                suffix=''):
     """ new logger if required """
     if logger is not None:
         return logger
     return LoggerManager().get_logger(
         default_level=logging.DEBUG if is_test_run else logging.INFO,
         save_file='%slog_trainer%s.txt' % (save_dir, suffix))
コード例 #14
0
 def get_weights_criterion(self) -> (list, AbstractCriterion):
     weights = self.net.get_head_weightings()
     cls_criterion = self._parsed_meta_argument(Register.criteria,
                                                'cls_criterion',
                                                self.hparams,
                                                index=None)
     criterion = cls_criterion(weights, self.hparams, self.data_set)
     if len(weights) > 1:
         LoggerManager().get_logger().info("Weighting model heads: %s" %
                                           str(weights))
     return weights, criterion
コード例 #15
0
    def __init__(self, hparams: Namespace):
        super().__init__(hparams)
        self.update_architecture_weights = True
        self.train_loader = None

        # mask
        for idx in split(self._parsed_argument('mask_indices', self.hparams),
                         cast_fun=int):
            self.strategy_manager.mask_index(idx)
            LoggerManager().get_logger().info(
                "Globally masking arc choices with index %d" % idx)
コード例 #16
0
 def _cb_add_str(self, *_, **__):
     if self._config_enabled:
         try:
             self.interactive.add_meta_value(self.node.name,
                                             self.meta.argument.name,
                                             self.var_add_str.get())
         except Exception as e:
             LoggerManager().get_logger().error(str(e), exc_info=e)
             tkm.showwarning(message=str(e))
         self.var_add_str.set(self.str_add)
     self.update_content()
     self.update()
コード例 #17
0
ファイル: abstract.py プロジェクト: Light-Reflection/uninas
 def from_file(cls, file_path: str):
     """ create and load a profiler from a profiler save file """
     file_path = replace_standard_paths(file_path)
     assert os.path.isfile(
         file_path), "File does not exist: %s" % str(file_path)
     cls_name = torch.load(file_path).get('meta').get('cls')
     profiler = Register.profilers.get(cls_name)()
     profiler.load(file_path)
     if profiler.get('is_test_run'):
         LoggerManager().get_logger().warning(
             "Loading profiler data from a file created in a test run!")
     return profiler
コード例 #18
0
 def run(self) -> 'AbstractTask':
     """ execute the task """
     try:
         self._run()
         for method in self.methods:
             method.flush_logging()
         self.logger.info("Done!")
         return self
     except Exception as e:
         raise e
     finally:
         LoggerManager().cleanup()
コード例 #19
0
 def add_cells_from_config(self, config: dict):
     """ add all cell types in the given config """
     for name, cfg in config.get('kwargs').get('cell_configs').items():
         already_had = False
         if name in self._cell_partials.keys():
             already_had = True
             self._cell_partials.pop(name)
         if name in self.cell_configs.keys():
             already_had = True
         if already_had:
             LoggerManager().get_logger().info('%s cell type "%s" from given config' % ('Replaced' if already_had else 'Added', name))
         self.cell_configs[name] = cfg
コード例 #20
0
 def save_config(cls, config: dict, config_dir: str,
                 config_name: str) -> str:
     if config is None:
         return ''
     os.makedirs(config_dir, exist_ok=True)
     path = '%s/%s%s' % (replace_standard_paths(config_dir), config_name,
                         cls.extension_net_config)
     path = os.path.abspath(path)
     with open(path, 'w+') as outfile:
         json.dump(config, outfile, ensure_ascii=False, indent=2)
     LoggerManager().get_logger().info('Wrote net config to %s' % path)
     return path
コード例 #21
0
 def _cb_rem_str(self, *_, **__):
     if self._config_enabled:
         splits = self.var_rem_str.get().split('#')
         idx = 0 if len(splits) < 2 else int(splits[1])
         try:
             self.interactive.remove_meta_index(self.node.name,
                                                self.meta.argument.name,
                                                idx)
         except Exception as e:
             LoggerManager().get_logger().error(str(e), exc_info=e)
             tkm.showwarning(message=str(e))
     self.update_content()
     self.update()
コード例 #22
0
 def save_as_config(self, path=None):
     if path is None:
         path = filedialog.asksaveasfilename(initialdir=self._init_dir,
                                             initialfile=name_task_config,
                                             title="Select file",
                                             filetypes=(("run config files",
                                                         "*.run_config"), ))
     if isinstance(path, str) and not os.path.isdir(path) and len(path) > 0:
         try:
             self.interactive.to_json(path)
         except Exception as e:
             LoggerManager().get_logger().error(str(e), exc_info=e)
             tkm.showwarning(message=str(e))
         self._init_dir = os.path.dirname(path)
         self._save_path = path
コード例 #23
0
ファイル: retrain.py プロジェクト: Light-Reflection/uninas
    def _build2(self, s_in: Shape, s_out: Shape) -> ShapeList:
        """ build the network """

        # find the search config
        if not os.path.isfile(self.search_config_path):
            self.search_config_path = Builder.find_net_config_path(
                self.search_config_path, pattern='search')

        # create a temporary search strategy
        tmp_s = RandomChoiceStrategy(max_epochs=1, name='__tmp__')
        sm = StrategyManager()
        assert len(sm.get_strategies_list(
        )) == 0, "can not load when there already is a search network"
        sm.add_strategy(tmp_s)
        sm.set_fixed_strategy_name('__tmp__')

        # create a search network
        search_net = Register.builder.load_from_config(self.search_config_path)
        assert isinstance(search_net, SearchUninasNetwork)
        search_net.build(s_in, s_out)
        search_net.set_forward_strategy(False)

        # set the architecture, get the config
        req_gene = ""
        if self.gene == 'random':
            search_net.forward_strategy()
            gene = sm.get_all_finalized_indices(unique=True, flat=True)
            self.model_name = "random(%s)" % str(gene)
            req_gene = " (%s)" % self.gene
        else:
            gene = split(self.gene, int)
        l0, l1 = len(sm.get_all_finalized_indices(unique=True)), len(gene)
        assert l0 == l1, "number of unique choices in the network (%d) must match length of the gene (%d)" % (
            l0, l1)
        search_net.forward_strategy(fixed_arc=gene)
        config = search_net.config(finalize=True)

        # clean up
        sm.delete_strategy('__tmp__')
        del sm
        del search_net

        # build the actually used finalized network
        LoggerManager().get_logger().info(
            "Extracting architecture %s%s from the super-network" %
            (gene, req_gene))
        self.net = Register.builder.from_config(config)
        return self.net.build(s_in, s_out)
コード例 #24
0
ファイル: checkpoint.py プロジェクト: Light-Reflection/uninas
 def save(cls,
          file_path: str,
          pl_module: AbstractMethod,
          update_dict: dict = None) -> dict:
     """
     save method checkpoint to file, not tracking it
     """
     os.makedirs(os.path.dirname(file_path), exist_ok=True)
     checkpoint = dict(state_dict=pl_module.state_dict())
     if isinstance(update_dict, dict):
         checkpoint.update(update_dict)
     pl_module.on_save_checkpoint(checkpoint)
     cls.atomic_save(file_path, checkpoint)
     LoggerManager().get_logger().info('Saved weights to file: %s' %
                                       file_path)
     return checkpoint
コード例 #25
0
 def load_add_config(self, add=True):
     path = filedialog.askopenfilename(initialdir=self._init_dir,
                                       initialfile=name_task_config,
                                       title="Select file",
                                       filetypes=(("run config files",
                                                   "*.run_config"), ))
     if isinstance(path, str) and os.path.isfile(path) and len(path) > 0:
         self.enable_config_changes(False)
         if not add:
             self.reset()
         try:
             self.interactive.from_json(path)
         except Exception as e:
             LoggerManager().get_logger().error(str(e), exc_info=e)
             tkm.showwarning(message=str(e))
         self.update_content()
         self.update()
コード例 #26
0
def maybe_download(path_or_url: str, file_type: FileType = FileType.MISC) -> Union[str, None]:
    """
    if the file does not locally exist at the given path,
    try to use a cached download, otherwise download it,
    then return the path
    """
    path_or_url = replace_standard_paths(path_or_url)
    if os.path.isfile(path_or_url):
        return path_or_url
    else:
        try:
            os.makedirs(file_type.value, exist_ok=True)
            file_path = '%s/%s' % (file_type.value, path_or_url.split('/')[-1])
            if not os.path.isfile(file_path):
                urlretrieve(path_or_url, file_path)
                LoggerManager().get_logger().info("downloaded %s to %s" % (path_or_url, file_path))
            return file_path
        except:
            return None
コード例 #27
0
    def explore(mini: MiniNATSBenchTabularBenchmark):
        logger = LoggerManager().get_logger()

        # some stats of specific results
        logger.info(
            mini.get_by_arch_tuple((4, 3, 2, 1, 0, 2)).get_info_str('cifar10'))
        logger.info("")
        mini.get_by_arch_tuple((1, 2, 1, 2, 3, 4)).print(logger.info)
        logger.info("")
        mini.get_by_index(1554).print(logger.info)
        logger.info("")
        mini.get_by_arch_str(
            '|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|'
        ).print(logger.info)
        logger.info("")

        # best results by acc1
        rows = [("acc1", "params", "arch tuple", "arch str")]
        log_headline(
            logger, "highest acc1 topologies (%s, %s, %s)" %
            (mini.get_name(), mini.get_default_data_set(),
             mini.get_default_result_type()))
        for i, r in enumerate(mini.get_all_sorted(['acc1'], [True])):
            rows.append(
                (r.get_acc1(), r.get_params(), str(r.arch_tuple), r.arch_str))
            if i > 8:
                break
        log_in_columns(logger, rows)
        logger.info("")

        # best results by acc1
        rows = [("acc1", "arch tuple", "arch str")]
        c = 0
        log_headline(
            logger, "highest acc1 topologies without skip (%s, %s, %s)" %
            (mini.get_name(), mini.get_default_data_set(),
             mini.get_default_result_type()))
        for i, r in enumerate(mini.get_all_sorted(['acc1'], [True])):
            if 1 not in r.arch_tuple:
                rows.append((r.get_acc1(), str(r.arch_tuple), r.arch_str))
                c += 1
            if c > 9:
                break
        log_in_columns(logger, rows)
コード例 #28
0
ファイル: abstract.py プロジェクト: Light-Reflection/uninas
 def initialize_weights(self, net: AbstractModule):
     logger = LoggerManager().get_logger()
     logger.info('Initializing: %s' % self.__class__.__name__)
     self._initialize_weights(net, logger)
コード例 #29
0
ファイル: abstract.py プロジェクト: Light-Reflection/uninas
    def __init__(self, data_dir: str, save_dir: Union[str, None],
                 bs_train: int, bs_test: int,
                 train_transforms: transforms.Compose, test_transforms: transforms.Compose,
                 train_batch_aug: Union[BatchAugmentations, None], test_batch_aug: Union[BatchAugmentations, None],
                 num_workers: int, num_prefetch: int,
                 valid_split: Union[int, float], valid_shuffle: bool,
                 fake: bool, download: bool,
                 **additional_args):
        """

        :param data_dir: where to find (or download) the data set
        :param save_dir: global save dir, can store and reuse the info which data was used in the random valid split
        :param bs_train: batch size for the train loader
        :param bs_test: batch size for the test loader, <= 0 to have the same as bs_train
        :param train_transforms: train augmentations (on each data point individually)
        :param test_transforms: test augmentations (on each data point individually)
        :param train_batch_aug: train augmentations (across the entire batch)
        :param test_batch_aug: test augmentations (across the entire batch)
        :param num_workers: number of workers prefetching data
        :param num_prefetch: number of batches prefetched by every worker
        :param valid_split: absolute number of data points if int or >1, otherwise a fraction of the training set
        :param valid_shuffle: whether to shuffle validation data
        :param fake: use fake data instead (no need to provide either real data or enabling downloading)
        :param download: whether downloading is allowed
        :param additional_args: arguments that are added and used by child classes
        """
        super().__init__()
        logger = LoggerManager().get_logger()
        self.dir = data_dir
        self.bs_train = bs_train
        self.bs_test = bs_test if bs_test > 0 else self.bs_train
        self.num_workers, self.num_prefetch = num_workers, num_prefetch
        self.valid_shuffle = valid_shuffle
        self.additional_args = additional_args

        self.fake = fake
        self.download = download and not self.fake
        if self.download and (not self.can_download):
            LoggerManager().get_logger().warning("The dataset can not be downloaded, but may be asked to.")

        self.train_transforms = train_transforms
        self.test_transforms = test_transforms
        self.train_batch_augmentations = train_batch_aug
        self.test_batch_augmentations = test_batch_aug

        # load/create meta info dict
        if isinstance(save_dir, str) and len(save_dir) > 0:
            meta_path = '%s/data.meta.pt' % replace_standard_paths(save_dir)
            if os.path.isfile(meta_path):
                meta = torch.load(meta_path)
            else:
                meta = defaultdict(dict)
        else:
            meta, meta_path = defaultdict(dict), None

        # give subclasses a good spot to react to additional arguments
        self._before_loading()

        # data
        if self.fake:
            train_data = self._get_fake_train_data(self.train_transforms)
            self.test_data = self._get_fake_test_data(self.test_transforms)
        else:
            train_data = self._get_train_data(self.train_transforms)
            self.test_data = self._get_test_data(self.test_transforms)

        # split train into train+valid or using stand-alone valid set
        if valid_split > 0:
            s1 = int(valid_split) if valid_split >= 1 else int(len(train_data)*valid_split)
            if s1 >= len(train_data):
                logger.warning("Tried to set valid split larger than the training set size, setting to 0.5")
                s1 = len(train_data)//2
            s0 = len(train_data) - s1
            if meta['splits'].get((s0, s1), None) is None:
                meta['splits'][(s0, s1)] = torch.randperm(s0+s1).tolist()
            indices = meta['splits'][(s0, s1)]
            self.valid_data = torch.utils.data.Subset(train_data, np.array(indices[s0:]).astype(np.int))
            train_data = torch.utils.data.Subset(train_data, np.array(indices[0:s0]).astype(np.int))
            logger.info('Data Set: splitting training set, will use %s data points as validation set' % s1)
            if self.length[1] > 0:
                logger.info('Data Set: a dedicated validation set exists, but it will be replaced.')
        elif self.length[1] > 0:
            if self.fake:
                self.valid_data = self._get_fake_valid_data(self.test_transforms)
            else:
                self.valid_data = self._get_valid_data(self.test_transforms)
            logger.info('Data Set: using the dedicated validation set with test augmentations')
        else:
            self.valid_data = None
            logger.info('Data Set: not using a validation set at all.')
        self.train_data = train_data

        # shapes
        data, label = self.train_data[0]
        self.data_shape = Shape(list(data.shape))

        # save meta info dict
        if meta_path is not None:
            torch.save(meta, meta_path)
コード例 #30
0
ファイル: pbt.py プロジェクト: Light-Reflection/uninas
 def log(cls, msg: str):
     LoggerManager().get_logger().info('%s: %s' % (cls.__name__, msg))