Exemple #1
0
def _context(proto):
    comm = current_communicator()
    if not proto.backends:
        logger.warn('Old-style context. Updating to new format.')
        # Update from old Context
        backends = [x.strip() for x in proto.backend.split('|')]
        compute_backends = [x.strip()
                            for x in proto.compute_backend.split('|')]
        if 'cuda' in backends:
            device_id = str(proto.device_id)
            if comm:
                device_id = str(comm.local_rank)

            if 'cudnn' in compute_backends:
                try:
                    import nnabla_ext.cudnn
                    ctx = nnabla_ext.cudnn.context(device_id=device_id)
                except ImportError:
                    logger.warn('Fallback to CPU context.')
                    import nnabla_ext.cpu
                    ctx = nnabla_ext.cpu.context()
            elif 'default' in compute_backends:
                try:
                    import nnabla_ext.cuda
                    ctx = nnabla_ext.cuda.context(device_id=device_id)
                except ImportError:
                    logger.warn('Fallback to CPU context.')
                    import nnabla_ext.cpu
                    ctx = nnabla_ext.cpu.context()
            else:
                raise ValueError(
                    'Invalid compute_backend {}'.format(proto.compute_backend))
        elif 'cpu' in backends:
            import nnabla_ext.cpu
            ctx = nnabla_ext.cpu.context()
        else:
            raise ValueError('Invalid context {}'.format(proto))
        ctx.array_class = str(proto.array_class)
        return ctx
    ctx = nn.Context()
    ctx.backend = proto.backends
    ctx.array_class = str(proto.array_class)

    if comm:
        ctx.device_id = str(comm.local_rank)
    else:
        ctx.device_id = str(proto.device_id)

    return ctx
