def _context(proto): comm = current_communicator() if not proto.backends: logger.warn('Old-style context. Updating to new format.') # Update from old Context backends = [x.strip() for x in proto.backend.split('|')] compute_backends = [x.strip() for x in proto.compute_backend.split('|')] if 'cuda' in backends: device_id = str(proto.device_id) if comm: device_id = str(comm.local_rank) if 'cudnn' in compute_backends: try: import nnabla_ext.cudnn ctx = nnabla_ext.cudnn.context(device_id=device_id) except ImportError: logger.warn('Fallback to CPU context.') import nnabla_ext.cpu ctx = nnabla_ext.cpu.context() elif 'default' in compute_backends: try: import nnabla_ext.cuda ctx = nnabla_ext.cuda.context(device_id=device_id) except ImportError: logger.warn('Fallback to CPU context.') import nnabla_ext.cpu ctx = nnabla_ext.cpu.context() else: raise ValueError( 'Invalid compute_backend {}'.format(proto.compute_backend)) elif 'cpu' in backends: import nnabla_ext.cpu ctx = nnabla_ext.cpu.context() else: raise ValueError('Invalid context {}'.format(proto)) ctx.array_class = str(proto.array_class) return ctx ctx = nn.Context() ctx.backend = proto.backends ctx.array_class = str(proto.array_class) if comm: ctx.device_id = str(comm.local_rank) else: ctx.device_id = str(proto.device_id) return ctx
def load(filenames, prepare_data_iterator=True, batch_size=None, exclude_parameter=False, parameter_only=False, extension=".nntxt"): '''load Load network information from files. Args: filenames (list): file-like object or List of filenames. extension: if filenames is file-like object, extension is one of ".nntxt", ".prototxt", ".protobuf", ".h5", ".nnp". Returns: dict: Network information. ''' class Info: pass info = Info() proto = nnabla_pb2.NNablaProtoBuf() # optimizer checkpoint opti_proto = nnabla_pb2.NNablaProtoBuf() OPTI_BUF_EXT = ['.optimizer'] opti_h5_files = {} tmpdir = tempfile.mkdtemp() if isinstance(filenames, list) or isinstance(filenames, tuple): pass elif isinstance(filenames, str) or hasattr(filenames, 'read'): filenames = [filenames] for filename in filenames: if isinstance(filename, str): _, ext = os.path.splitext(filename) else: ext = extension # TODO: Here is some known problems. # - Even when protobuf file includes network structure, # it will not loaded. # - Even when prototxt file includes parameter, # it will not loaded. if ext in ['.nntxt', '.prototxt']: if not parameter_only: with get_file_handle_load(filename, ext) as f: try: text_format.Merge(f.read(), proto) except: logger.critical('Failed to read {}.'.format(filename)) logger.critical( '2 byte characters may be used for file name or folder name.' ) raise if len(proto.parameter) > 0: if not exclude_parameter: nn.load_parameters(filename, extension=ext) elif ext in ['.protobuf', '.h5']: if not exclude_parameter: nn.load_parameters(filename, extension=ext) else: logger.info('Skip loading parameter.') elif ext == '.nnp': with get_file_handle_load(filename, ext) as nnp: for name in nnp.namelist(): _, ext = os.path.splitext(name) if name == 'nnp_version.txt': pass # TODO currently do nothing with version. elif ext in ['.nntxt', '.prototxt']: if not parameter_only: with nnp.open(name, 'r') as f: text_format.Merge(f.read(), proto) if len(proto.parameter) > 0: if not exclude_parameter: with nnp.open(name, 'r') as f: nn.load_parameters(f, extension=ext) elif ext in ['.protobuf', '.h5']: if not exclude_parameter: with nnp.open(name, 'r') as f: nn.load_parameters(f, extension=ext) else: logger.info('Skip loading parameter.') elif ext in OPTI_BUF_EXT: buf_type = get_buf_type(name) if buf_type == 'protobuf': with nnp.open(name, 'r') as f: with get_file_handle_load( f, '.protobuf') as opti_p: opti_proto.MergeFromString(opti_p.read()) elif buf_type == 'h5': nnp.extract(name, tmpdir) opti_h5_files[name] = os.path.join(tmpdir, name) default_context = None if proto.HasField('global_config'): info.global_config = _global_config(proto) default_context = info.global_config.default_context if 'cuda' in default_context.backend: import nnabla_ext.cudnn elif 'cuda:float' in default_context.backend: try: import nnabla_ext.cudnn except: pass try: x = nn.Variable() y = nn.Variable() func = F.ReLU(default_context, inplace=True) func.setup([x], [y]) func.forward([x], [y]) except: logger.warn('Fallback to CPU context.') import nnabla_ext.cpu default_context = nnabla_ext.cpu.context() else: import nnabla_ext.cpu default_context = nnabla_ext.cpu.context() comm = current_communicator() if comm: default_context.device_id = str(comm.local_rank) if proto.HasField('training_config'): info.training_config = _training_config(proto) info.datasets = _datasets( proto, prepare_data_iterator if prepare_data_iterator is not None else info.training_config.max_epoch > 0) info.networks = _networks(proto, default_context, batch_size) info.optimizers = _optimizers(proto, default_context, info.networks, info.datasets) _load_optimizer_checkpoint(opti_proto, opti_h5_files, info) shutil.rmtree(tmpdir) info.monitors = _monitors(proto, default_context, info.networks, info.datasets) info.executors = _executors(proto, info.networks) return info
def generate(): conf = get_config() # batch_size is forced to be 1 conf.train.batch_size = 1 image_shape = (conf.train.batch_size,) + \ tuple(x * conf.model.g_n_scales for x in [512, 1024]) # set context comm = init_nnabla(conf.nnabla_context) # find all test data if conf.train.data_set == "cityscapes": data_list = get_cityscape_datalist(conf.cityscapes, data_type="val", save_file=comm.rank == 0) conf.model.n_label_ids = conf.cityscapes.n_label_ids else: raise NotImplementedError( "Currently dataset {} is not supported.".format(conf.dataset)) if comm.n_procs > 1: data_list = get_data_lists_for_each_process(data_list, comm.n_procs)[comm.rank] # define generator generator = Generator(image_shape=image_shape, mconf=conf.model) # load parameters if not os.path.exists(conf.load_path): logger.warn( "Path to load params is not found." " Loading params is skipped and generated result will be unreasonable. ({})" .format(conf.load_path)) nn.load_parameters(conf.load_path) progress_iterator = trange(len(data_list) // conf.train.batch_size, desc="[Generating Images]", disable=comm.rank > 0) # for label2color label2color = Colorize(conf.model.n_label_ids) save_path = os.path.join(conf.train.save_path, "generated") if not os.path.exists(save_path): os.makedirs(save_path, exist_ok=True) output_str = [] for i in progress_iterator: paths = data_list[i] image, instance_id, object_id = cityscapes_load_function( paths[0], paths[1], paths[2], image_shape[1:]) gen = generator(instance_id, object_id) gen = (gen - gen.min()) / (gen.max() - gen.min()) id_colorized = label2color(object_id).astype(np.uint8) gen_image_path = os.path.join(save_path, "res{}_{}.png".format(comm.rank, i)) input_image_path = os.path.join(save_path, "input_{}_{}.png".format(comm.rank, i)) imsave(gen_image_path, gen[0], channel_first=True) imsave(input_image_path, id_colorized) output_str.append(" ".join( [x for x in paths + [gen_image_path, input_image_path]])) if comm.rank == 0: with open(os.path.join(save_path, "in_out_pairs.txt"), "w") as f: f.write("\n".join(output_str))
def lms_scheduler(ctx, use_lms, gpu_memory_size=None, window_length=None): _check_list = [x.split(":")[0] for x in ctx.backend] if "cudnn" not in _check_list and "cuda" not in _check_list: logger.warn( "ctx passed to scheduler doesn't have cuda/cudnn backend. lms scheduler will not be used." ) use_lms = False comm = current_communicator() if comm: logger.log(99, f'[OoC] Currently OoC is disabled for Multi-GPU training.') use_lms = False if use_lms: gpu_index = 0 if 'cuda' in str(ctx.backend): gpu_index = int(ctx.device_id) else: logger.log(99, f'[OoC] OoC is only enabled for GPU training.') raise Exception if gpu_memory_size is None or gpu_memory_size == 0: try: handle = nvml.nvmlDeviceGetHandleByIndex(gpu_index) total_memory = nvml.nvmlDeviceGetMemoryInfo(handle).total gpu_memory_size = int(total_memory * 0.7) except: logger.log( 99, f'[OoC] Could not get GPU memory size using default value(6GB).' ) gpu_memory_size = 6e9 # default 6 GiB pass if window_length is None or window_length == 0: window_length = int(gpu_memory_size * 1.5) logger.log( 99, f'[OoC] gpu_memory_limit: {gpu_memory_size / 1e9}GB, prefetch_window_length: {window_length / 1e9}GB' ) # Change array preference so that lms works well. # import nnabla_ext.cuda.init as cuda_init # cuda_init.prefer_cpu_pinned_array() # cuda_init.prefer_cuda_virtual_array() from nnabla.ext_utils import get_extension_context be, tc = ctx.backend[0].split(":") cpu_ctx = get_extension_context("cpu", device_id="", type_config=tc) return SwapInOutScheduler(cpu_ctx, ctx, gpu_memory_size, window_length) else: class DummyScheduler(object): function_pre_hook = None function_post_hook = None update_pre_hook = None update_post_hook = None def start_scheduling(self): return None def end_scheduling(self): return None def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): pass return DummyScheduler()
def load(filenames, prepare_data_iterator=True, batch_size=None, exclude_parameter=False, parameter_only=False, extension=".nntxt", context=None): '''load Load network information from files. Args: filenames (list): file-like object or List of filenames. extension: if filenames is file-like object, extension is one of ".nntxt", ".prototxt", ".protobuf", ".h5", ".nnp". Returns: dict: Network information. ''' class Info: pass info = Info() info.prepare_data_iterator = prepare_data_iterator info.batch_size = batch_size info.exclude_parameter = exclude_parameter info.parameter_only = parameter_only info.proto = nnabla_pb2.NNablaProtoBuf() # first stage file loaders file_loaders = get_initial_file_loader() # using global parameter scope, keep consistency with legacy implementation. # To avoid to surprise previous developers, but it is better using # stand-alone OrderedDict() instance. info.parameter_scope = nn.parameter.get_current_parameter_scope() load_files(info, file_loaders, filenames, extension) default_context = None if context: if context == 'cpu': import nnabla_ext.cpu default_context = nnabla_ext.cpu.context() else: cs = context.split(':') if cs[0] == 'cudnn': if len(cs) == 1: devid = 0 else: devid = int(cs[1]) import nnabla_ext.cudnn default_context = nnabla_ext.cudnn.context(device_id=devid) if default_context is None: logger.warn('Invalid context [{}]'.format(context)) elif info.proto.HasField('global_config'): info.global_config = _global_config(proto) info.global_config.default_context = default_context if default_context is None: if info.proto.HasField('global_config'): info.global_config = _global_config(info.proto) default_context = info.global_config.default_context if 'cuda' in default_context.backend: import nnabla_ext.cudnn elif 'cuda:float' in default_context.backend: try: import nnabla_ext.cudnn except: pass else: import nnabla_ext.cpu default_context = nnabla_ext.cpu.context() info.global_config = _global_config( None, default_context=default_context) default_context = _check_context(default_context) logger.log(99, 'Using context "{}"'.format(default_context)) comm = current_communicator() if comm: default_context.device_id = str(comm.local_rank) if info.proto.HasField('training_config'): info.training_config = _training_config(info.proto) info.default_context = default_context info.datasets = _datasets( info.proto, prepare_data_iterator if prepare_data_iterator is not None else info.training_config.max_epoch > 0) info.renamed_variables = {} info.networks = _networks(info, nn.graph_def.ProtoGraph.from_proto(info.proto, param_scope=info.parameter_scope, rng=numpy.random.RandomState(0))) info.optimizers = _optimizers(info) info.monitors = _monitors(info) info.executors = _executors(info) return info
def train(self): real = nn.Variable(shape=(self.bs, 3) + self.image_shape) inst_label = nn.Variable(shape=(self.bs,) + self.image_shape) id_label = nn.Variable(shape=(self.bs,) + self.image_shape) id_onehot, bm = encode_inputs( inst_label, id_label, n_ids=self.model_conf.n_label_ids, use_encoder=self.use_encoder) x = F.concatenate(id_onehot, bm, axis=1) # generator # Note that only global generator would be used in the case of g_scales = 1. generator = LocalGenerator() fake, _, = generator(x, lg_channels=self.model_conf.lg_channels, gg_channels=self.model_conf.gg_channels, n_scales=self.model_conf.g_n_scales, lg_n_residual_layers=self.model_conf.lg_num_residual_loop, gg_n_residual_layers=self.model_conf.gg_num_residual_loop) unlinked_fake = fake.get_unlinked_variable(need_grad=True) # discriminator discriminator = PatchGAN( n_scales=self.model_conf.d_n_scales, use_spectral_normalization=False) d_real_out, d_real_feats = discriminator( F.concatenate(real, x, axis=1)) d_fake_out, d_fake_feats = discriminator( F.concatenate(unlinked_fake, x, axis=1)) g_gan, g_feat, d_real, d_fake = discriminator.get_loss(d_real_out, d_real_feats, d_fake_out, d_fake_feats, use_fm=True, fm_lambda=self.train_conf.lambda_feat, gan_loss_type="ls") g_vgg = vgg16_perceptual_loss( real, unlinked_fake) * self.train_conf.lambda_perceptual set_persistent_all(bm, fake, fake, g_gan, g_feat, g_vgg, d_real, d_fake) g_loss = g_gan + g_feat + g_vgg d_loss = 0.5 * (d_real + d_fake) # load parameters if self.load_path: if not os.path.exists(self.load_path): logger.warn("Path to load params is not found." " Loading params is skipped. ({})".format(self.load_path)) else: nn.load_parameters(self.load_path) # Setup Solvers g_solver = S.Adam(beta1=0.5) g_solver.set_parameters(get_params_startswith("generator/local")) d_solver = S.Adam(beta1=0.5) d_solver.set_parameters(get_params_startswith("discriminator")) # lr scheduler lr_schduler = LinearDecayScheduler(self.train_conf.base_lr, 0., start_iter=self.train_conf.lr_decay_starts, end_iter=self.train_conf.max_epochs) # Setup Reporter losses = {"g_gan": g_gan, "g_feat": g_feat, "g_vgg": g_vgg, "d_real": d_real, "d_fake": d_fake} reporter = Reporter(self.comm, losses, self.train_conf.save_path) # for label2color label2color = Colorize(self.model_conf.n_label_ids) for epoch in range(self.train_conf.max_epochs): if epoch == self.fix_global_epoch: g_solver.set_parameters(get_params_startswith( "generator"), reset=False, retain_state=True) # update learning rate for current epoch lr = lr_schduler(epoch) g_solver.set_learning_rate(lr) d_solver.set_learning_rate(lr) progress_iterator = trange(self.data_iter._size // self.bs, desc="[epoch {}]".format(epoch), disable=self.comm.rank > 0) reporter.start(progress_iterator) for i in progress_iterator: image, instance_id, object_id = self.data_iter.next() real.d = image inst_label.d = instance_id id_label.d = object_id # create fake fake.forward() # update discriminator d_solver.zero_grad() d_loss.forward() d_loss.backward(clear_buffer=True) if self.comm.n_procs > 1: params = [ x.grad for x in d_solver.get_parameters().values()] self.comm.all_reduce(params, division=False, inplace=False) d_solver.update() # update generator unlinked_fake.grad.zero() g_solver.zero_grad() g_loss.forward() g_loss.backward(clear_buffer=True) # backward generator fake.backward(grad=None, clear_buffer=True) if self.comm.n_procs > 1: params = [ x.grad for x in g_solver.get_parameters().values()] self.comm.all_reduce(params, division=False, inplace=False) g_solver.update() # report iteration progress reporter() # report epoch progress show_images = {"InputImage": label2color(id_label.data.get_data("r")).astype(np.uint8), # "InputBoundary": bm.data.get_data("r").transpose((0, 2, 3, 1)), "GeneratedImage": fake.data.get_data("r").transpose((0, 2, 3, 1)), "RealImagse": real.data.get_data("r").transpose((0, 2, 3, 1))} reporter.step(epoch, show_images) if (epoch % 10) == 0 and self.comm.rank == 0: nn.save_parameters(os.path.join( self.train_conf.save_path, 'param_{:03d}.h5'.format(epoch))) if self.comm.rank == 0: nn.save_parameters(os.path.join( self.train_conf.save_path, 'param_final.h5'))
def generate(): rng = np.random.RandomState(803) conf = get_config() # set context comm = init_nnabla(conf) # find all test data if conf.dataset == "cityscapes": data_list = get_cityscape_datalist(conf.cityscapes, data_type="val", save_file=comm.rank == 0) conf.n_class = conf.cityscapes.n_label_ids use_inst = True data_iter = create_cityscapes_iterator(conf.batch_size, data_list, comm=comm, image_shape=conf.image_shape, rng=rng, flip=False) elif conf.dataset == "ade20k": data_list = get_ade20k_datalist(conf.ade20k, data_type="val", save_file=comm.rank == 0) conf.n_class = conf.ade20k.n_label_ids + 1 # class id + unknown use_inst = False load_shape = tuple( x + 30 for x in conf.image_shape) if conf.use_crop else conf.image_shape data_iter = create_ade20k_iterator(conf.batch_size, data_list, comm=comm, load_shape=load_shape, crop_shape=conf.image_shape, rng=rng, flip=False) else: raise NotImplementedError( "Currently dataset {} is not supported.".format(conf.dataset)) # define generator generator = Generator(conf, use_inst) # load parameters if not os.path.exists(conf.load_params): logger.warn( "Path to load params is not found." " Loading params is skipped and generated result will be unreasonable. ({})" .format(conf.load_params)) else: print("load parameters from {}".format(conf.load_params)) nn.load_parameters(conf.load_params) niter = get_iteration_per_epoch(data_iter._size, conf.batch_size, round="ceil") progress_iterator = trange(niter, desc="[Generating Images]", disable=comm.rank > 0) # for label2color label2color = Colorize(conf.n_class) save_path = os.path.join(conf.save_path, "generated") if not os.path.exists(save_path): os.makedirs(save_path, exist_ok=True) logger.info("Generated images will be saved on '{}'.".format(save_path)) cnt = 0 for i in progress_iterator: if conf.dataset == "cityscapes": _, instance_id, object_id = data_iter.next() elif conf.dataset == "ade20k": _, object_id = data_iter.next() instance_id = None else: raise NotImplemented() gen = generator(instance_id, object_id) id_colorized = label2color(object_id).astype(np.uint8) valid = conf.batch_size if cnt > data_iter._size: valid = data_iter._size - conf.batch_size * (i - 1) for j in range(valid): gen_image_path = os.path.join( save_path, "res_{}_{}.png".format(comm.rank, cnt + j)) input_image_path = os.path.join( save_path, "input_{}_{}.png".format(comm.rank, cnt + j)) imsave(gen_image_path, gen[j], channel_first=True) imsave(input_image_path, id_colorized[j]) cnt += conf.batch_size
def lms_scheduler(ctx, use_lms, gpu_memory_size=None, window_length=None): _check_list = [x.split(":")[0] for x in ctx.backend] if "cudnn" not in _check_list and "cuda" not in _check_list: logger.warn( "ctx passed to scheduler doesn't have cuda/cudnn backend. lms scheduler will not be used." ) use_lms = False comm = current_communicator() if comm: logger.log(99, f'[OoC] Currently OoC is disabled for Multi-GPU training.') use_lms = False if use_lms: gpu_index = 0 if 'cuda' in str(ctx.backend): gpu_index = int(ctx.device_id) else: logger.log(99, f'[OoC] OoC is only enabled for GPU training.') raise Exception # It is better to use nvml to get GPU infomation but due to windows problem, temporarily get information with `nvidia-smi`. if gpu_memory_size is None or gpu_memory_size == 0: try: import subprocess gpu_memory_size = int( int( subprocess.check_output( 'nvidia-smi --query-gpu=index,memory.total --format=csv' ).decode().splitlines()[1:][gpu_index].split(',') [1].strip().split()[0]) * (1024**2) * 0.7) except: logger.log( 99, f'[OoC] Could not get GPU memory size using default value(6GB).' ) gpu_memory_size = 6e9 # default 6 GiB pass if window_length is None or window_length == 0: window_length = int(gpu_memory_size * 1.5) logger.log( 99, f'[OoC] gpu_memory_limit: {gpu_memory_size / 1e9}GB, prefetch_window_length: {window_length / 1e9}GB' ) # Change array preference so that lms works well. # import nnabla_ext.cuda.init as cuda_init # cuda_init.prefer_cpu_pinned_array() # cuda_init.prefer_cuda_virtual_array() from nnabla.ext_utils import get_extension_context be, tc = ctx.backend[0].split(":") cpu_ctx = get_extension_context("cpu", device_id="", type_config=tc) return SwapInOutScheduler(cpu_ctx, ctx, gpu_memory_size, window_length) else: class DummyScheduler(object): function_pre_hook = None function_post_hook = None update_pre_hook = None update_post_hook = None def start_scheduling(self): return None def end_scheduling(self): return None def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): pass return DummyScheduler()
def generate(): config = get_config() # batch_size is forced to be 1 config.train.batch_size = 1 image_shape = (config.train.batch_size, 3) + \ tuple(x * config.model.g_n_scales for x in [512, 512]) # set context comm = init_nnabla(config.nnabla_context) img_path_list = [ os.path.join(config.test_input_dir, path) for path in os.listdir(config.test_input_dir) ] test_image = nn.Variable(shape=image_shape) # define generator generator = LocalGenerator() generated_image, _, = generator( test_image, lg_channels=config.model.lg_channels, gg_channels=config.model.gg_channels, n_scales=config.model.g_n_scales, lg_n_residual_layers=config.model.lg_num_residual_loop, gg_n_residual_layers=config.model.gg_num_residual_loop) # load parameters if not os.path.exists(config.load_path): logger.warn( "Path to load params is not found." " Loading params is skipped and generated result will be unreasonable. ({})" .format(config.load_path)) nn.load_parameters(config.load_path) progress_iterator = trange(len(img_path_list) // config.train.batch_size, desc="[Generating Images]", disable=comm.rank > 0) save_path = os.path.join(config.test_output_dir) if not os.path.exists(save_path): os.makedirs(save_path, exist_ok=True) for i in progress_iterator: path = img_path_list[i] test_image_data = get_var(path, image_shape[2:]) test_image.d = test_image_data generated_image.forward(clear_buffer=True) generated_image_data = (generated_image.d - generated_image.d.min()) / \ (generated_image.d.max() - generated_image.d.min()) test_image_data = test_image_data * 0.5 + 0.5 gen_image_path = os.path.join(save_path, "res{}_{}.png".format(comm.rank, i)) input_image_path = os.path.join(save_path, "input_{}_{}.png".format(comm.rank, i)) imsave(gen_image_path, generated_image_data[0], channel_first=True) imsave(input_image_path, test_image_data[0], channel_first=True)
def load(filenames, prepare_data_iterator=True, batch_size=None, exclude_parameter=False, parameter_only=False): '''load Load network information from files. Args: filenames (list): List of filenames. Returns: dict: Network information. ''' class Info: pass info = Info() proto = nnabla_pb2.NNablaProtoBuf() for filename in filenames: _, ext = os.path.splitext(filename) # TODO: Here is some known problems. # - Even when protobuf file includes network structure, # it will not loaded. # - Even when prototxt file includes parameter, # it will not loaded. if ext in ['.nntxt', '.prototxt']: if not parameter_only: with open(filename, 'rt') as f: try: text_format.Merge(f.read(), proto) except: logger.critical('Failed to read {}.'.format(filename)) logger.critical( '2 byte characters may be used for file name or folder name.' ) raise if len(proto.parameter) > 0: if not exclude_parameter: nn.load_parameters(filename) elif ext in ['.protobuf', '.h5']: if not exclude_parameter: nn.load_parameters(filename) else: logger.info('Skip loading parameter.') elif ext == '.nnp': try: tmpdir = tempfile.mkdtemp() with zipfile.ZipFile(filename, 'r') as nnp: for name in nnp.namelist(): _, ext = os.path.splitext(name) if name == 'nnp_version.txt': nnp.extract(name, tmpdir) with open(os.path.join(tmpdir, name), 'rt') as f: pass # TODO currently do nothing with version. elif ext in ['.nntxt', '.prototxt']: nnp.extract(name, tmpdir) if not parameter_only: with open(os.path.join(tmpdir, name), 'rt') as f: text_format.Merge(f.read(), proto) if len(proto.parameter) > 0: if not exclude_parameter: nn.load_parameters( os.path.join(tmpdir, name)) elif ext in ['.protobuf', '.h5']: nnp.extract(name, tmpdir) if not exclude_parameter: nn.load_parameters(os.path.join(tmpdir, name)) else: logger.info('Skip loading parameter.') finally: shutil.rmtree(tmpdir) default_context = None if proto.HasField('global_config'): info.global_config = _global_config(proto) default_context = info.global_config.default_context if 'cuda' in default_context.backend: import nnabla_ext.cudnn elif 'cuda:float' in default_context.backend: try: import nnabla_ext.cudnn except: pass try: x = nn.Variable() y = nn.Variable() func = F.ReLU(default_context, inplace=True) func.setup([x], [y]) func.forward([x], [y]) except: logger.warn('Fallback to CPU context.') import nnabla_ext.cpu default_context = nnabla_ext.cpu.context() else: import nnabla_ext.cpu default_context = nnabla_ext.cpu.context() comm = current_communicator() if comm: default_context.device_id = str(comm.rank) if proto.HasField('training_config'): info.training_config = _training_config(proto) info.datasets = _datasets( proto, prepare_data_iterator if prepare_data_iterator is not None else info.training_config.max_epoch > 0) info.networks = _networks(proto, default_context, batch_size) info.optimizers = _optimizers(proto, default_context, info.networks, info.datasets) info.monitors = _monitors(proto, default_context, info.networks, info.datasets) info.executors = _executors(proto, info.networks) return info