def cli_main(parser, args): global return_value return_value = False if 'func' not in args: parser.print_help(sys.stderr) sys.exit(-1) if args.mpi: from nnabla.utils.communicator_util import create_communicator comm = create_communicator() try: return_value = args.func(args) except: import traceback print(traceback.format_exc()) logger.log(99, "ABORTED") os.kill(os.getpid(), 9) # comm.abort() else: try: return_value = args.func(args) except: import traceback print(traceback.format_exc()) return_value = False sys.exit(-1)
def cli_main(): global return_value return_value = False import nnabla parser = argparse.ArgumentParser( description='Command line interface ' + 'for NNabla({})'.format(_nnabla_version())) parser.add_argument('-m', '--mpi', help='exec with mpi.', action='store_true') subparsers = parser.add_subparsers() from nnabla.utils.cli.train import add_train_command add_train_command(subparsers) from nnabla.utils.cli.forward import add_infer_command, add_forward_command add_infer_command(subparsers) add_forward_command(subparsers) from nnabla.utils.cli.encode_decode_param import add_decode_param_command, add_encode_param_command add_encode_param_command(subparsers) add_decode_param_command(subparsers) from nnabla.utils.cli.profile import add_profile_command add_profile_command(subparsers) from nnabla.utils.cli.conv_dataset import add_conv_dataset_command add_conv_dataset_command(subparsers) from nnabla.utils.cli.compare_with_cpu import add_compare_with_cpu_command add_compare_with_cpu_command(subparsers) from nnabla.utils.cli.create_image_classification_dataset import add_create_image_classification_dataset_command add_create_image_classification_dataset_command(subparsers) from nnabla.utils.cli.uploader import add_upload_command add_upload_command(subparsers) from nnabla.utils.cli.uploader import add_create_tar_command add_create_tar_command(subparsers) from nnabla.utils.cli.convert import add_convert_command add_convert_command(subparsers) # Version subparser = subparsers.add_parser('version', help='Print version and build number.') subparser.set_defaults(func=version_command) print('NNabla command line interface (Version {}, Build {})'.format( nnabla.__version__, nnabla.__build_number__)) args = parser.parse_args() if 'func' not in args: parser.print_help(sys.stderr) return if args.mpi: from nnabla.utils.communicator_util import create_communicator comm = create_communicator() try: return_value = args.func(args) except: import traceback print(traceback.format_exc()) comm.abort() else: try: return_value = args.func(args) except: import traceback print(traceback.format_exc()) return_value = False
def cli_main(): global return_value return_value = False parser = argparse.ArgumentParser( description='Command line interface ' + 'for NNabla({})'.format(_nnabla_version())) parser.add_argument('-m', '--mpi', help='exec with mpi.', action='store_true') subparsers = parser.add_subparsers() from nnabla.utils.cli.train import add_train_command add_train_command(subparsers) from nnabla.utils.cli.forward import add_infer_command, add_forward_command add_infer_command(subparsers) add_forward_command(subparsers) from nnabla.utils.cli.encode_decode_param import add_decode_param_command, add_encode_param_command add_encode_param_command(subparsers) add_decode_param_command(subparsers) from nnabla.utils.cli.profile import add_profile_command add_profile_command(subparsers) from nnabla.utils.cli.conv_dataset import add_conv_dataset_command add_conv_dataset_command(subparsers) from nnabla.utils.cli.compare_with_cpu import add_compare_with_cpu_command add_compare_with_cpu_command(subparsers) from nnabla.utils.cli.create_image_classification_dataset import add_create_image_classification_dataset_command add_create_image_classification_dataset_command(subparsers) from nnabla.utils.cli.create_object_detection_dataset import add_create_object_detection_dataset_command add_create_object_detection_dataset_command(subparsers) from nnabla.utils.cli.uploader import add_upload_command add_upload_command(subparsers) from nnabla.utils.cli.uploader import add_create_tar_command add_create_tar_command(subparsers) from nnabla.utils.cli.convert import add_convert_command add_convert_command(subparsers) from nnabla.utils.cli.func_info import add_function_info_command add_function_info_command(subparsers) from nnabla.utils.cli.optimize_pb_model import add_optimize_pb_model_command add_optimize_pb_model_command(subparsers) from nnabla.utils.cli.plot import (add_plot_series_command, add_plot_timer_command) add_plot_series_command(subparsers) add_plot_timer_command(subparsers) from nnabla.utils.cli.draw_graph import add_draw_graph_command add_draw_graph_command(subparsers) # Version subparser = subparsers.add_parser('version', help='Print version and build number.') subparser.set_defaults(func=version_command) print('NNabla command line interface ({})'.format(_nnabla_version())) args = parser.parse_args() if 'func' not in args: parser.print_help(sys.stderr) sys.exit(-1) if args.mpi: from nnabla.utils.communicator_util import create_communicator comm = create_communicator() try: return_value = args.func(args) except: import traceback print(traceback.format_exc()) logger.log(99, "ABORTED") os.kill(os.getpid(), 9) # comm.abort() else: try: return_value = args.func(args) except: import traceback print(traceback.format_exc()) return_value = False sys.exit(-1)
def train(args): """ Multi-Device Training NOTE: the communicator exposes low-level interfaces Steps: * Instantiate a communicator and set parameter variables. * Specify contexts for computation. * Initialize DataIterator. * Construct a computation graph for training and one for validation. * Initialize solver and set parameter variables to that. * Load checkpoint to resume previous training. * Create monitor instances for saving and displaying training stats. * Training loop * Computate error rate for validation data (periodically) * Get a next minibatch. * Execute forwardprop * Set parameter gradients zero * Execute backprop. * AllReduce for gradients * Solver updates parameters by using gradients computed by backprop and all reduce. * Compute training error """ # Create Communicator and Context comm = create_communicator(ignore_error=True) if comm: n_devices = comm.size mpi_rank = comm.rank device_id = comm.local_rank else: n_devices = 1 mpi_rank = 0 device_id = args.device_id if args.context == 'cpu': import nnabla_ext.cpu context = nnabla_ext.cpu.context() else: import nnabla_ext.cudnn context = nnabla_ext.cudnn.context(device_id=device_id) nn.set_default_context(context) n_train_samples = 50000 n_valid_samples = 10000 bs_valid = args.batch_size iter_per_epoch = int(n_train_samples / args.batch_size / n_devices) # Model rng = np.random.RandomState(313) comm_syncbn = comm if args.sync_bn else None if args.net == "cifar10_resnet23": prediction = functools.partial(resnet23_prediction, rng=rng, ncls=10, nmaps=64, act=F.relu, comm=comm_syncbn) data_iterator = data_iterator_cifar10 if args.net == "cifar100_resnet23": prediction = functools.partial(resnet23_prediction, rng=rng, ncls=100, nmaps=384, act=F.elu, comm=comm_syncbn) data_iterator = data_iterator_cifar100 # Create training graphs image_train = nn.Variable((args.batch_size, 3, 32, 32)) label_train = nn.Variable((args.batch_size, 1)) pred_train = prediction(image_train, test=False) pred_train.persistent = True loss_train = (loss_function(pred_train, label_train) / n_devices).apply(persistent=True) error_train = F.mean(F.top_n_error(pred_train, label_train, axis=1)).apply(persistent=True) loss_error_train = F.sink(loss_train, error_train) # Create validation graphs image_valid = nn.Variable((bs_valid, 3, 32, 32)) label_valid = nn.Variable((bs_valid, 1)) pred_valid = prediction(image_valid, test=True) error_valid = F.mean(F.top_n_error(pred_valid, label_valid, axis=1)) # Solvers solver = S.Adam() solver.set_parameters(nn.get_parameters()) base_lr = args.learning_rate warmup_iter = iter_per_epoch * args.warmup_epoch warmup_slope = base_lr * (n_devices - 1) / warmup_iter solver.set_learning_rate(base_lr) # load checkpoint if file exist. start_point = 0 if args.use_latest_checkpoint: files = glob.glob(f'{args.model_save_path}/checkpoint_*.json') if len(files) != 0: index = max([ int(n) for n in [re.sub(r'.*checkpoint_(\d+).json', '\\1', f) for f in files] ]) # load weights and solver state info from specified checkpoint file. start_point = load_checkpoint( f'{args.model_save_path}/checkpoint_{index}.json', solver) print(f'checkpoint is loaded. start iteration from {start_point}') # Create monitor monitor = Monitor(args.monitor_path) monitor_loss = MonitorSeries("Training loss", monitor, interval=10) monitor_err = MonitorSeries("Training error", monitor, interval=10) monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10) monitor_verr = MonitorSeries("Validation error", monitor, interval=1) monitor_vtime = MonitorTimeElapsed("Validation time", monitor, interval=1) # Data Iterator # If the data does not exist, it will try to download it from the server # and prepare it. When executing multiple processes on the same host, it is # necessary to execute initial data preparation by the representative # process (rank is 0) on the host. # Download dataset by rank-0 process if single_or_rankzero(): rng = np.random.RandomState(mpi_rank) _, tdata = data_iterator(args.batch_size, True, rng) vsource, vdata = data_iterator(bs_valid, False) # Wait for data to be prepared without watchdog if comm: comm.barrier() # Prepare dataset for remaining process if not single_or_rankzero(): rng = np.random.RandomState(mpi_rank) _, tdata = data_iterator(args.batch_size, True, rng) vsource, vdata = data_iterator(bs_valid, False) # Training-loop ve = nn.Variable() for i in range(start_point // n_devices, args.epochs * iter_per_epoch): # Validation if i % iter_per_epoch == 0: ve_local = 0. k = 0 idx = np.random.permutation(n_valid_samples) val_images = vsource.images[idx] val_labels = vsource.labels[idx] for j in range(int(n_valid_samples / n_devices * mpi_rank), int(n_valid_samples / n_devices * (mpi_rank + 1)), bs_valid): image = val_images[j:j + bs_valid] label = val_labels[j:j + bs_valid] if len(image ) != bs_valid: # note that smaller batch is ignored continue image_valid.d = image label_valid.d = label error_valid.forward(clear_buffer=True) ve_local += error_valid.d.copy() k += 1 ve_local /= k ve.d = ve_local if comm: comm.all_reduce(ve.data, division=True, inplace=True) # Monitoring error and elapsed time if single_or_rankzero(): monitor_verr.add(i * n_devices, ve.d.copy()) monitor_vtime.add(i * n_devices) # Save model if single_or_rankzero(): if i % (args.model_save_interval // n_devices) == 0: iter = i * n_devices nn.save_parameters( os.path.join(args.model_save_path, 'params_%06d.h5' % iter)) if args.use_latest_checkpoint: save_checkpoint(args.model_save_path, iter, solver) # Forward/Zerograd image, label = tdata.next() image_train.d = image label_train.d = label loss_error_train.forward(clear_no_need_grad=True) solver.zero_grad() # Backward/AllReduce backward_and_all_reduce( loss_error_train, comm, with_all_reduce_callback=args.with_all_reduce_callback) # Solvers update solver.update() # Linear Warmup if i <= warmup_iter: lr = base_lr + warmup_slope * i solver.set_learning_rate(lr) # Monitoring loss, error and elapsed time if single_or_rankzero(): monitor_loss.add(i * n_devices, loss_train.d.copy()) monitor_err.add(i * n_devices, error_train.d.copy()) monitor_time.add(i * n_devices) # Save nnp last epoch if single_or_rankzero(): runtime_contents = { 'networks': [{ 'name': 'Validation', 'batch_size': args.batch_size, 'outputs': { 'y': pred_valid }, 'names': { 'x': image_valid } }], 'executors': [{ 'name': 'Runtime', 'network': 'Validation', 'data': ['x'], 'output': ['y'] }] } iter = args.epochs * iter_per_epoch nn.save_parameters( os.path.join(args.model_save_path, 'params_%06d.h5' % iter)) nnabla.utils.save.save( os.path.join(args.model_save_path, f'{args.net}_result.nnp'), runtime_contents) if comm: comm.barrier()