def train(): args = get_args() # Get context. from nnabla.contrib.context import extension_context extension_module = args.context if args.context is None: extension_module = 'cpu' logger.info("Running in %s" % extension_module) ctx = extension_context(extension_module, device_id=args.device_id) nn.set_default_context(ctx) # Create CNN network for both training and testing. if args.net == "cifar10_resnet23_prediction": model_prediction = cifar10_resnet23_prediction # TRAIN maps = 64 data_iterator = data_iterator_cifar10 c = 3 h = w = 32 n_train = 50000 n_valid = 10000 # Create input variables. image = nn.Variable([args.batch_size, c, h, w]) label = nn.Variable([args.batch_size, 1]) # Create model_prediction graph. pred = model_prediction(image, maps=maps, test=False) pred.persistent = True # Create loss function. loss = F.mean(F.softmax_cross_entropy(pred, label)) # SSL Regularization loss += ssl_regularization(nn.get_parameters(), args.filter_decay, args.channel_decay) # TEST # Create input variables. vimage = nn.Variable([args.batch_size, c, h, w]) vlabel = nn.Variable([args.batch_size, 1]) # Create predition graph. vpred = model_prediction(vimage, maps=maps, test=True) # Create Solver. solver = S.Adam(args.learning_rate) solver.set_parameters(nn.get_parameters()) # Create monitor. from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed monitor = Monitor(args.monitor_path) monitor_loss = MonitorSeries("Training loss", monitor, interval=10) monitor_err = MonitorSeries("Training error", monitor, interval=10) monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100) monitor_verr = MonitorSeries("Test error", monitor, interval=1) # Initialize DataIterator data = data_iterator(args.batch_size, True) vdata = data_iterator(args.batch_size, False) best_ve = 1.0 ve = 1.0 # Training loop. for i in range(args.max_iter): if i % args.val_interval == 0: # Validation ve = 0.0 for j in range(int(n_valid / args.batch_size)): vimage.d, vlabel.d = vdata.next() vpred.forward(clear_buffer=True) ve += categorical_error(vpred.d, vlabel.d) ve /= int(n_valid / args.batch_size) monitor_verr.add(i, ve) if ve < best_ve: nn.save_parameters( os.path.join(args.model_save_path, 'params_%06d.h5' % i)) best_ve = ve # Training forward image.d, label.d = data.next() solver.zero_grad() loss.forward(clear_no_need_grad=True) loss.backward(clear_buffer=True) solver.weight_decay(args.weight_decay) solver.update() e = categorical_error(pred.d, label.d) monitor_loss.add(i, loss.d.copy()) monitor_err.add(i, e) monitor_time.add(i) ve = 0.0 for j in range(int(n_valid / args.batch_size)): vimage.d, vlabel.d = vdata.next() vpred.forward(clear_buffer=True) ve += categorical_error(vpred.d, vlabel.d) ve /= int(n_valid / args.batch_size) monitor_verr.add(i, ve) parameter_file = os.path.join(args.model_save_path, 'params_{:06}.h5'.format(args.max_iter)) nn.save_parameters(parameter_file)
def classification_svd(): args = get_args() # Get context. from nnabla.ext_utils import get_extension_context logger.info("Running in %s" % args.context) ctx = get_extension_context(args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) # Create CNN network for both training and testing. mnist_cnn_prediction = mnist_lenet_prediction_slim # TRAIN reference = "reference" slim = "slim" rrate = 0.5 # reduction rate # Create input variables. image = nn.Variable([args.batch_size, 1, 28, 28]) label = nn.Variable([args.batch_size, 1]) # Create `reference` and "slim" prediction graph. model_load_path = args.model_load_path pred = mnist_cnn_prediction(image, scope=slim, rrate=rrate, test=False) pred.persistent = True # Decompose and set parameters decompose_network_and_set_params(model_load_path, reference, slim, rrate) loss = F.mean(F.softmax_cross_entropy(pred, label)) # TEST # Create input variables. vimage = nn.Variable([args.batch_size, 1, 28, 28]) vlabel = nn.Variable([args.batch_size, 1]) # Create reference prediction graph. vpred = mnist_cnn_prediction(vimage, scope=slim, rrate=rrate, test=True) # Create Solver. solver = S.Adam(args.learning_rate) with nn.parameter_scope(slim): solver.set_parameters(nn.get_parameters()) # Create monitor. from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed monitor = Monitor(args.monitor_path) monitor_loss = MonitorSeries("Training loss", monitor, interval=10) monitor_err = MonitorSeries("Training error", monitor, interval=10) monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100) monitor_verr = MonitorSeries("Test error", monitor, interval=10) # Initialize DataIterator for MNIST. data = data_iterator_mnist(args.batch_size, True) vdata = data_iterator_mnist(args.batch_size, False) best_ve = 1.0 # Training loop. for i in range(args.max_iter): if i % args.val_interval == 0: # Validation ve = 0.0 for j in range(args.val_iter): vimage.d, vlabel.d = vdata.next() vpred.forward(clear_buffer=True) ve += categorical_error(vpred.d, vlabel.d) monitor_verr.add(i, ve / args.val_iter) if ve < best_ve: nn.save_parameters( os.path.join(args.model_save_path, 'params_%06d.h5' % i)) best_ve = ve # Training forward image.d, label.d = data.next() solver.zero_grad() loss.forward(clear_no_need_grad=True) loss.backward(clear_buffer=True) solver.weight_decay(args.weight_decay) solver.update() e = categorical_error(pred.d, label.d) monitor_loss.add(i, loss.d.copy()) monitor_err.add(i, e) monitor_time.add(i) ve = 0.0 for j in range(args.val_iter): vimage.d, vlabel.d = vdata.next() vpred.forward(clear_buffer=True) ve += categorical_error(vpred.d, vlabel.d) monitor_verr.add(i, ve / args.val_iter) parameter_file = os.path.join(args.model_save_path, 'params_{:06}.h5'.format(args.max_iter)) nn.save_parameters(parameter_file)
def train(): """ Main script. Steps: * Parse command line arguments. * Instantiate a communicator and set parameter variables. * Specify contexts for computation. * Initialize DataIterator. * Construct a computation graph for training and one for validation. * Initialize solver and set parameter variables to that. * Create monitor instances for saving and displaying training stats. * Training loop * Computate error rate for validation data (periodically) * Get a next minibatch. * Execute forwardprop * Set parameter gradients zero * Execute backprop. * Solver updates parameters by using gradients computed by backprop. * Compute training error """ # Parse args args = get_args() n_train_samples = 50000 bs_valid = args.batch_size extension_module = args.context ctx = get_extension_context(extension_module, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) if args.net == "cifar10_resnet23": prediction = functools.partial(resnet23_prediction, ncls=10, nmaps=64, act=F.relu) data_iterator = data_iterator_cifar10 if args.net == "cifar100_resnet23": prediction = functools.partial(resnet23_prediction, ncls=100, nmaps=384, act=F.elu) data_iterator = data_iterator_cifar100 # Create training graphs test = False image_train = nn.Variable((args.batch_size, 3, 32, 32)) label_train = nn.Variable((args.batch_size, 1)) pred_train = prediction(image_train, test) loss_train = loss_function(pred_train, label_train) input_image_train = {"image": image_train, "label": label_train} # Create validation graph test = True image_valid = nn.Variable((bs_valid, 3, 32, 32)) pred_valid = prediction(image_valid, test) input_image_valid = {"image": image_valid} # Solvers solver = S.Adam() solver.set_parameters(nn.get_parameters()) # Create monitor from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed monitor = Monitor(args.monitor_path) monitor_loss = MonitorSeries("Training loss", monitor, interval=10) monitor_err = MonitorSeries("Training error", monitor, interval=10) monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10) monitor_verr = MonitorSeries("Test error", monitor, interval=1) # Data Iterator tdata = data_iterator(args.batch_size, True) vdata = data_iterator(args.batch_size, False) # Training-loop for i in range(args.max_iter): # Validation if i % int(n_train_samples / args.batch_size) == 0: ve = 0. for j in range(args.val_iter): image, label = vdata.next() input_image_valid["image"].d = image pred_valid.forward() ve += categorical_error(pred_valid.d, label) ve /= args.val_iter monitor_verr.add(i, ve) if int(i % args.model_save_interval) == 0: nn.save_parameters( os.path.join(args.model_save_path, 'params_%06d.h5' % i)) # Forward/Zerograd/Backward image, label = tdata.next() input_image_train["image"].d = image input_image_train["label"].d = label loss_train.forward() solver.zero_grad() loss_train.backward() # Solvers update solver.update() e = categorical_error(pred_train.d, input_image_train["label"].d) monitor_loss.add(i, loss_train.d.copy()) monitor_err.add(i, e) monitor_time.add(i) nn.save_parameters( os.path.join(args.model_save_path, 'params_%06d.h5' % (args.max_iter)))
def train(): """ Naive Multi-Device Training NOTE: the communicator exposes low-level interfaces * Parse command line arguments. * Instantiate a communicator and set parameter variables. * Specify contexts for computation. * Initialize DataIterator. * Construct a computation graph for training and one for validation. * Initialize solver and set parameter variables to that. * Create monitor instances for saving and displaying training stats. * Training loop * Computate error rate for validation data (periodically) * Get a next minibatch. * Execute forwardprop * Set parameter gradients zero * Execute backprop. * Solver updates parameters by using gradients computed by backprop. * Compute training error """ # Parse args args = get_args() n_train_samples = 50000 bs_valid = args.batch_size rng = np.random.RandomState(313) if args.net == "cifar10_resnet23": prediction = functools.partial(resnet23_prediction, rng=rng, ncls=10, nmaps=64, act=F.relu) data_iterator = data_iterator_cifar10 if args.net == "cifar100_resnet23": prediction = functools.partial(resnet23_prediction, rng=rng, ncls=100, nmaps=384, act=F.elu) data_iterator = data_iterator_cifar100 # Communicator and Context extension_module = "cuda.cudnn" ctx = extension_context(extension_module) comm = C.MultiProcessDataParalellCommunicator(ctx) comm.init() n_devices = comm.size mpi_rank = comm.rank mpi_local_rank = comm.local_rank device_id = mpi_local_rank ctx = extension_context(extension_module, device_id=device_id) nn.set_default_context(ctx) # Create training graphs test = False image_train = nn.Variable((args.batch_size, 3, 32, 32)) label_train = nn.Variable((args.batch_size, 1)) pred_train = prediction(image_train, test) loss_train = loss_function(pred_train, label_train) input_image_train = {"image": image_train, "label": label_train} # add parameters to communicator comm.add_context_and_parameters((ctx, nn.get_parameters())) # Create validation graph test = True image_valid = nn.Variable((bs_valid, 3, 32, 32)) pred_valid = prediction(image_valid, test) input_image_valid = {"image": image_valid} # Solvers solver = S.Adam() solver.set_parameters(nn.get_parameters()) base_lr = args.learning_rate warmup_iter = int( 1. * n_train_samples / args.batch_size / n_devices) * args.warmup_epoch warmup_slope = base_lr * (n_devices - 1) / warmup_iter solver.set_learning_rate(base_lr) # Create monitor from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed monitor = Monitor(args.monitor_path) monitor_loss = MonitorSeries("Training loss", monitor, interval=10) monitor_err = MonitorSeries("Training error", monitor, interval=10) monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10) monitor_verr = MonitorSeries("Test error", monitor, interval=10) # Data Iterator rng = np.random.RandomState(device_id) tdata = data_iterator(args.batch_size, True, rng) vdata = data_iterator(args.batch_size, False) # Training-loop for i in range(int(args.max_iter / n_devices)): # Validation if device_id == 0: if i % int(n_train_samples / args.batch_size / n_devices) == 0: ve = 0. for j in range(args.val_iter): image, label = vdata.next() input_image_valid["image"].d = image pred_valid.forward() ve += categorical_error(pred_valid.d, label) ve /= args.val_iter monitor_verr.add(i * n_devices, ve) if i % int(args.model_save_interval / n_devices) == 0: nn.save_parameters( os.path.join(args.model_save_path, 'params_%06d.h5' % i)) # Forward/Zerograd/Backward image, label = tdata.next() input_image_train["image"].d = image input_image_train["label"].d = label loss_train.forward() solver.zero_grad() loss_train.backward() # Allreduce comm.allreduce(division=False, inplace=False) # Solvers update solver.update() # Linear Warmup if i <= warmup_iter: lr = base_lr + warmup_slope * i solver.set_learning_rate(lr) if device_id == 0: e = categorical_error(pred_train.d, input_image_train["label"].d) monitor_loss.add(i * n_devices, loss_train.d.copy()) monitor_err.add(i * n_devices, e) monitor_time.add(i * n_devices) if device_id == 0: nn.save_parameters( os.path.join(args.model_save_path, 'params_%06d.h5' % (args.max_iter / n_devices)))
def train(args): # Get context. from nnabla.ext_utils import get_extension_context logger.info("Running in %s" % args.context) ctx = get_extension_context(args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) # Create CNN network for both training and testing. if args.net == "cifar10_resnet23_prediction": model_prediction = cifar10_resnet23_prediction elif args.net == 'cifar10_binary_connect_resnet23_prediction': model_prediction = cifar10_binary_connect_resnet23_prediction elif args.net == 'cifar10_binary_net_resnet23_prediction': model_prediction = cifar10_binary_net_resnet23_prediction elif args.net == 'cifar10_binary_weight_resnet23_prediction': model_prediction = cifar10_binary_weight_resnet23_prediction elif args.net == 'cifar10_fp_connect_resnet23_prediction': model_prediction = functools.partial( cifar10_fp_connect_resnet23_prediction, n=args.bit_width, delta=args.delta) elif args.net == 'cifar10_fp_net_resnet23_prediction': model_prediction = functools.partial( cifar10_fp_net_resnet23_prediction, n=args.bit_width, delta=args.delta) elif args.net == 'cifar10_pow2_connect_resnet23_prediction': model_prediction = functools.partial( cifar10_pow2_connect_resnet23_prediction, n=args.bit_width, m=args.upper_bound) elif args.net == 'cifar10_pow2_net_resnet23_prediction': model_prediction = functools.partial( cifar10_pow2_net_resnet23_prediction, n=args.bit_width, m=args.upper_bound) elif args.net == 'cifar10_inq_resnet23_prediction': model_prediction = functools.partial(cifar10_inq_resnet23_prediction, num_bits=args.bit_width) elif args.net == 'cifar10_min_max_resnet23_prediction': model_prediction = functools.partial( cifar10_min_max_resnet23_prediction, ql_min=args.ql_min, ql_max=args.ql_max, p_min_max=args.p_min_max, a_min_max=args.a_min_max, a_ema=args.a_ema, ste_fine_grained=args.ste_fine_grained) # TRAIN maps = 64 data_iterator = data_iterator_cifar10 c = 3 h = w = 32 n_train = 50000 n_valid = 10000 # Create input variables. image = nn.Variable([args.batch_size, c, h, w]) label = nn.Variable([args.batch_size, 1]) # Create model_prediction graph. pred = model_prediction(image, maps=maps, test=False) pred.persistent = True # Create loss function. loss = F.mean(F.softmax_cross_entropy(pred, label)) # TEST # Create input variables. vimage = nn.Variable([args.batch_size, c, h, w]) vlabel = nn.Variable([args.batch_size, 1]) # Create prediction graph. vpred = model_prediction(vimage, maps=maps, test=True) # Create Solver. solver = S.Adam(args.learning_rate) solver.set_parameters(nn.get_parameters()) # Create monitor. from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed monitor = Monitor(args.monitor_path) monitor_loss = MonitorSeries("Training loss", monitor, interval=10) monitor_err = MonitorSeries("Training error", monitor, interval=10) monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100) monitor_verr = MonitorSeries("Test error", monitor, interval=1) # Initialize DataIterator data = data_iterator(args.batch_size, True) vdata = data_iterator(args.batch_size, False) best_ve = 1.0 ve = 1.0 # Training loop. for i in range(args.max_iter): if i % args.val_interval == 0: # Validation ve = 0.0 for j in range(int(n_valid / args.batch_size)): vimage.d, vlabel.d = vdata.next() vpred.forward(clear_buffer=True) ve += categorical_error(vpred.d, vlabel.d) ve /= int(n_valid / args.batch_size) monitor_verr.add(i, ve) if ve < best_ve: nn.save_parameters( os.path.join(args.model_save_path, 'params_%06d.h5' % i)) best_ve = ve # Training forward image.d, label.d = data.next() solver.zero_grad() loss.forward(clear_no_need_grad=True) loss.backward(clear_buffer=True) solver.weight_decay(args.weight_decay) solver.update() e = categorical_error(pred.d, label.d) monitor_loss.add(i, loss.d.copy()) monitor_err.add(i, e) monitor_time.add(i) ve = 0.0 for j in range(int(n_valid / args.batch_size)): vimage.d, vlabel.d = vdata.next() vpred.forward(clear_buffer=True) ve += categorical_error(vpred.d, vlabel.d) ve /= int(n_valid / args.batch_size) monitor_verr.add(i, ve) parameter_file = os.path.join(args.model_save_path, 'params_{:06}.h5'.format(args.max_iter)) nn.save_parameters(parameter_file)
def distil(): args = get_args() # Get context. from nnabla.ext_utils import get_extension_context logger.info("Running in %s" % args.context) ctx = get_extension_context(args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) # Create CNN network for both training and testing. if args.net == "cifar10_resnet23_prediction": model_prediction = cifar10_resnet23_prediction data_iterator = data_iterator_cifar10 c = 3 h = w = 32 n_train = 50000 n_valid = 10000 # TRAIN teacher = "teacher" student = "student" maps = args.maps rrate = args.reduction_rate # Create input variables. image = nn.Variable([args.batch_size, c, h, w]) image.persistent = True # not clear the intermediate buffer re-used label = nn.Variable([args.batch_size, 1]) label.persistent = True # not clear the intermediate buffer re-used # Create `teacher` and "student" prediction graph. model_load_path = args.model_load_path nn.load_parameters(model_load_path) pred_label = model_prediction(image, net=teacher, maps=maps, test=not args.use_batch) pred_label.need_grad = False # no need backward through teacher graph pred = model_prediction(image, net=student, maps=int(maps * (1. - rrate)), test=False) pred.persistent = True # not clear the intermediate buffer used loss_ce = F.mean(F.softmax_cross_entropy(pred, label)) loss_ce_soft = ce_soft(pred, pred_label) loss = args.weight_ce * loss_ce + args.weight_ce_soft * loss_ce_soft # TEST # Create input variables. vimage = nn.Variable([args.batch_size, c, h, w]) vlabel = nn.Variable([args.batch_size, 1]) # Create teacher prediction graph. vpred = model_prediction(vimage, net=student, maps=int(maps * (1. - rrate)), test=True) # Create Solver. solver = S.Adam(args.learning_rate) with nn.parameter_scope(student): solver.set_parameters(nn.get_parameters()) # Create monitor. from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed monitor = Monitor(args.monitor_path) monitor_loss = MonitorSeries("Training loss", monitor, interval=10) monitor_err = MonitorSeries("Training error", monitor, interval=10) monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100) monitor_verr = MonitorSeries("Test error", monitor, interval=1) # Initialize DataIterator for MNIST. data = data_iterator(args.batch_size, True) vdata = data_iterator(args.batch_size, False) best_ve = 1.0 # Training loop. for i in range(args.max_iter): if i % args.val_interval == 0: # Validation ve = 0.0 for j in range(int(n_valid / args.batch_size)): vimage.d, vlabel.d = vdata[1].next() vpred.forward(clear_buffer=True) ve += categorical_error(vpred.d, vlabel.d) ve /= int(n_valid / args.batch_size) monitor_verr.add(i, ve) if ve < best_ve: nn.save_parameters( os.path.join(args.model_save_path, 'params_%06d.h5' % i)) best_ve = ve # Training forward image.d, label.d = data[1].next() solver.zero_grad() loss.forward(clear_no_need_grad=True) loss.backward(clear_buffer=True) solver.weight_decay(args.weight_decay) solver.update() e = categorical_error(pred.d, label.d) monitor_loss.add(i, loss.d.copy()) monitor_err.add(i, e) monitor_time.add(i) ve = 0.0 for j in range(int(n_valid / args.batch_size)): vimage.d, vlabel.d = vdata[1].next() vpred.forward(clear_buffer=True) ve += categorical_error(vpred.d, vlabel.d) ve /= int(n_valid / args.batch_size) monitor_verr.add(i, ve) parameter_file = os.path.join(args.model_save_path, 'params_{:06}.h5'.format(args.max_iter)) nn.save_parameters(parameter_file)
def distil(): args = get_args() # Get context. from nnabla.contrib.context import extension_context extension_module = args.context if args.context is None: extension_module = 'cpu' logger.info("Running in %s" % extension_module) ctx = extension_context(extension_module, device_id=args.device_id) nn.set_default_context(ctx) # Create CNN network for both training and testing. mnist_cnn_prediction = mnist_resnet_prediction # TRAIN teacher = "teacher" student = "student" # Create input variables. image = nn.Variable([args.batch_size, 1, 28, 28]) image.persistent = True # not clear the intermediate buffer re-used label = nn.Variable([args.batch_size, 1]) label.persistent = True # not clear the intermediate buffer re-used # Create `teacher` and "student" prediction graph. model_load_path = args.model_load_path nn.load_parameters(model_load_path) pred_label = mnist_cnn_prediction(image, net=teacher, maps=64, test=False) pred_label.need_grad = False # no need backward through teacher graph pred = mnist_cnn_prediction(image, net=student, maps=32, test=False) pred.persistent = True # not clear the intermediate buffer used loss_ce = F.mean(F.softmax_cross_entropy(pred, label)) loss_kl = kl_divergence(pred, pred_label) loss = args.weight_ce * loss_ce + args.weight_kl * loss_kl # TEST # Create input variables. vimage = nn.Variable([args.batch_size, 1, 28, 28]) vlabel = nn.Variable([args.batch_size, 1]) # Create teacher predition graph. vpred = mnist_cnn_prediction(vimage, net=student, maps=32, test=True) # Create Solver. solver = S.Adam(args.learning_rate) with nn.parameter_scope(student): solver.set_parameters(nn.get_parameters()) # Create monitor. from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed monitor = Monitor(args.monitor_path) monitor_loss = MonitorSeries("Training loss", monitor, interval=10) monitor_err = MonitorSeries("Training error", monitor, interval=10) monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100) monitor_verr = MonitorSeries("Test error", monitor, interval=10) # Initialize DataIterator for MNIST. data = data_iterator_mnist(args.batch_size, True) vdata = data_iterator_mnist(args.batch_size, False) best_ve = 1.0 # Training loop. for i in range(args.max_iter): if i % args.val_interval == 0: # Validation ve = 0.0 for j in range(args.val_iter): vimage.d, vlabel.d = vdata.next() vpred.forward(clear_buffer=True) ve += categorical_error(vpred.d, vlabel.d) monitor_verr.add(i, ve / args.val_iter) if ve < best_ve: nn.save_parameters(os.path.join( args.model_save_path, 'params_%06d.h5' % i)) best_ve = ve # Training forward image.d, label.d = data.next() solver.zero_grad() loss.forward(clear_no_need_grad=True) loss.backward(clear_buffer=True) solver.weight_decay(args.weight_decay) solver.update() e = categorical_error(pred.d, label.d) monitor_loss.add(i, loss.d.copy()) monitor_err.add(i, e) monitor_time.add(i) ve = 0.0 for j in range(args.val_iter): vimage.d, vlabel.d = vdata.next() vpred.forward(clear_buffer=True) ve += categorical_error(vpred.d, vlabel.d) monitor_verr.add(i, ve / args.val_iter) parameter_file = os.path.join( args.model_save_path, 'params_{:06}.h5'.format(args.max_iter)) nn.save_parameters(parameter_file)