Exemple #2
0
def load(filenames,
         prepare_data_iterator=True,
         batch_size=None,
         exclude_parameter=False,
         parameter_only=False,
         extension=".nntxt"):
    '''load
    Load network information from files.

    Args:
        filenames (list): file-like object or List of filenames.
        extension: if filenames is file-like object, extension is one of ".nntxt", ".prototxt", ".protobuf", ".h5", ".nnp".
    Returns:
        dict: Network information.
    '''
    class Info:
        pass

    info = Info()

    proto = nnabla_pb2.NNablaProtoBuf()

    # optimizer checkpoint
    opti_proto = nnabla_pb2.NNablaProtoBuf()
    OPTI_BUF_EXT = ['.optimizer']
    opti_h5_files = {}
    tmpdir = tempfile.mkdtemp()

    if isinstance(filenames, list) or isinstance(filenames, tuple):
        pass
    elif isinstance(filenames, str) or hasattr(filenames, 'read'):
        filenames = [filenames]

    for filename in filenames:
        if isinstance(filename, str):
            _, ext = os.path.splitext(filename)
        else:
            ext = extension

        # TODO: Here is some known problems.
        #   - Even when protobuf file includes network structure,
        #     it will not loaded.
        #   - Even when prototxt file includes parameter,
        #     it will not loaded.

        if ext in ['.nntxt', '.prototxt']:
            if not parameter_only:
                with get_file_handle_load(filename, ext) as f:
                    try:
                        text_format.Merge(f.read(), proto)
                    except:
                        logger.critical('Failed to read {}.'.format(filename))
                        logger.critical(
                            '2 byte characters may be used for file name or folder name.'
                        )
                        raise
            if len(proto.parameter) > 0:
                if not exclude_parameter:
                    nn.load_parameters(filename, extension=ext)
        elif ext in ['.protobuf', '.h5']:
            if not exclude_parameter:
                nn.load_parameters(filename, extension=ext)
            else:
                logger.info('Skip loading parameter.')

        elif ext == '.nnp':
            with get_file_handle_load(filename, ext) as nnp:
                for name in nnp.namelist():
                    _, ext = os.path.splitext(name)
                    if name == 'nnp_version.txt':
                        pass  # TODO currently do nothing with version.
                    elif ext in ['.nntxt', '.prototxt']:
                        if not parameter_only:
                            with nnp.open(name, 'r') as f:
                                text_format.Merge(f.read(), proto)
                        if len(proto.parameter) > 0:
                            if not exclude_parameter:
                                with nnp.open(name, 'r') as f:
                                    nn.load_parameters(f, extension=ext)
                    elif ext in ['.protobuf', '.h5']:
                        if not exclude_parameter:
                            with nnp.open(name, 'r') as f:
                                nn.load_parameters(f, extension=ext)
                        else:
                            logger.info('Skip loading parameter.')
                    elif ext in OPTI_BUF_EXT:
                        buf_type = get_buf_type(name)
                        if buf_type == 'protobuf':
                            with nnp.open(name, 'r') as f:
                                with get_file_handle_load(
                                        f, '.protobuf') as opti_p:
                                    opti_proto.MergeFromString(opti_p.read())
                        elif buf_type == 'h5':
                            nnp.extract(name, tmpdir)
                            opti_h5_files[name] = os.path.join(tmpdir, name)

    default_context = None
    if proto.HasField('global_config'):
        info.global_config = _global_config(proto)
        default_context = info.global_config.default_context
        if 'cuda' in default_context.backend:
            import nnabla_ext.cudnn
        elif 'cuda:float' in default_context.backend:
            try:
                import nnabla_ext.cudnn
            except:
                pass
        try:
            x = nn.Variable()
            y = nn.Variable()
            func = F.ReLU(default_context, inplace=True)
            func.setup([x], [y])
            func.forward([x], [y])
        except:
            logger.warn('Fallback to CPU context.')
            import nnabla_ext.cpu
            default_context = nnabla_ext.cpu.context()
    else:
        import nnabla_ext.cpu
        default_context = nnabla_ext.cpu.context()

    comm = current_communicator()
    if comm:
        default_context.device_id = str(comm.local_rank)
    if proto.HasField('training_config'):
        info.training_config = _training_config(proto)

    info.datasets = _datasets(
        proto, prepare_data_iterator if prepare_data_iterator is not None else
        info.training_config.max_epoch > 0)

    info.networks = _networks(proto, default_context, batch_size)

    info.optimizers = _optimizers(proto, default_context, info.networks,
                                  info.datasets)
    _load_optimizer_checkpoint(opti_proto, opti_h5_files, info)
    shutil.rmtree(tmpdir)

    info.monitors = _monitors(proto, default_context, info.networks,
                              info.datasets)

    info.executors = _executors(proto, info.networks)

    return info
Exemple #3
0
def generate():
    conf = get_config()

    # batch_size is forced to be 1
    conf.train.batch_size = 1

    image_shape = (conf.train.batch_size,) + \
        tuple(x * conf.model.g_n_scales for x in [512, 1024])

    # set context
    comm = init_nnabla(conf.nnabla_context)

    # find all test data
    if conf.train.data_set == "cityscapes":
        data_list = get_cityscape_datalist(conf.cityscapes,
                                           data_type="val",
                                           save_file=comm.rank == 0)
        conf.model.n_label_ids = conf.cityscapes.n_label_ids
    else:
        raise NotImplementedError(
            "Currently dataset {} is not supported.".format(conf.dataset))

    if comm.n_procs > 1:
        data_list = get_data_lists_for_each_process(data_list,
                                                    comm.n_procs)[comm.rank]

    # define generator
    generator = Generator(image_shape=image_shape, mconf=conf.model)

    # load parameters
    if not os.path.exists(conf.load_path):
        logger.warn(
            "Path to load params is not found."
            " Loading params is skipped and generated result will be unreasonable. ({})"
            .format(conf.load_path))

    nn.load_parameters(conf.load_path)

    progress_iterator = trange(len(data_list) // conf.train.batch_size,
                               desc="[Generating Images]",
                               disable=comm.rank > 0)

    # for label2color
    label2color = Colorize(conf.model.n_label_ids)

    save_path = os.path.join(conf.train.save_path, "generated")
    if not os.path.exists(save_path):
        os.makedirs(save_path, exist_ok=True)

    output_str = []
    for i in progress_iterator:
        paths = data_list[i]
        image, instance_id, object_id = cityscapes_load_function(
            paths[0], paths[1], paths[2], image_shape[1:])
        gen = generator(instance_id, object_id)
        gen = (gen - gen.min()) / (gen.max() - gen.min())
        id_colorized = label2color(object_id).astype(np.uint8)

        gen_image_path = os.path.join(save_path,
                                      "res{}_{}.png".format(comm.rank, i))
        input_image_path = os.path.join(save_path,
                                        "input_{}_{}.png".format(comm.rank, i))

        imsave(gen_image_path, gen[0], channel_first=True)
        imsave(input_image_path, id_colorized)
        output_str.append(" ".join(
            [x for x in paths + [gen_image_path, input_image_path]]))

    if comm.rank == 0:
        with open(os.path.join(save_path, "in_out_pairs.txt"), "w") as f:
            f.write("\n".join(output_str))
