def evaluate(args): # Context ctx = get_extension_context(args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) # Args latent = args.latent maps = args.maps batch_size = args.batch_size image_size = args.image_size n_classes = args.n_classes not_sn = args.not_sn # Model (Inception model) from nnp file nnp = NnpLoader(args.nnp_inception_model_load_path) x, y = get_input_and_output(nnp, args.batch_size, name=args.variable_name) if args.evaluation_metric == "IS": is_model = None compute_metric = compute_inception_score if args.evaluation_metric == "FID": di = data_iterator_imagenet(args.valid_dir, args.dirname_to_label_path, batch_size=args.batch_size, ih=args.image_size, iw=args.image_size, shuffle=True, train=False, noise=False) compute_metric = functools.partial(compute_frechet_inception_distance, di=di) # Monitor monitor = Monitor(args.monitor_path) monitor_metric = MonitorSeries("{}".format(args.evaluation_metric), monitor, interval=1) # Compute the evaluation metric for all models def cmp_func(path): return int(path.split("/")[-1].strip("params_").rstrip(".h5")) model_load_path = sorted(glob.glob("{}/*.h5".format(args.model_load_path)), key=cmp_func) \ if os.path.isdir(args.model_load_path) else \ [args.model_load_path] for path in model_load_path: # Model (SAGAN) nn.load_parameters(path) z = nn.Variable([batch_size, latent]) y_fake = nn.Variable([batch_size]) x_fake = generator(z, y_fake, maps=maps, n_classes=n_classes, test=True, sn=not_sn)\ .apply(persistent=True) # Compute the evaluation metric score = compute_metric(z, y_fake, x_fake, x, y, args) itr = cmp_func(path) monitor_metric.add(itr, score)
def train(): """ Main script. """ args = get_args() # Get context. from nnabla.ext_utils import get_extension_context extension_module = args.context if args.context is None: extension_module = 'cpu' logger.info("Running in %s" % extension_module) ctx = get_extension_context(extension_module, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) if args.tiny_mode: # We use Tiny ImageNet from Stanford CS231N class. # (Tiny ImageNet, https://tiny-imagenet.herokuapp.com/) # Tiny ImageNet consists of 200 categories, each category has 500 images # in training set. The image size is 64x64. To adapt ResNet into 64x64 # image inputs, the input image size of ResNet is set as 56x56, and # the stride in the first conv and the first max pooling are removed. # Please check README. data = data_iterator_tiny_imagenet(args.batch_size, 'train') vdata = data_iterator_tiny_imagenet(args.batch_size, 'val') num_classes = 200 else: # We use ImageNet. # (ImageNet, https://imagenet.herokuapp.com/) # ImageNet consists of 1000 categories, each category has 1280 images # in training set. The image size is various. To adapt ResNet into # 320x320 image inputs, the input image size of ResNet is set as # 224x224. We need to get tar file and create cache file(320x320 images). # Please check README. data = data_iterator_imagenet(args.batch_size, args.train_cachefile_dir) vdata = data_iterator_imagenet(args.batch_size, args.val_cachefile_dir) num_classes = 1000 t_model = get_model(args, num_classes, test=False, tiny=args.tiny_mode) t_model.pred.persistent = True # Not clearing buffer of pred in backward # TODO: need_grad should be passed to get_unlinked_variable after v1.0.3 fix. t_pred2 = t_model.pred.get_unlinked_variable() t_pred2.need_grad = False t_e = F.mean(F.top_n_error(t_pred2, t_model.label)) v_model = get_model(args, num_classes, test=True, tiny=args.tiny_mode) v_model.pred.persistent = True # Not clearing buffer of pred in forward # TODO: need_grad should be passed to get_unlinked_variable after v1.0.3 fix. v_pred2 = v_model.pred.get_unlinked_variable() v_pred2.need_grad = False v_e = F.mean(F.top_n_error(v_pred2, v_model.label)) # Save_nnp_Epoch0 contents = save_nnp({'x': v_model.image}, {'y': v_model.pred}, args.batch_size) save.save(os.path.join(args.model_save_path, 'Imagenet_result_epoch0.nnp'), contents) # Create Solver. solver = S.Momentum(args.learning_rate, 0.9) solver.set_parameters(nn.get_parameters()) start_point = 0 if args.checkpoint is not None: # load weights and solver state info from specified checkpoint file. start_point = load_checkpoint(args.checkpoint, solver) # Create monitor. import nnabla.monitor as M monitor = M.Monitor(args.monitor_path) monitor_loss = M.MonitorSeries("Training loss", monitor, interval=10) monitor_err = M.MonitorSeries("Training error", monitor, interval=10) monitor_vloss = M.MonitorSeries("Validation loss", monitor, interval=10) monitor_verr = M.MonitorSeries("Validation error", monitor, interval=10) monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=10) monitor_vtime = M.MonitorTimeElapsed("Validation time", monitor, interval=10) # Training loop. for i in range(start_point, args.max_iter): # Save parameters if i % args.model_save_interval == 0: # save checkpoint file save_checkpoint(args.model_save_path, i, solver) # Validation if i % args.val_interval == 0 and i != 0: # Clear all intermediate memory to save memory. # t_model.loss.clear_recursive() l = 0.0 e = 0.0 for j in range(args.val_iter): images, labels = vdata.next() v_model.image.d = images v_model.label.d = labels v_model.image.data.cast(np.uint8, ctx) v_model.label.data.cast(np.int32, ctx) v_model.loss.forward(clear_buffer=True) v_e.forward(clear_buffer=True) l += v_model.loss.d e += v_e.d monitor_vloss.add(i, l / args.val_iter) monitor_verr.add(i, e / args.val_iter) monitor_vtime.add(i) # Clear all intermediate memory to save memory. # v_model.loss.clear_recursive() # Training l = 0.0 e = 0.0 solver.zero_grad() def accumulate_error(l, e, t_model, t_e): l += t_model.loss.d e += t_e.d return l, e # Gradient accumulation loop for j in range(args.accum_grad): images, labels = data.next() t_model.image.d = images t_model.label.d = labels t_model.image.data.cast(np.uint8, ctx) t_model.label.data.cast(np.int32, ctx) t_model.loss.forward(clear_no_need_grad=True) t_model.loss.backward(clear_buffer=True) # Accumulating gradients t_e.forward(clear_buffer=True) l, e = accumulate_error(l, e, t_model, t_e) solver.weight_decay(args.weight_decay) solver.update() monitor_loss.add(i, l / args.accum_grad) monitor_err.add(i, e / args.accum_grad) monitor_time.add(i) # Learning rate decay at scheduled iter if i in args.learning_rate_decay_at: solver.set_learning_rate(solver.learning_rate() * 0.1) nn.save_parameters( os.path.join(args.model_save_path, 'param_%06d.h5' % args.max_iter)) # Save_nnp contents = save_nnp({'x': v_model.image}, {'y': v_model.pred}, args.batch_size) save.save(os.path.join(args.model_save_path, 'Imagenet_result.nnp'), contents)
def match(args): # Context extension_module = "cudnn" ctx = get_extension_context(extension_module, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) # Args latent = args.latent maps = args.maps batch_size = 1 image_size = args.image_size n_classes = args.n_classes not_sn = args.not_sn threshold = args.truncation_threshold # Model (SAGAN) nn.load_parameters(args.model_load_path) z = nn.Variable([batch_size, latent]) y_fake = nn.Variable([batch_size]) x_fake = generator(z, y_fake, maps=maps, n_classes=n_classes, test=True, sn=not_sn)\ .apply(persistent=True) # Model (Inception model) from nnp file nnp = NnpLoader(args.nnp_inception_model_load_path) x, h = get_input_and_output(nnp, batch_size, args.variable_name) # DataIterator for a given class_id di = data_iterator_imagenet(args.train_dir, args.dirname_to_label_path, batch_size=batch_size, n_classes=args.n_classes, noise=False, class_id=args.class_id) # Monitor monitor = Monitor(args.monitor_path) name = "Matched Image {}".format(args.class_id) monitor_image = MonitorImage(name, monitor, interval=1, num_images=batch_size, normalize_method=lambda x: (x + 1.) / 2. * 255.) name = "Matched Image Tile {}".format(args.class_id) monitor_image_tile = MonitorImageTile(name, monitor, interval=1, num_images=batch_size + args.top_n, normalize_method=lambda x: (x + 1.) / 2. * 255.) # Generate and p(h|x).forward # generate z_data = resample(batch_size, latent, threshold) y_data = generate_one_class(args.class_id, batch_size) z.d = z_data y_fake.d = y_data x_fake.forward(clear_buffer=True) # p(h|x).forward x_fake_d = x_fake.d.copy() x_fake_d = preprocess( x_fake_d, (args.image_size, args.image_size), args.nnp_preprocess) x.d = x_fake_d h.forward(clear_buffer=True) h_fake_d = h.d.copy() # Feature matching norm2_list = [] x_data_list = [] x_data_list.append(x_fake.d) for i in range(di.size): # forward for real data x_d, _ = di.next() x_data_list.append(x_d) x_d = preprocess( x_d, (args.image_size, args.image_size), args.nnp_preprocess) x.d = x_d h.forward(clear_buffer=True) h_real_d = h.d.copy() # norm computation axis = tuple(np.arange(1, len(h.shape)).tolist()) norm2 = np.sum((h_real_d - h_fake_d) ** 2.0, axis=axis) norm2_list.append(norm2) # Save top-n images argmins = np.argsort(norm2_list) for i in range(args.top_n): monitor_image.add(i, x_data_list[i]) matched_images = np.concatenate(x_data_list) monitor_image_tile.add(0, matched_images)
def train(args): # Communicator and Context extension_module = "cudnn" ctx = get_extension_context(extension_module, type_config=args.type_config) comm = C.MultiProcessDataParalellCommunicator(ctx) comm.init() n_devices = comm.size mpi_rank = comm.rank device_id = comm.local_rank ctx.device_id = str(device_id) nn.set_default_context(ctx) # Args latent = args.latent maps = args.maps batch_size = args.batch_size image_size = args.image_size n_classes = args.n_classes not_sn = args.not_sn # Model # workaround to start with the same weights in the distributed system. np.random.seed(412) # generator loss z = nn.Variable([batch_size, latent]) y_fake = nn.Variable([batch_size]) x_fake = generator(z, y_fake, maps=maps, n_classes=n_classes, sn=not_sn).apply(persistent=True) p_fake = discriminator(x_fake, y_fake, maps=maps // 16, n_classes=n_classes, sn=not_sn) loss_gen = gan_loss(p_fake) # discriminator loss y_real = nn.Variable([batch_size]) x_real = nn.Variable([batch_size, 3, image_size, image_size]) p_real = discriminator(x_real, y_real, maps=maps // 16, n_classes=n_classes, sn=not_sn) loss_dis = gan_loss(p_fake, p_real) # generator with fixed value for test z_test = nn.Variable.from_numpy_array(np.random.randn(batch_size, latent)) y_test = nn.Variable.from_numpy_array( generate_random_class(n_classes, batch_size)) x_test = generator(z_test, y_test, maps=maps, n_classes=n_classes, test=True, sn=not_sn) # Solver solver_gen = S.Adam(args.lrg, args.beta1, args.beta2) solver_dis = S.Adam(args.lrd, args.beta1, args.beta2) with nn.parameter_scope("generator"): params_gen = nn.get_parameters() solver_gen.set_parameters(params_gen) with nn.parameter_scope("discriminator"): params_dis = nn.get_parameters() solver_dis.set_parameters(params_dis) # Monitor if comm.rank == 0: monitor = Monitor(args.monitor_path) monitor_loss_gen = MonitorSeries( "Generator Loss", monitor, interval=10) monitor_loss_dis = MonitorSeries( "Discriminator Loss", monitor, interval=10) monitor_time = MonitorTimeElapsed( "Training Time", monitor, interval=10) monitor_image_tile_train = MonitorImageTile("Image Tile Train", monitor, num_images=args.batch_size, interval=1, normalize_method=normalize_method) monitor_image_tile_test = MonitorImageTile("Image Tile Test", monitor, num_images=args.batch_size, interval=1, normalize_method=normalize_method) # DataIterator rng = np.random.RandomState(device_id) di = data_iterator_imagenet(args.train_dir, args.dirname_to_label_path, args.batch_size, n_classes=args.n_classes, rng=rng) # Train loop for i in range(args.max_iter): # Train discriminator x_fake.need_grad = False # no need for discriminator backward solver_dis.zero_grad() for _ in range(args.accum_grad): # feed x_real and y_real x_data, y_data = di.next() x_real.d, y_real.d = x_data, y_data.flatten() # feed z and y_fake z_data = np.random.randn(args.batch_size, args.latent) y_data = generate_random_class(args.n_classes, args.batch_size) z.d, y_fake.d = z_data, y_data loss_dis.forward(clear_no_need_grad=True) loss_dis.backward( 1.0 / (args.accum_grad * n_devices), clear_buffer=True) comm.all_reduce([v.grad for v in params_dis.values()]) solver_dis.update() # Train genrator x_fake.need_grad = True # need for generator backward solver_gen.zero_grad() for _ in range(args.accum_grad): z_data = np.random.randn(args.batch_size, args.latent) y_data = generate_random_class(args.n_classes, args.batch_size) z.d, y_fake.d = z_data, y_data loss_gen.forward(clear_no_need_grad=True) loss_gen.backward( 1.0 / (args.accum_grad * n_devices), clear_buffer=True) comm.all_reduce([v.grad for v in params_gen.values()]) solver_gen.update() # Synchronize by averaging the weights over devices using allreduce if i % args.sync_weight_every_itr == 0: weights = [v.data for v in nn.get_parameters().values()] comm.all_reduce(weights, division=True, inplace=True) # Save model and image if i % args.save_interval == 0 and comm.rank == 0: x_test.forward(clear_buffer=True) nn.save_parameters(os.path.join( args.monitor_path, "params_{}.h5".format(i))) monitor_image_tile_train.add(i, x_fake.d) monitor_image_tile_test.add(i, x_test.d) # Monitor if comm.rank == 0: monitor_loss_gen.add(i, loss_gen.d.copy()) monitor_loss_dis.add(i, loss_dis.d.copy()) monitor_time.add(i) if comm.rank == 0: x_test.forward(clear_buffer=True) nn.save_parameters(os.path.join( args.monitor_path, "params_{}.h5".format(i))) monitor_image_tile_train.add(i, x_fake.d) monitor_image_tile_test.add(i, x_test.d)
def train(): """ Main script. Naive Multi-Device Training NOTE: the communicator exposes low-level interfaces * Parse command line arguments. * Instantiate a communicator and set parameter variables. * Specify contexts for computation. * Initialize DataIterator. * Construct a computation graph for training and one for validation. * Initialize solver and set parameter variables to that. * Create monitor instances for saving and displaying training stats. * Training loop * Computate error rate for validation data (periodically) * Get a next minibatch. * Execute forwardprop * Set parameter gradients zero * Execute backprop. * Inplace allreduce (THIS IS THE MAIN difference from a single device training) * Solver updates parameters by using gradients computed by backprop. * Compute training error """ args = get_args() if args.tiny_mode: n_train_samples = 100000 else: n_train_samples = 1282167 # Communicator and Context from nnabla.ext_utils import get_extension_context extension_module = "cudnn" ctx = get_extension_context(extension_module, type_config=args.type_config) comm = C.MultiProcessDataParalellCommunicator(ctx) comm.init() n_devices = comm.size mpi_rank = comm.rank device_id = mpi_rank ctx.device_id = str(device_id) nn.set_default_context(ctx) # workarond to start with the same parameters. rng = np.random.RandomState(device_id) if args.tiny_mode: # We use Tiny ImageNet from Stanford CS231N class. # (Tiny ImageNet, https://tiny-imagenet.herokuapp.com/) # Tiny ImageNet consists of 200 categories, each category has 500 images # in training set. The image size is 64x64. To adapt ResNet into 64x64 # image inputs, the input image size of ResNet is set as 56x56, and # the stride in the first conv and the first max pooling are removed. # Please check README. data = data_iterator_tiny_imagenet(args.batch_size, 'train') vdata = data_iterator_tiny_imagenet(args.batch_size, 'val') num_classes = 200 else: # We use ImageNet. # (ImageNet, https://imagenet.herokuapp.com/) # ImageNet consists of 1000 categories, each category has 1280 images # in training set. The image size is various. To adapt ResNet into # 320x320 image inputs, the input image size of ResNet is set as # 224x224. We need to get tar file and create cache file(320x320 images). # Please check README. data = data_iterator_imagenet(args.batch_size, args.train_cachefile_dir, rng=rng) vdata = data_iterator_imagenet(args.batch_size, args.val_cachefile_dir) vdata = vdata.slice(rng=None, num_of_slices=n_devices, slice_pos=device_id) num_classes = 1000 # Workaround to start with the same initialized weights for all workers. np.random.seed(313) t_model = get_model(args, num_classes, test=False, tiny=args.tiny_mode) t_model.pred.persistent = True # Not clearing buffer of pred in backward t_pred2 = t_model.pred.unlinked() t_e = F.mean(F.top_n_error(t_pred2, t_model.label)) v_model = get_model(args, num_classes, test=True, tiny=args.tiny_mode) v_model.pred.persistent = True # Not clearing buffer of pred in forward v_pred2 = v_model.pred.unlinked() v_e = F.mean(F.top_n_error(v_pred2, v_model.label)) # Add parameters to communicator. comm.add_context_and_parameters((ctx, nn.get_parameters())) # Create Solver. solver = S.Momentum(args.learning_rate, 0.9) solver.set_parameters(nn.get_parameters()) # Setting warmup. base_lr = args.learning_rate / n_devices warmup_iter = int(1. * n_train_samples / args.batch_size / args.accum_grad / n_devices) * args.warmup_epoch warmup_slope = base_lr * (n_devices - 1) / warmup_iter solver.set_learning_rate(base_lr) # Create monitor. import nnabla.monitor as M monitor = M.Monitor(args.monitor_path) monitor_loss = M.MonitorSeries("Training loss", monitor, interval=10) monitor_err = M.MonitorSeries("Training error", monitor, interval=10) monitor_vloss = M.MonitorSeries("Validation loss", monitor, interval=1) monitor_verr = M.MonitorSeries("Validation error", monitor, interval=1) monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=10) monitor_vtime = M.MonitorTimeElapsed("Validation time", monitor, interval=1) # Training loop. vl = nn.Variable() ve = nn.Variable() for i in range(int(args.max_iter / n_devices)): # Save parameters if i % (args.model_save_interval // n_devices) == 0 and device_id == 0: nn.save_parameters( os.path.join(args.model_save_path, 'param_%06d.h5' % i)) # Validation if i % (args.val_interval // n_devices) == 0 and i != 0: ve_local = 0. vl_local = 0. val_iter_local = args.val_iter // n_devices for j in range(val_iter_local): images, labels = vdata.next() v_model.image.d = images v_model.label.d = labels v_model.image.data.cast(np.uint8, ctx) v_model.label.data.cast(np.int32, ctx) v_model.loss.forward(clear_buffer=True) v_e.forward(clear_buffer=True) vl_local += v_model.loss.d.copy() ve_local += v_e.d.copy() vl_local /= val_iter_local vl.d = vl_local comm.all_reduce(vl.data, division=True, inplace=True) ve_local /= val_iter_local ve.d = ve_local comm.all_reduce(ve.data, division=True, inplace=True) if device_id == 0: monitor_vloss.add(i * n_devices, vl.d.copy()) monitor_verr.add(i * n_devices, ve.d.copy()) monitor_vtime.add(i * n_devices) # Training l = 0.0 e = 0.0 solver.zero_grad() def accumulate_error(l, e, t_model, t_e): l += t_model.loss.d e += t_e.d return l, e # Gradient accumulation loop for j in range(args.accum_grad): images, labels = data.next() if j != 0: # Update e and l according to previous results of forward # propagation. # The update of last iteration is performed # after solver update to avoid unnecessary CUDA synchronization. # This is performed after data.next() in order to overlap # the data loading and graph execution. # TODO: Move this to the bottom of the loop when prefetch # data loader is available. l, e = accumulate_error(l, e, t_model, t_e) t_model.image.d = images t_model.label.d = labels t_model.image.data.cast(np.uint8, ctx) t_model.label.data.cast(np.int32, ctx) t_model.loss.forward(clear_no_need_grad=True) t_model.loss.backward(clear_buffer=True) # Accumulating gradients t_e.forward(clear_buffer=True) # AllReduce params = [x.grad for x in nn.get_parameters().values()] comm.all_reduce(params, division=False, inplace=False) # Update solver.weight_decay(args.weight_decay) solver.update() # Accumulate errors after solver update l, e = accumulate_error(l, e, t_model, t_e) # Linear Warmup if i <= warmup_iter: lr = base_lr + warmup_slope * i solver.set_learning_rate(lr) # Synchronize by averaging the weights over devices using allreduce if (i + 1) % args.sync_weight_every_itr == 0: weights = [x.data for x in nn.get_parameters().values()] comm.all_reduce(weights, division=True, inplace=True) if device_id == 0: monitor_loss.add(i * n_devices, l / args.accum_grad) monitor_err.add(i * n_devices, e / args.accum_grad) monitor_time.add(i * n_devices) # Learning rate decay at scheduled iter if i * n_devices in args.learning_rate_decay_at: solver.set_learning_rate(solver.learning_rate() * 0.1) if device_id == 0: nn.save_parameters( os.path.join(args.model_save_path, 'param_%06d.h5' % (args.max_iter / n_devices)))
def train(): """ Main script. """ args = get_args() # Get context. from nnabla.ext_utils import get_extension_context extension_module = args.context if args.context is None: extension_module = 'cpu' logger.info("Running in %s" % extension_module) ctx = get_extension_context(extension_module, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) if args.tiny_mode: # We use Tiny ImageNet from Stanford CS231N class. # (Tiny ImageNet, https://tiny-imagenet.herokuapp.com/) # Tiny ImageNet consists of 200 categories, each category has 500 images # in training set. The image size is 64x64. To adapt ResNet into 64x64 # image inputs, the input image size of ResNet is set as 56x56, and # the stride in the first conv and the first max pooling are removed. # Please check README. data = data_iterator_tiny_imagenet(args.batch_size, 'train') vdata = data_iterator_tiny_imagenet(args.batch_size, 'val') num_classes = 200 else: # We use ImageNet. # (ImageNet, https://imagenet.herokuapp.com/) # ImageNet consists of 1000 categories, each category has 1280 images # in training set. The image size is various. To adapt ResNet into # 320x320 image inputs, the input image size of ResNet is set as # 224x224. We need to get tar file and create cache file(320x320 images). # Please check README. data = data_iterator_imagenet(args.batch_size, args.train_cachefile_dir) vdata = data_iterator_imagenet(args.batch_size, args.val_cachefile_dir) num_classes = 1000 t_model = get_model(args, num_classes, test=False, tiny=args.tiny_mode) t_model.pred.persistent = True # Not clearing buffer of pred in backward t_pred2 = t_model.pred.unlinked() t_e = F.mean(F.top_n_error(t_pred2, t_model.label)) v_model = get_model(args, num_classes, test=True, tiny=args.tiny_mode) v_model.pred.persistent = True # Not clearing buffer of pred in forward v_pred2 = v_model.pred.unlinked() v_e = F.mean(F.top_n_error(v_pred2, v_model.label)) # Create Solver. solver = S.Momentum(args.learning_rate, 0.9) solver.set_parameters(nn.get_parameters()) # Create monitor. import nnabla.monitor as M monitor = M.Monitor(args.monitor_path) monitor_loss = M.MonitorSeries("Training loss", monitor, interval=10) monitor_err = M.MonitorSeries("Training error", monitor, interval=10) monitor_vloss = M.MonitorSeries("Validation loss", monitor, interval=10) monitor_verr = M.MonitorSeries("Validation error", monitor, interval=10) monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=10) monitor_vtime = M.MonitorTimeElapsed("Validation time", monitor, interval=10) # Training loop. for i in range(args.max_iter): # Save parameters if i % args.model_save_interval == 0: nn.save_parameters( os.path.join(args.model_save_path, 'param_%06d.h5' % i)) # Validation if i % args.val_interval == 0 and i != 0: # Clear all intermediate memory to save memory. # t_model.loss.clear_recursive() l = 0.0 e = 0.0 for j in range(args.val_iter): images, labels = vdata.next() v_model.image.d = images v_model.label.d = labels v_model.image.data.cast(np.uint8, ctx) v_model.label.data.cast(np.int32, ctx) v_model.loss.forward(clear_buffer=True) v_e.forward(clear_buffer=True) l += v_model.loss.d e += v_e.d monitor_vloss.add(i, l / args.val_iter) monitor_verr.add(i, e / args.val_iter) monitor_vtime.add(i) # Clear all intermediate memory to save memory. # v_model.loss.clear_recursive() # Training l = 0.0 e = 0.0 solver.zero_grad() def accumulate_error(l, e, t_model, t_e): l += t_model.loss.d e += t_e.d return l, e # Gradient accumulation loop for j in range(args.accum_grad): images, labels = data.next() if j != 0: # Update e and l according to previous results of forward # propagation. # The update of last iteration is performed # after solver update to avoid unnecessary CUDA synchronization. # This is performed after data.next() in order to overlap # the data loading and graph execution. # TODO: Move this to the bottom of the loop when prefetch # data loader is available. l, e = accumulate_error(l, e, t_model, t_e) t_model.image.d = images t_model.label.d = labels t_model.image.data.cast(np.uint8, ctx) t_model.label.data.cast(np.int32, ctx) t_model.loss.forward(clear_no_need_grad=True) t_model.loss.backward(clear_buffer=True) # Accumulating gradients t_e.forward(clear_buffer=True) solver.weight_decay(args.weight_decay) solver.update() # Accumulate errors after solver update l, e = accumulate_error(l, e, t_model, t_e) monitor_loss.add(i, l / args.accum_grad) monitor_err.add(i, e / args.accum_grad) monitor_time.add(i) # Learning rate decay at scheduled iter if i in args.learning_rate_decay_at: solver.set_learning_rate(solver.learning_rate() * 0.1) nn.save_parameters( os.path.join(args.model_save_path, 'param_%06d.h5' % args.max_iter))