示例#1
0
    def __init__(self, args):
        class Config:
            pass
        self._config = Config()
        self._config.context = 'cpu'
        self._config.device_id = 0
        self._config.columns_size = 2
        self._config.x_length = args.x_length
        self._config.mlp_model_params_path = args.mlp_model_params_path
        self._config.labels_path = args.labels_path

        if not os.path.isfile(self._config.mlp_model_params_path):
            logger.error("Model params path {} is not found.".format(self._config.mlp_model_params_path))
            sys.exit(-1)
        else:
            logger.info("Path of the model parameters file is {}.".format(self._config.mlp_model_params_path))
        if not os.path.isfile(self._config.labels_path):
            logger.error("Labels path {} is not found.".format(self._config.labels_path))
            sys.exit(-1)
        else:
            logger.info("Path of the labels file is {}.".format(self._config.labels_path))

        seed(0)
        logger.info("Running in %s" % self._config.context)
        self._ctx = get_extension_context(self._config.context, device_id = self._config.device_id)
        nn.set_default_context(self._ctx)
        nn.clear_parameters()
        self._mlp = MLP(self._config)
        self._labels = None
        with open(self._config.labels_path) as f:
            self._labels = f.readlines()
        self._points_buf = pointsbuffer.PointsBuffer()
示例#2
0
    def __init__(self, celeb_name=None, data_dir=None, mode="all", shuffle=True, rng=None, resize_size=(64, 64), line_thickness=3, gaussian_kernel=(5, 5), gaussian_sigma=3):

        self.resize_size = resize_size
        self.line_thickness = line_thickness
        self.gaussian_kernel = gaussian_kernel
        self.gaussian_sigma = gaussian_sigma

        celeb_name_list = ['Donald_Trump', 'Emmanuel_Macron',
                           'Jack_Ma', 'Kathleen', 'Theresa_May']
        assert celeb_name in celeb_name_list
        self.data_dir = data_dir
        self._shuffle = shuffle
        self.mode = mode
        self.celeb_name = celeb_name

        self.imgs_root_path = os.path.join(self.data_dir, self.celeb_name)
        if not os.path.exists(self.imgs_root_path):
            logger.error('{} is not exists.'.format(self.imgs_root_path))

        # use an annotation file to know how many images are needed.
        self.ant, self._size = self.get_ant_and_size(
            self.imgs_root_path, self.mode)
        logger.info(f'the number of images for {self.mode}: {self._size}')

        self._variables = list()
        self._ref_files = dict()

        self.reset()
示例#3
0
def import_extension_module(ext_name):
    """
    Import an extension module by name.

    The extension modules are installed under the `nnabla_ext` package as
    namespace packages. All extension modules provide a unified set of APIs.

    Args:
        ext_name(str): Extension name. e.g. 'cpu', 'cuda', 'cudnn' etc.

    Returns: module
        An Python module of a particular NNabla extension.

    Example:

        .. code-block:: python

            ext = import_extension_module('cudnn')
            available_devices = ext.get_devices()
            print(available_devices)
            ext.device_synchronize(available_devices[0])
            ext.clear_memory_cache()

    """
    import importlib
    try:
        return importlib.import_module('.' + ext_name, 'nnabla_ext')
    except ImportError as e:
        from nnabla import logger
        logger.error('Extension `{}` does not exist.'.format(ext_name))
        raise e
示例#4
0
def main():
    # argparse
    args = get_args()

    # Context Setting
    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(
        args.context, device_id=args.device_id)
    nn.set_default_context(ctx)

    model_path = args.model

    if args.train:
        # Data Loading
        logger.info("Initialing DataSource.")
        train_iterator = facade.facade_data_iterator(
            args.traindir,
            args.batchsize,
            shuffle=True,
            with_memory_cache=False)
        val_iterator = facade.facade_data_iterator(
            args.valdir,
            args.batchsize,
            random_crop=False,
            shuffle=False,
            with_memory_cache=False)

        monitor = nm.Monitor(args.logdir)
        solver_gen = S.Adam(alpha=args.lrate, beta1=args.beta1)
        solver_dis = S.Adam(alpha=args.lrate, beta1=args.beta1)

        generator = unet.generator
        discriminator = unet.discriminator

        model_path = train(generator, discriminator, args.patch_gan,
                           solver_gen, solver_dis,
                           args.weight_l1, train_iterator, val_iterator,
                           args.epoch, monitor, args.monitor_interval)

    if args.generate:
        if model_path is not None:
            # Data Loading
            logger.info("Generating from DataSource.")
            test_iterator = facade.facade_data_iterator(
                args.testdir,
                args.batchsize,
                shuffle=False,
                with_memory_cache=False)
            generator = unet.generator
            generate(generator, model_path, test_iterator, args.logdir)
        else:
            logger.error("Trained model was NOT given.")