Exemple #4
0
def lms_scheduler(ctx, use_lms, gpu_memory_size=None, window_length=None):
    _check_list = [x.split(":")[0] for x in ctx.backend]
    if "cudnn" not in _check_list and "cuda" not in _check_list:
        logger.warn(
            "ctx passed to scheduler doesn't have cuda/cudnn backend. lms scheduler will not be used."
        )
        use_lms = False

    comm = current_communicator()
    if comm:
        logger.log(99,
                   f'[OoC] Currently OoC is disabled for Multi-GPU training.')
        use_lms = False

    if use_lms:
        gpu_index = 0
        if 'cuda' in str(ctx.backend):
            gpu_index = int(ctx.device_id)
        else:
            logger.log(99, f'[OoC] OoC is only enabled for GPU training.')
            raise Exception

        if gpu_memory_size is None or gpu_memory_size == 0:
            try:
                handle = nvml.nvmlDeviceGetHandleByIndex(gpu_index)
                total_memory = nvml.nvmlDeviceGetMemoryInfo(handle).total
                gpu_memory_size = int(total_memory * 0.7)
            except:
                logger.log(
                    99,
                    f'[OoC] Could not get GPU memory size using default value(6GB).'
                )
                gpu_memory_size = 6e9  # default 6 GiB
                pass

        if window_length is None or window_length == 0:
            window_length = int(gpu_memory_size * 1.5)

        logger.log(
            99,
            f'[OoC] gpu_memory_limit: {gpu_memory_size / 1e9}GB, prefetch_window_length: {window_length / 1e9}GB'
        )
        # Change array preference so that lms works well.
        # import nnabla_ext.cuda.init as cuda_init
        # cuda_init.prefer_cpu_pinned_array()
        # cuda_init.prefer_cuda_virtual_array()
        from nnabla.ext_utils import get_extension_context
        be, tc = ctx.backend[0].split(":")
        cpu_ctx = get_extension_context("cpu", device_id="", type_config=tc)
        return SwapInOutScheduler(cpu_ctx, ctx, gpu_memory_size, window_length)
    else:

        class DummyScheduler(object):
            function_pre_hook = None
            function_post_hook = None
            update_pre_hook = None
            update_post_hook = None

            def start_scheduling(self):
                return None

            def end_scheduling(self):
                return None

            def __enter__(self):
                return self

            def __exit__(self, exc_type, exc_val, exc_tb):
                pass

        return DummyScheduler()
