コード例 #1
0
ファイル: verify.py プロジェクト: Light-Reflection/uninas
def verify():
    logger = LoggerManager().get_logger()

    parser = argparse.ArgumentParser('get_network')
    parser.add_argument('--config_path', type=str, default='FairNasC')
    parser.add_argument('--weights_path', type=str, default='{path_tmp}/s3/')
    parser.add_argument('--data_dir',
                        type=str,
                        default='{path_data}/ImageNet_ILSVRC2012/')
    parser.add_argument('--data_batch_size', type=int, default=128)
    parser.add_argument('--data_num_workers', type=int, default=8)
    parser.add_argument('--num_batches',
                        type=int,
                        default=-1,
                        help='>0 to stop early')
    args, _ = parser.parse_known_args()

    # ImageNet with default augmentations / cropping
    data_set = get_imagenet(
        data_dir=args.data_dir,
        num_workers=args.data_num_workers,
        batch_size=args.data_batch_size,
        aug_dict={
            "cls_augmentations": "TimmImagenetAug",
            "DartsImagenetAug#0.crop_size": 224,
        },
    )

    # network
    network = get_network(args.config_path, data_set.get_data_shape(),
                          data_set.get_label_shape(), args.weights_path)
    network.eval()
    network = network.cuda()

    # measure the accuracy
    top1, top5, num_samples = 0, 0, 0
    with torch.no_grad():
        for i, (data, targets) in enumerate(data_set.test_loader()):
            if i >= args.num_batches > 0:
                break
            outputs = network(data.cuda())
            t1, t5 = accuracy(outputs, targets.cuda(), topk=(1, 5))
            n = data.size(0)
            top1 += t1 * n
            top5 += t5 * n
            num_samples += n

    logger.info('results:')
    logger.info('\ttested images: %d' % num_samples)
    logger.info('\ttop1: %.4f (%d/%d)' %
                (top1 / num_samples, top1, num_samples))
    logger.info('\ttop5: %.4f (%d/%d)' %
                (top5 / num_samples, top5, num_samples))
コード例 #2
0
ファイル: abstract.py プロジェクト: Light-Reflection/uninas
 def initialize_weights(self, net: AbstractModule):
     logger = LoggerManager().get_logger()
     logger.info('Initializing: %s' % self.__class__.__name__)
     self._initialize_weights(net, logger)
コード例 #3
0
ファイル: abstract.py プロジェクト: Light-Reflection/uninas
    def __init__(self, data_dir: str, save_dir: Union[str, None],
                 bs_train: int, bs_test: int,
                 train_transforms: transforms.Compose, test_transforms: transforms.Compose,
                 train_batch_aug: Union[BatchAugmentations, None], test_batch_aug: Union[BatchAugmentations, None],
                 num_workers: int, num_prefetch: int,
                 valid_split: Union[int, float], valid_shuffle: bool,
                 fake: bool, download: bool,
                 **additional_args):
        """

        :param data_dir: where to find (or download) the data set
        :param save_dir: global save dir, can store and reuse the info which data was used in the random valid split
        :param bs_train: batch size for the train loader
        :param bs_test: batch size for the test loader, <= 0 to have the same as bs_train
        :param train_transforms: train augmentations (on each data point individually)
        :param test_transforms: test augmentations (on each data point individually)
        :param train_batch_aug: train augmentations (across the entire batch)
        :param test_batch_aug: test augmentations (across the entire batch)
        :param num_workers: number of workers prefetching data
        :param num_prefetch: number of batches prefetched by every worker
        :param valid_split: absolute number of data points if int or >1, otherwise a fraction of the training set
        :param valid_shuffle: whether to shuffle validation data
        :param fake: use fake data instead (no need to provide either real data or enabling downloading)
        :param download: whether downloading is allowed
        :param additional_args: arguments that are added and used by child classes
        """
        super().__init__()
        logger = LoggerManager().get_logger()
        self.dir = data_dir
        self.bs_train = bs_train
        self.bs_test = bs_test if bs_test > 0 else self.bs_train
        self.num_workers, self.num_prefetch = num_workers, num_prefetch
        self.valid_shuffle = valid_shuffle
        self.additional_args = additional_args

        self.fake = fake
        self.download = download and not self.fake
        if self.download and (not self.can_download):
            LoggerManager().get_logger().warning("The dataset can not be downloaded, but may be asked to.")

        self.train_transforms = train_transforms
        self.test_transforms = test_transforms
        self.train_batch_augmentations = train_batch_aug
        self.test_batch_augmentations = test_batch_aug

        # load/create meta info dict
        if isinstance(save_dir, str) and len(save_dir) > 0:
            meta_path = '%s/data.meta.pt' % replace_standard_paths(save_dir)
            if os.path.isfile(meta_path):
                meta = torch.load(meta_path)
            else:
                meta = defaultdict(dict)
        else:
            meta, meta_path = defaultdict(dict), None

        # give subclasses a good spot to react to additional arguments
        self._before_loading()

        # data
        if self.fake:
            train_data = self._get_fake_train_data(self.train_transforms)
            self.test_data = self._get_fake_test_data(self.test_transforms)
        else:
            train_data = self._get_train_data(self.train_transforms)
            self.test_data = self._get_test_data(self.test_transforms)

        # split train into train+valid or using stand-alone valid set
        if valid_split > 0:
            s1 = int(valid_split) if valid_split >= 1 else int(len(train_data)*valid_split)
            if s1 >= len(train_data):
                logger.warning("Tried to set valid split larger than the training set size, setting to 0.5")
                s1 = len(train_data)//2
            s0 = len(train_data) - s1
            if meta['splits'].get((s0, s1), None) is None:
                meta['splits'][(s0, s1)] = torch.randperm(s0+s1).tolist()
            indices = meta['splits'][(s0, s1)]
            self.valid_data = torch.utils.data.Subset(train_data, np.array(indices[s0:]).astype(np.int))
            train_data = torch.utils.data.Subset(train_data, np.array(indices[0:s0]).astype(np.int))
            logger.info('Data Set: splitting training set, will use %s data points as validation set' % s1)
            if self.length[1] > 0:
                logger.info('Data Set: a dedicated validation set exists, but it will be replaced.')
        elif self.length[1] > 0:
            if self.fake:
                self.valid_data = self._get_fake_valid_data(self.test_transforms)
            else:
                self.valid_data = self._get_valid_data(self.test_transforms)
            logger.info('Data Set: using the dedicated validation set with test augmentations')
        else:
            self.valid_data = None
            logger.info('Data Set: not using a validation set at all.')
        self.train_data = train_data

        # shapes
        data, label = self.train_data[0]
        self.data_shape = Shape(list(data.shape))

        # save meta info dict
        if meta_path is not None:
            torch.save(meta, meta_path)
