def validate(args): # load trained param file _ = nn.load_parameters(args.model_load_path) # get data iterator vdata = data_iterator_segmentation(args.val_samples, args.batch_size, args.val_dir, args.val_label_dir) # get deeplabv3plus model v_model = train.get_model(args, test=True) v_model.pred.persistent = True # Not clearing buffer of pred in forward v_pred2 = v_model.pred.unlinked() # Create monitor monitor = M.Monitor(args.monitor_path) monitor_miou = M.MonitorSeries("mean IOU", monitor, interval=1) l = 0.0 e = 0.0 vmiou = 0. # Evaluation loop for j in range(args.val_samples // args.batch_size): images, labels, masks = vdata.next() v_model.image.d = images v_model.label.d = labels v_model.mask.d = masks v_model.pred.forward(clear_buffer=True) miou = train.compute_miou(args.num_class, labels, np.argmax(v_model.pred.d, axis=1), masks) vmiou += miou print(j, miou) monitor_miou.add(0, vmiou / (args.val_samples / args.batch_size)) return vmiou / args.val_samples
def train(): """ Main script. """ args = get_args() _ = nn.load_parameters(args.pretrained_model_path) if args.fine_tune: nnabla.parameter.pop_parameter('decoder/logits/affine/conv/W') nnabla.parameter.pop_parameter('decoder/logits/affine/conv/b') n_train_samples = args.train_samples n_val_samples = args.val_samples distributed = args.distributed compute_acc = args.compute_acc if distributed: # Communicator and Context from nnabla.ext_utils import get_extension_context extension_module = "cudnn" ctx = get_extension_context( extension_module, type_config=args.type_config) comm = C.MultiProcessDataParalellCommunicator(ctx) comm.init() n_devices = comm.size mpi_rank = comm.rank device_id = mpi_rank ctx.device_id = str(device_id) nn.set_default_context(ctx) else: # Get context. from nnabla.ext_utils import get_extension_context extension_module = args.context if args.context is None: extension_module = 'cpu' logger.info("Running in %s" % extension_module) ctx = get_extension_context( extension_module, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) n_devices = 1 device_id = 0 # training data data = data_iterator_segmentation( args.train_samples, args.batch_size, args.train_dir, args.train_label_dir, target_width=args.image_width, target_height=args.image_height) # validation data vdata = data_iterator_segmentation(args.val_samples, args.batch_size, args.val_dir, args.val_label_dir, target_width=args.image_width, target_height=args.image_height) if distributed: data = data.slice( rng=None, num_of_slices=n_devices, slice_pos=device_id) vdata = vdata.slice( rng=None, num_of_slices=n_devices, slice_pos=device_id) num_classes = args.num_class # Workaround to start with the same initialized weights for all workers. np.random.seed(313) t_model = get_model( args, test=False) t_model.pred.persistent = True # Not clearing buffer of pred in backward t_pred2 = t_model.pred.unlinked() t_e = F.sum(F.top_n_error(t_pred2, t_model.label, axis=1) * t_model.mask) / F.sum(t_model.mask) v_model = get_model( args, test=True) v_model.pred.persistent = True # Not clearing buffer of pred in forward v_pred2 = v_model.pred.unlinked() v_e = F.sum(F.top_n_error(v_pred2, v_model.label, axis=1) * v_model.mask) / F.sum(t_model.mask) # Create Solver solver = S.Momentum(args.learning_rate, 0.9) solver.set_parameters(nn.get_parameters()) # Load checkpoint start_point = 0 if args.checkpoint is not None: # load weights and solver state info from specified checkpoint file. start_point = load_checkpoint(args.checkpoint, solver) # Setting warmup. base_lr = args.learning_rate / n_devices warmup_iter = int(1. * n_train_samples / args.batch_size / args.accum_grad / n_devices) * args.warmup_epoch warmup_slope = base_lr * (n_devices - 1) / warmup_iter solver.set_learning_rate(base_lr) # Create monitor import nnabla.monitor as M monitor = M.Monitor(args.monitor_path) monitor_loss = M.MonitorSeries("Training loss", monitor, interval=10) monitor_err = M.MonitorSeries("Training error", monitor, interval=10) monitor_vloss = M.MonitorSeries("Validation loss", monitor, interval=1) monitor_verr = M.MonitorSeries("Validation error", monitor, interval=1) monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=10) monitor_miou = M.MonitorSeries("mean IOU", monitor, interval=10) monitor_vtime = M.MonitorTimeElapsed( "Validation time", monitor, interval=1) # save_nnp contents = save_nnp({'x': v_model.image}, { 'y': v_model.pred}, args.batch_size) save.save(os.path.join(args.model_save_path, 'Deeplabv3plus_result_epoch0.nnp'), contents, variable_batch_size=False) # Training loop for i in range(start_point, int(args.max_iter / n_devices)): # Save parameters if i % (args.model_save_interval // n_devices) == 0 and device_id == 0: save_checkpoint(args.model_save_path, i, solver) # Validation if i % (args.val_interval // n_devices) == 0 and i != 0: vmiou_local = 0. val_iter_local = n_val_samples // args.batch_size vl_local = nn.NdArray() vl_local.zero() ve_local = nn.NdArray() ve_local.zero() for j in range(val_iter_local): images, labels, masks = vdata.next() v_model.image.d = images v_model.label.d = labels v_model.mask.d = masks v_model.image.data.cast(np.float32, ctx) v_model.label.data.cast(np.int32, ctx) v_model.loss.forward(clear_buffer=True) v_e.forward(clear_buffer=True) vl_local += v_model.loss.data ve_local += v_e.data # Mean IOU computation if compute_acc: vmiou_local += compute_miou(num_classes, labels, np.argmax(v_model.pred.d, axis=1), masks) vl_local /= val_iter_local ve_local /= val_iter_local if compute_acc: vmiou_local /= val_iter_local vmiou_ndarray = nn.NdArray.from_numpy_array( np.array(vmiou_local)) if distributed: comm.all_reduce(vl_local, division=True, inplace=True) comm.all_reduce(ve_local, division=True, inplace=True) if compute_acc: comm.all_reduce(vmiou_ndarray, division=True, inplace=True) if device_id == 0: monitor_vloss.add(i * n_devices, vl_local.data.copy()) monitor_verr.add(i * n_devices, ve_local.data.copy()) if compute_acc: monitor_miou.add(i * n_devices, vmiou_local) monitor_vtime.add(i * n_devices) # Training l = 0.0 e = 0.0 solver.zero_grad() e_acc = nn.NdArray(t_e.shape) e_acc.zero() l_acc = nn.NdArray(t_model.loss.shape) l_acc.zero() # Gradient accumulation loop for j in range(args.accum_grad): images, labels, masks = data.next() t_model.image.d = images t_model.label.d = labels t_model.mask.d = masks t_model.image.data.cast(np.float32, ctx) t_model.label.data.cast(np.int32, ctx) t_model.loss.forward(clear_no_need_grad=True) t_model.loss.backward(clear_buffer=True) # Accumulating gradients t_e.forward(clear_buffer=True) e_acc += t_e.data l_acc += t_model.loss.data # AllReduce if distributed: params = [x.grad for x in nn.get_parameters().values()] comm.all_reduce(params, division=False, inplace=False) comm.all_reduce(l_acc, division=True, inplace=True) comm.all_reduce(e_acc, division=True, inplace=True) solver.scale_grad(1./args.accum_grad) solver.weight_decay(args.weight_decay) solver.update() # Linear Warmup if i <= warmup_iter: lr = base_lr + warmup_slope * i solver.set_learning_rate(lr) if distributed: # Synchronize by averaging the weights over devices using allreduce if (i+1) % args.sync_weight_every_itr == 0: weights = [x.data for x in nn.get_parameters().values()] comm.all_reduce(weights, division=True, inplace=True) if device_id == 0: monitor_loss.add( i * n_devices, (l_acc / args.accum_grad).data.copy()) monitor_err.add( i * n_devices, (e_acc / args.accum_grad).data.copy()) monitor_time.add(i * n_devices) # Learning rate decay at scheduled iter --> changed to poly learning rate decay policy # if i in args.learning_rate_decay_at: solver.set_learning_rate(base_lr * ((1 - i / args.max_iter)**0.1)) if device_id == 0: nn.save_parameters(os.path.join(args.model_save_path, 'param_%06d.h5' % args.max_iter)) contents = save_nnp({'x': v_model.image}, { 'y': v_model.pred}, args.batch_size) save.save(os.path.join(args.model_save_path, 'Deeplabv3plus_result.nnp'), contents, variable_batch_size=False)