Exemple #5
0
def load(filenames, prepare_data_iterator=True, batch_size=None, exclude_parameter=False, parameter_only=False, extension=".nntxt", context=None):
    '''load
    Load network information from files.

    Args:
        filenames (list): file-like object or List of filenames.
        extension: if filenames is file-like object, extension is one of ".nntxt", ".prototxt", ".protobuf", ".h5", ".nnp".
    Returns:
        dict: Network information.
    '''
    class Info:
        pass
    info = Info()

    info.prepare_data_iterator = prepare_data_iterator
    info.batch_size = batch_size
    info.exclude_parameter = exclude_parameter
    info.parameter_only = parameter_only
    info.proto = nnabla_pb2.NNablaProtoBuf()

    # first stage file loaders
    file_loaders = get_initial_file_loader()

    # using global parameter scope, keep consistency with legacy implementation.
    # To avoid to surprise previous developers, but it is better using
    # stand-alone OrderedDict() instance.
    info.parameter_scope = nn.parameter.get_current_parameter_scope()
    load_files(info, file_loaders, filenames, extension)

    default_context = None
    if context:
        if context == 'cpu':
            import nnabla_ext.cpu
            default_context = nnabla_ext.cpu.context()
        else:
            cs = context.split(':')
            if cs[0] == 'cudnn':
                if len(cs) == 1:
                    devid = 0
                else:
                    devid = int(cs[1])
            import nnabla_ext.cudnn
            default_context = nnabla_ext.cudnn.context(device_id=devid)
        if default_context is None:
            logger.warn('Invalid context [{}]'.format(context))
        elif info.proto.HasField('global_config'):
            info.global_config = _global_config(proto)
            info.global_config.default_context = default_context

    if default_context is None:
        if info.proto.HasField('global_config'):
            info.global_config = _global_config(info.proto)
            default_context = info.global_config.default_context
            if 'cuda' in default_context.backend:
                import nnabla_ext.cudnn
            elif 'cuda:float' in default_context.backend:
                try:
                    import nnabla_ext.cudnn
                except:
                    pass
        else:
            import nnabla_ext.cpu
            default_context = nnabla_ext.cpu.context()
            info.global_config = _global_config(
                None, default_context=default_context)

    default_context = _check_context(default_context)
    logger.log(99, 'Using context "{}"'.format(default_context))
    comm = current_communicator()
    if comm:
        default_context.device_id = str(comm.local_rank)
    if info.proto.HasField('training_config'):
        info.training_config = _training_config(info.proto)

    info.default_context = default_context
    info.datasets = _datasets(
        info.proto, prepare_data_iterator if prepare_data_iterator is not None else info.training_config.max_epoch > 0)

    info.renamed_variables = {}
    info.networks = _networks(info, nn.graph_def.ProtoGraph.from_proto(info.proto, param_scope=info.parameter_scope,
                                                                       rng=numpy.random.RandomState(0)))

    info.optimizers = _optimizers(info)
    info.monitors = _monitors(info)
    info.executors = _executors(info)

    return info
    def train(self):
        real = nn.Variable(shape=(self.bs, 3) + self.image_shape)
        inst_label = nn.Variable(shape=(self.bs,) + self.image_shape)
        id_label = nn.Variable(shape=(self.bs,) + self.image_shape)

        id_onehot, bm = encode_inputs(
            inst_label, id_label, n_ids=self.model_conf.n_label_ids, use_encoder=self.use_encoder)

        x = F.concatenate(id_onehot, bm, axis=1)

        # generator
        # Note that only global generator would be used in the case of g_scales = 1.
        generator = LocalGenerator()
        fake, _, = generator(x,
                             lg_channels=self.model_conf.lg_channels,
                             gg_channels=self.model_conf.gg_channels,
                             n_scales=self.model_conf.g_n_scales,
                             lg_n_residual_layers=self.model_conf.lg_num_residual_loop,
                             gg_n_residual_layers=self.model_conf.gg_num_residual_loop)
        unlinked_fake = fake.get_unlinked_variable(need_grad=True)

        # discriminator
        discriminator = PatchGAN(
            n_scales=self.model_conf.d_n_scales, use_spectral_normalization=False)
        d_real_out, d_real_feats = discriminator(
            F.concatenate(real, x, axis=1))
        d_fake_out, d_fake_feats = discriminator(
            F.concatenate(unlinked_fake, x, axis=1))
        g_gan, g_feat, d_real, d_fake = discriminator.get_loss(d_real_out, d_real_feats,
                                                               d_fake_out, d_fake_feats,
                                                               use_fm=True,
                                                               fm_lambda=self.train_conf.lambda_feat,
                                                               gan_loss_type="ls")

        g_vgg = vgg16_perceptual_loss(
            real, unlinked_fake) * self.train_conf.lambda_perceptual

        set_persistent_all(bm, fake, fake, g_gan,
                           g_feat, g_vgg, d_real, d_fake)

        g_loss = g_gan + g_feat + g_vgg
        d_loss = 0.5 * (d_real + d_fake)

        # load parameters
        if self.load_path:
            if not os.path.exists(self.load_path):
                logger.warn("Path to load params is not found."
                            " Loading params is skipped. ({})".format(self.load_path))
            else:
                nn.load_parameters(self.load_path)

        # Setup Solvers
        g_solver = S.Adam(beta1=0.5)
        g_solver.set_parameters(get_params_startswith("generator/local"))

        d_solver = S.Adam(beta1=0.5)
        d_solver.set_parameters(get_params_startswith("discriminator"))

        # lr scheduler
        lr_schduler = LinearDecayScheduler(self.train_conf.base_lr, 0.,
                                           start_iter=self.train_conf.lr_decay_starts,
                                           end_iter=self.train_conf.max_epochs)

        # Setup Reporter
        losses = {"g_gan": g_gan, "g_feat": g_feat,
                  "g_vgg": g_vgg, "d_real": d_real, "d_fake": d_fake}
        reporter = Reporter(self.comm, losses, self.train_conf.save_path)

        # for label2color
        label2color = Colorize(self.model_conf.n_label_ids)

        for epoch in range(self.train_conf.max_epochs):
            if epoch == self.fix_global_epoch:
                g_solver.set_parameters(get_params_startswith(
                    "generator"), reset=False, retain_state=True)

            # update learning rate for current epoch
            lr = lr_schduler(epoch)
            g_solver.set_learning_rate(lr)
            d_solver.set_learning_rate(lr)

            progress_iterator = trange(self.data_iter._size // self.bs,
                                       desc="[epoch {}]".format(epoch), disable=self.comm.rank > 0)

            reporter.start(progress_iterator)

            for i in progress_iterator:
                image, instance_id, object_id = self.data_iter.next()

                real.d = image
                inst_label.d = instance_id
                id_label.d = object_id

                # create fake
                fake.forward()

                # update discriminator
                d_solver.zero_grad()
                d_loss.forward()
                d_loss.backward(clear_buffer=True)

                if self.comm.n_procs > 1:
                    params = [
                        x.grad for x in d_solver.get_parameters().values()]
                    self.comm.all_reduce(params, division=False, inplace=False)
                d_solver.update()

                # update generator
                unlinked_fake.grad.zero()
                g_solver.zero_grad()
                g_loss.forward()
                g_loss.backward(clear_buffer=True)

                # backward generator
                fake.backward(grad=None, clear_buffer=True)

                if self.comm.n_procs > 1:
                    params = [
                        x.grad for x in g_solver.get_parameters().values()]
                    self.comm.all_reduce(params, division=False, inplace=False)
                g_solver.update()

                # report iteration progress
                reporter()

            # report epoch progress
            show_images = {"InputImage": label2color(id_label.data.get_data("r")).astype(np.uint8),
                           # "InputBoundary": bm.data.get_data("r").transpose((0, 2, 3, 1)),
                           "GeneratedImage": fake.data.get_data("r").transpose((0, 2, 3, 1)),
                           "RealImagse": real.data.get_data("r").transpose((0, 2, 3, 1))}
            reporter.step(epoch, show_images)

            if (epoch % 10) == 0 and self.comm.rank == 0:
                nn.save_parameters(os.path.join(
                    self.train_conf.save_path, 'param_{:03d}.h5'.format(epoch)))

        if self.comm.rank == 0:
            nn.save_parameters(os.path.join(
                self.train_conf.save_path, 'param_final.h5'))