コード例 #4
0
    def explore(mini: MiniNATSBenchTabularBenchmark):
        logger = LoggerManager().get_logger()

        # some stats of specific results
        logger.info(
            mini.get_by_arch_tuple((4, 3, 2, 1, 0, 2)).get_info_str('cifar10'))
        logger.info("")
        mini.get_by_arch_tuple((1, 2, 1, 2, 3, 4)).print(logger.info)
        logger.info("")
        mini.get_by_index(1554).print(logger.info)
        logger.info("")
        mini.get_by_arch_str(
            '|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|'
        ).print(logger.info)
        logger.info("")

        # best results by acc1
        rows = [("acc1", "params", "arch tuple", "arch str")]
        log_headline(
            logger, "highest acc1 topologies (%s, %s, %s)" %
            (mini.get_name(), mini.get_default_data_set(),
             mini.get_default_result_type()))
        for i, r in enumerate(mini.get_all_sorted(['acc1'], [True])):
            rows.append(
                (r.get_acc1(), r.get_params(), str(r.arch_tuple), r.arch_str))
            if i > 8:
                break
        log_in_columns(logger, rows)
        logger.info("")

        # best results by acc1
        rows = [("acc1", "arch tuple", "arch str")]
        c = 0
        log_headline(
            logger, "highest acc1 topologies without skip (%s, %s, %s)" %
            (mini.get_name(), mini.get_default_data_set(),
             mini.get_default_result_type()))
        for i, r in enumerate(mini.get_all_sorted(['acc1'], [True])):
            if 1 not in r.arch_tuple:
                rows.append((r.get_acc1(), str(r.arch_tuple), r.arch_str))
                c += 1
            if c > 9:
                break
        log_in_columns(logger, rows)
コード例 #5
0
from uninas.utils.paths import replace_standard_paths
from uninas.register import Register


def example_export_network(path: str) -> AbstractUninasNetwork:
    """ create a new network and export it, does not require to have onnx installed """
    network = get_network("FairNasC",
                          Shape([3, 224, 224]),
                          Shape([1000]),
                          weights_path=None)
    network = network.cuda()
    network.export_onnx(path, export_params=True)
    return network


try:
    import onnx

    if __name__ == '__main__':
        logger = LoggerManager().get_logger()
        export_path = replace_standard_paths("{path_tmp}/onnx/FairNasC.onnx")
        net1 = example_export_network(export_path)

        log_headline(logger, "onnx graph")
        net2 = onnx.load(export_path)
        onnx.checker.check_model(net2)
        logger.info(onnx.helper.printable_graph(net2.graph))

except ImportError as e:
    Register.missing_import(e)