示例#5
0
def celebv_data_iterator(dataset_mode=None, celeb_name=None, data_dir=None, ref_dir=None,
                         mode="all", batch_size=1, shuffle=False, rng=None,
                         with_memory_cache=False, with_file_cache=False,
                         resize_size=(64, 64), line_thickness=3, gaussian_kernel=(5, 5), gaussian_sigma=3
                         ):

    if dataset_mode == 'transformer':
        if ref_dir:
            assert os.path.exists(ref_dir), f'{ref_dir} not found.'
            logger.info(
                'CelebV Dataiterator using reference .npz file for Transformer is created.')
            return data_iterator(CelebVDataRefSource(
                                celeb_name=celeb_name, data_dir=data_dir, ref_dir=ref_dir,
                                need_image=False, need_heatmap=True, need_resized_heatmap=False,
                                mode=mode, shuffle=shuffle, rng=rng),
                                batch_size, rng, with_memory_cache, with_file_cache)

        else:
            logger.info('CelebV Dataiterator for Transformer is created.')
            return data_iterator(CelebVDataSource(
                            celeb_name=celeb_name, data_dir=data_dir,
                            need_image=False, need_heatmap=True, need_resized_heatmap=False,
                            mode=mode, shuffle=shuffle, rng=rng,
                            resize_size=resize_size, line_thickness=line_thickness,
                            gaussian_kernel=gaussian_kernel, gaussian_sigma=gaussian_sigma),
                            batch_size, rng, with_memory_cache, with_file_cache)

    elif dataset_mode == 'decoder':
        if ref_dir:
            assert os.path.exists(ref_dir), f'{ref_dir} not found.'
            logger.info(
                'CelebV Dataiterator using reference .npz file for Decoder is created.')
            return data_iterator(CelebVDataRefSource(
                                celeb_name=celeb_name, data_dir=data_dir, ref_dir=ref_dir,
                                need_image=True, need_heatmap=True, need_resized_heatmap=True,
                                mode=mode, shuffle=shuffle, rng=rng),
                                batch_size, rng, with_memory_cache, with_file_cache)

        else:
            logger.info('CelebV Dataiterator for Decoder is created.')
            return data_iterator(CelebVDataSource(
                            celeb_name=celeb_name, data_dir=data_dir,
                            need_image=True, need_heatmap=True, need_resized_heatmap=True,
                            mode=mode, shuffle=shuffle, rng=rng,
                            resize_size=resize_size, line_thickness=line_thickness,
                            gaussian_kernel=gaussian_kernel, gaussian_sigma=gaussian_sigma),
                            batch_size, rng, with_memory_cache, with_file_cache)

    else:
        logger.error(
            'Specified Dataitaretor is wrong?  given: {}'.format(dataset_mode))
        import sys
        sys.exit()
def main():
    HERE = os.path.dirname(__file__)
    # Import MNIST data
    sys.path.append(
        os.path.realpath(os.path.join(HERE, '..', '..', 'vision', 'mnist')))
    from mnist_data import data_iterator_mnist
    from args import get_args
    from classification import mnist_lenet_prediction, mnist_resnet_prediction

    args = get_args(description=__doc__)

    mnist_cnn_prediction = mnist_lenet_prediction
    if args.net == 'resnet':
        mnist_cnn_prediction = mnist_resnet_prediction

    # Infer parameter file name and read it.
    model_save_path = os.path.join('../../vision/mnist',
                                   args.model_save_path)
    parameter_file = os.path.join(
        model_save_path,
        '{}_params_{:06}.h5'.format(args.net, args.max_iter))
    try:
        nn.load_parameters(parameter_file)
    except IOError:
        logger.error("Run classification.py before runnning this script.")
        exit(1)

    # Create a computation graph to be saved.
    image = nn.Variable([args.batch_size, 1, 28, 28])
    pred = mnist_cnn_prediction(image, test=True)

    # Save NNP file (used in C++ inference later.).
    nnp_file = '{}_{:06}.nnp'.format(args.net, args.max_iter)
    runtime_contents = {
        'networks': [
            {'name': 'runtime',
             'batch_size': args.batch_size,
             'outputs': {'y': pred},
             'names': {'x': image}}],
        'executors': [
            {'name': 'runtime',
             'network': 'runtime',
             'data': ['x'],
             'output': ['y']}]}
    nn.utils.save.save(nnp_file, runtime_contents)