Exemple #7
0
def generate():
    rng = np.random.RandomState(803)

    conf = get_config()

    # set context
    comm = init_nnabla(conf)

    # find all test data
    if conf.dataset == "cityscapes":
        data_list = get_cityscape_datalist(conf.cityscapes,
                                           data_type="val",
                                           save_file=comm.rank == 0)
        conf.n_class = conf.cityscapes.n_label_ids
        use_inst = True

        data_iter = create_cityscapes_iterator(conf.batch_size,
                                               data_list,
                                               comm=comm,
                                               image_shape=conf.image_shape,
                                               rng=rng,
                                               flip=False)
    elif conf.dataset == "ade20k":
        data_list = get_ade20k_datalist(conf.ade20k,
                                        data_type="val",
                                        save_file=comm.rank == 0)
        conf.n_class = conf.ade20k.n_label_ids + 1  # class id + unknown
        use_inst = False

        load_shape = tuple(
            x + 30
            for x in conf.image_shape) if conf.use_crop else conf.image_shape
        data_iter = create_ade20k_iterator(conf.batch_size,
                                           data_list,
                                           comm=comm,
                                           load_shape=load_shape,
                                           crop_shape=conf.image_shape,
                                           rng=rng,
                                           flip=False)
    else:
        raise NotImplementedError(
            "Currently dataset {} is not supported.".format(conf.dataset))

    # define generator
    generator = Generator(conf, use_inst)

    # load parameters
    if not os.path.exists(conf.load_params):
        logger.warn(
            "Path to load params is not found."
            " Loading params is skipped and generated result will be unreasonable. ({})"
            .format(conf.load_params))

    else:
        print("load parameters from {}".format(conf.load_params))
        nn.load_parameters(conf.load_params)

    niter = get_iteration_per_epoch(data_iter._size,
                                    conf.batch_size,
                                    round="ceil")

    progress_iterator = trange(niter,
                               desc="[Generating Images]",
                               disable=comm.rank > 0)

    # for label2color
    label2color = Colorize(conf.n_class)

    save_path = os.path.join(conf.save_path, "generated")
    if not os.path.exists(save_path):
        os.makedirs(save_path, exist_ok=True)

    logger.info("Generated images will be saved on '{}'.".format(save_path))

    cnt = 0
    for i in progress_iterator:
        if conf.dataset == "cityscapes":
            _, instance_id, object_id = data_iter.next()
        elif conf.dataset == "ade20k":
            _, object_id = data_iter.next()
            instance_id = None
        else:
            raise NotImplemented()

        gen = generator(instance_id, object_id)
        id_colorized = label2color(object_id).astype(np.uint8)

        valid = conf.batch_size
        if cnt > data_iter._size:
            valid = data_iter._size - conf.batch_size * (i - 1)

        for j in range(valid):
            gen_image_path = os.path.join(
                save_path, "res_{}_{}.png".format(comm.rank, cnt + j))
            input_image_path = os.path.join(
                save_path, "input_{}_{}.png".format(comm.rank, cnt + j))

            imsave(gen_image_path, gen[j], channel_first=True)
            imsave(input_image_path, id_colorized[j])

        cnt += conf.batch_size
Exemple #8
0
def lms_scheduler(ctx, use_lms, gpu_memory_size=None, window_length=None):
    _check_list = [x.split(":")[0] for x in ctx.backend]
    if "cudnn" not in _check_list and "cuda" not in _check_list:
        logger.warn(
            "ctx passed to scheduler doesn't have cuda/cudnn backend. lms scheduler will not be used."
        )
        use_lms = False

    comm = current_communicator()
    if comm:
        logger.log(99,
                   f'[OoC] Currently OoC is disabled for Multi-GPU training.')
        use_lms = False

    if use_lms:
        gpu_index = 0
        if 'cuda' in str(ctx.backend):
            gpu_index = int(ctx.device_id)
        else:
            logger.log(99, f'[OoC] OoC is only enabled for GPU training.')
            raise Exception

        # It is better to use nvml to get GPU infomation but due to windows problem, temporarily get information with `nvidia-smi`.
        if gpu_memory_size is None or gpu_memory_size == 0:
            try:
                import subprocess
                gpu_memory_size = int(
                    int(
                        subprocess.check_output(
                            'nvidia-smi --query-gpu=index,memory.total --format=csv'
                        ).decode().splitlines()[1:][gpu_index].split(',')
                        [1].strip().split()[0]) * (1024**2) * 0.7)
            except:
                logger.log(
                    99,
                    f'[OoC] Could not get GPU memory size using default value(6GB).'
                )
                gpu_memory_size = 6e9  # default 6 GiB
                pass

        if window_length is None or window_length == 0:
            window_length = int(gpu_memory_size * 1.5)

        logger.log(
            99,
            f'[OoC] gpu_memory_limit: {gpu_memory_size / 1e9}GB, prefetch_window_length: {window_length / 1e9}GB'
        )
        # Change array preference so that lms works well.
        # import nnabla_ext.cuda.init as cuda_init
        # cuda_init.prefer_cpu_pinned_array()
        # cuda_init.prefer_cuda_virtual_array()
        from nnabla.ext_utils import get_extension_context
        be, tc = ctx.backend[0].split(":")
        cpu_ctx = get_extension_context("cpu", device_id="", type_config=tc)
        return SwapInOutScheduler(cpu_ctx, ctx, gpu_memory_size, window_length)
    else:

        class DummyScheduler(object):
            function_pre_hook = None
            function_post_hook = None
            update_pre_hook = None
            update_post_hook = None

            def start_scheduling(self):
                return None

            def end_scheduling(self):
                return None

            def __enter__(self):
                return self

            def __exit__(self, exc_type, exc_val, exc_tb):
                pass

        return DummyScheduler()