def main():

    # Read envvar `NNABLA_EXAMPLES_ROOT` to identify the path to your local
    # nnabla-examples directory.
    HERE = os.path.dirname(__file__)
    nnabla_examples_root = os.environ.get(
        'NNABLA_EXAMPLES_ROOT',
        os.path.join(HERE, '../../../../nnabla-examples'))
    mnist_examples_root = os.path.realpath(
        os.path.join(nnabla_examples_root, 'mnist-collection'))
    sys.path.append(mnist_examples_root)
    nnabla_examples_git_url = 'https://github.com/sony/nnabla-examples'

    # Check if nnabla-examples found.
    try:
        from args import get_args
    except ImportError:
        print('An envvar `NNABLA_EXAMPLES_ROOT`'
              ' which locates the local path to '
              '[nnabla-examples]({})'
              ' repository must be set correctly.'.format(
                  nnabla_examples_git_url),
              file=sys.stderr)
        raise

    # Import MNIST data
    from mnist_data import data_iterator_mnist
    from classification import mnist_lenet_prediction, mnist_resnet_prediction

    args = get_args(description=__doc__)

    mnist_cnn_prediction = mnist_lenet_prediction
    if args.net == 'resnet':
        mnist_cnn_prediction = mnist_resnet_prediction

    # Infer parameter file name and read it.
    model_save_path = os.path.join(mnist_examples_root, args.model_save_path)
    parameter_file = os.path.join(
        model_save_path, '{}_params_{:06}.h5'.format(args.net, args.max_iter))
    try:
        nn.load_parameters(parameter_file)
    except IOError:
        logger.error("Run classification.py before running this script.")
        exit(1)

    # Create a computation graph to be saved.
    image = nn.Variable([args.batch_size, 1, 28, 28])
    pred = mnist_cnn_prediction(image, test=True)

    # Save NNP file (used in C++ inference later.).
    nnp_file = '{}_{:06}.nnp'.format(args.net, args.max_iter)
    runtime_contents = {
        'networks': [{
            'name': 'runtime',
            'batch_size': args.batch_size,
            'outputs': {
                'y': pred
            },
            'names': {
                'x': image
            }
        }],
        'executors': [{
            'name': 'runtime',
            'network': 'runtime',
            'data': ['x'],
            'output': ['y']
        }]
    }
    nn.utils.save.save(nnp_file, runtime_contents)
示例#8
0
    def __init__(self, args):
        class Config:
            pass

        self._config = Config()
        self._config.context = args.context
        self._config.device_id = args.device_id
        self._config.process = 'infer'
        self._config.columns_size = 2
        self._config.x_length = args.x_length
        self._config.x_input_length = args.x_input_length
        self._config.x_output_length = args.x_output_length
        self._config.x_split_step = args.x_split_step
        self._config.width = args.width
        self._config.height = args.height
        self._config.lstm_unit_name = args.lstm_unit_name
        self._config.lstm_units = args.lstm_units
        self._config.batch_size = 1
        self._config.max_iter = 0
        self._config.learning_rate = 0.0
        self._config.weight_decay = 0.0
        self._config.val_interval = 0
        self._config.val_iter = 0
        self._config.monitor_path = '.'
        self._config.training_dataset_path = None
        self._config.validation_dataset_path = None
        self._config.evaluation_dataset_path = None

        seed(0)
        if self._config.context is None:
            self._config.context = 'cpu'
        logger.info("Running in %s" % self._config.context)
        self._ctx = get_extension_context(self._config.context,
                                          device_id=self._config.device_id)
        nn.set_default_context(self._ctx)

        self._net_type = args.net
        logger.info("Network type is {}.".format(self._net_type))
        self._mlp = None
        self._lenet = None
        self._lstm = None

        nn.clear_parameters()
        if self._net_type == 'mlp':
            self._config.model_params_path = args.mlp_model_params_path
            if not os.path.isfile(self._config.model_params_path):
                logger.error("Model params path {} is not found.".format(
                    self._config.model_params_path))
            else:
                logger.info("Path of the model parameters file is {}.".format(
                    self._config.model_params_path))
            self._mlp = MLP(self._config)
            self._mlp.init_for_infer()
        elif self._net_type == 'lenet':
            self._config.model_params_path = args.lenet_model_params_path
            if not os.path.isfile(self._config.model_params_path):
                logger.error("Model params path {} is not found.".format(
                    self._config.model_params_path))
            else:
                logger.info("Path of the model parameters file is {}.".format(
                    self._config.model_params_path))
            self._lenet = LeNet(self._config)
            self._lenet.init_for_infer()
        elif self._net_type == 'mlp-with-lstm':
            self._config.model_params_path = args.lstm_model_params_path
            if not os.path.isfile(self._config.model_params_path):
                logger.error("Model params path {} is not found.".format(
                    self._config.model_params_path))
            else:
                logger.info("Path of the model parameters file is {}.".format(
                    self._config.model_params_path))
            self._lstm = LSTM(self._config)
            self._lstm.init_for_infer()
            self._config.model_params_path = args.mlp_model_params_path
            if not os.path.isfile(self._config.model_params_path):
                logger.error("Model params path {} is not found.".format(
                    self._config.model_params_path))
            else:
                logger.info("Path of the model parameters file is {}.".format(
                    self._config.model_params_path))
            self._mlp = MLP(self._config)
            self._mlp.init_for_infer()
        else:
            raise ValueError("Unknown network type {}".format(self._net_type))
        self._labels = None
        self._labels_path = args.labels_path
        if not os.path.isfile(self._labels_path):
            logger.error("Labels path {} is not found.".format(
                self._labels_path))
        else:
            logger.info("Path of the labels file is {}.".format(
                self._labels_path))
            with open(self._labels_path) as f:
                self._labels = f.readlines()
        self._points_buf = pointsbuffer.PointsBuffer()