Exemple #9
0
def generate():
    config = get_config()

    # batch_size is forced to be 1
    config.train.batch_size = 1

    image_shape = (config.train.batch_size, 3) + \
        tuple(x * config.model.g_n_scales for x in [512, 512])

    # set context
    comm = init_nnabla(config.nnabla_context)

    img_path_list = [
        os.path.join(config.test_input_dir, path)
        for path in os.listdir(config.test_input_dir)
    ]

    test_image = nn.Variable(shape=image_shape)
    # define generator
    generator = LocalGenerator()
    generated_image, _, = generator(
        test_image,
        lg_channels=config.model.lg_channels,
        gg_channels=config.model.gg_channels,
        n_scales=config.model.g_n_scales,
        lg_n_residual_layers=config.model.lg_num_residual_loop,
        gg_n_residual_layers=config.model.gg_num_residual_loop)

    # load parameters
    if not os.path.exists(config.load_path):
        logger.warn(
            "Path to load params is not found."
            " Loading params is skipped and generated result will be unreasonable. ({})"
            .format(config.load_path))

    nn.load_parameters(config.load_path)

    progress_iterator = trange(len(img_path_list) // config.train.batch_size,
                               desc="[Generating Images]",
                               disable=comm.rank > 0)

    save_path = os.path.join(config.test_output_dir)
    if not os.path.exists(save_path):
        os.makedirs(save_path, exist_ok=True)

    for i in progress_iterator:
        path = img_path_list[i]
        test_image_data = get_var(path, image_shape[2:])
        test_image.d = test_image_data

        generated_image.forward(clear_buffer=True)

        generated_image_data = (generated_image.d - generated_image.d.min()) / \
            (generated_image.d.max() - generated_image.d.min())
        test_image_data = test_image_data * 0.5 + 0.5

        gen_image_path = os.path.join(save_path,
                                      "res{}_{}.png".format(comm.rank, i))
        input_image_path = os.path.join(save_path,
                                        "input_{}_{}.png".format(comm.rank, i))

        imsave(gen_image_path, generated_image_data[0], channel_first=True)
        imsave(input_image_path, test_image_data[0], channel_first=True)
Exemple #10
0
def load(filenames,
         prepare_data_iterator=True,
         batch_size=None,
         exclude_parameter=False,
         parameter_only=False):
    '''load
    Load network information from files.

    Args:
        filenames (list): List of filenames.
    Returns:
        dict: Network information.
    '''
    class Info:
        pass

    info = Info()

    proto = nnabla_pb2.NNablaProtoBuf()
    for filename in filenames:
        _, ext = os.path.splitext(filename)

        # TODO: Here is some known problems.
        #   - Even when protobuf file includes network structure,
        #     it will not loaded.
        #   - Even when prototxt file includes parameter,
        #     it will not loaded.

        if ext in ['.nntxt', '.prototxt']:
            if not parameter_only:
                with open(filename, 'rt') as f:
                    try:
                        text_format.Merge(f.read(), proto)
                    except:
                        logger.critical('Failed to read {}.'.format(filename))
                        logger.critical(
                            '2 byte characters may be used for file name or folder name.'
                        )
                        raise
            if len(proto.parameter) > 0:
                if not exclude_parameter:
                    nn.load_parameters(filename)
        elif ext in ['.protobuf', '.h5']:
            if not exclude_parameter:
                nn.load_parameters(filename)
            else:
                logger.info('Skip loading parameter.')

        elif ext == '.nnp':
            try:
                tmpdir = tempfile.mkdtemp()
                with zipfile.ZipFile(filename, 'r') as nnp:
                    for name in nnp.namelist():
                        _, ext = os.path.splitext(name)
                        if name == 'nnp_version.txt':
                            nnp.extract(name, tmpdir)
                            with open(os.path.join(tmpdir, name), 'rt') as f:
                                pass  # TODO currently do nothing with version.
                        elif ext in ['.nntxt', '.prototxt']:
                            nnp.extract(name, tmpdir)
                            if not parameter_only:
                                with open(os.path.join(tmpdir, name),
                                          'rt') as f:
                                    text_format.Merge(f.read(), proto)
                            if len(proto.parameter) > 0:
                                if not exclude_parameter:
                                    nn.load_parameters(
                                        os.path.join(tmpdir, name))
                        elif ext in ['.protobuf', '.h5']:
                            nnp.extract(name, tmpdir)
                            if not exclude_parameter:
                                nn.load_parameters(os.path.join(tmpdir, name))
                            else:
                                logger.info('Skip loading parameter.')
            finally:
                shutil.rmtree(tmpdir)

    default_context = None
    if proto.HasField('global_config'):
        info.global_config = _global_config(proto)
        default_context = info.global_config.default_context
        if 'cuda' in default_context.backend:
            import nnabla_ext.cudnn
        elif 'cuda:float' in default_context.backend:
            try:
                import nnabla_ext.cudnn
            except:
                pass
        try:
            x = nn.Variable()
            y = nn.Variable()
            func = F.ReLU(default_context, inplace=True)
            func.setup([x], [y])
            func.forward([x], [y])
        except:
            logger.warn('Fallback to CPU context.')
            import nnabla_ext.cpu
            default_context = nnabla_ext.cpu.context()
    else:
        import nnabla_ext.cpu
        default_context = nnabla_ext.cpu.context()

    comm = current_communicator()
    if comm:
        default_context.device_id = str(comm.rank)
    if proto.HasField('training_config'):
        info.training_config = _training_config(proto)

    info.datasets = _datasets(
        proto, prepare_data_iterator if prepare_data_iterator is not None else
        info.training_config.max_epoch > 0)

    info.networks = _networks(proto, default_context, batch_size)

    info.optimizers = _optimizers(proto, default_context, info.networks,
                                  info.datasets)

    info.monitors = _monitors(proto, default_context, info.networks,
                              info.datasets)

    info.executors = _executors(proto, info.networks)

    return info