コード例 #1
0
def validate(args):

    # load trained param file
    _ = nn.load_parameters(args.model_load_path)

    # get data iterator
    vdata = data_iterator_segmentation(args.val_samples, args.batch_size,
                                       args.val_dir, args.val_label_dir)

    # get deeplabv3plus model
    v_model = train.get_model(args, test=True)
    v_model.pred.persistent = True  # Not clearing buffer of pred in forward
    v_pred2 = v_model.pred.unlinked()

    # Create monitor
    monitor = M.Monitor(args.monitor_path)
    monitor_miou = M.MonitorSeries("mean IOU", monitor, interval=1)

    l = 0.0
    e = 0.0
    vmiou = 0.
    # Evaluation loop
    for j in range(args.val_samples // args.batch_size):
        images, labels, masks = vdata.next()
        v_model.image.d = images
        v_model.label.d = labels
        v_model.mask.d = masks
        v_model.pred.forward(clear_buffer=True)
        miou = train.compute_miou(args.num_class, labels,
                                  np.argmax(v_model.pred.d, axis=1), masks)
        vmiou += miou
        print(j, miou)

    monitor_miou.add(0, vmiou / (args.val_samples / args.batch_size))
    return vmiou / args.val_samples
コード例 #2
0
def train():
    """
    Main script.
    """

    args = get_args()

    _ = nn.load_parameters(args.pretrained_model_path)
    if args.fine_tune:
        nnabla.parameter.pop_parameter('decoder/logits/affine/conv/W')
        nnabla.parameter.pop_parameter('decoder/logits/affine/conv/b')

    n_train_samples = args.train_samples
    n_val_samples = args.val_samples
    distributed = args.distributed
    compute_acc = args.compute_acc

    if distributed:
        # Communicator and Context
        from nnabla.ext_utils import get_extension_context
        extension_module = "cudnn"
        ctx = get_extension_context(
            extension_module, type_config=args.type_config)
        comm = C.MultiProcessDataParalellCommunicator(ctx)
        comm.init()
        n_devices = comm.size
        mpi_rank = comm.rank
        device_id = mpi_rank
        ctx.device_id = str(device_id)
        nn.set_default_context(ctx)
    else:
        # Get context.
        from nnabla.ext_utils import get_extension_context
        extension_module = args.context
        if args.context is None:
            extension_module = 'cpu'
        logger.info("Running in %s" % extension_module)
        ctx = get_extension_context(
            extension_module, device_id=args.device_id, type_config=args.type_config)
        nn.set_default_context(ctx)
        n_devices = 1
        device_id = 0

    # training data
    data = data_iterator_segmentation(
            args.train_samples, args.batch_size, args.train_dir, args.train_label_dir, target_width=args.image_width, target_height=args.image_height)
    # validation data
    vdata = data_iterator_segmentation(args.val_samples, args.batch_size, args.val_dir,
                                       args.val_label_dir, target_width=args.image_width, target_height=args.image_height)

    if distributed:
        data = data.slice(
            rng=None, num_of_slices=n_devices, slice_pos=device_id)
        vdata = vdata.slice(
            rng=None, num_of_slices=n_devices, slice_pos=device_id)
    num_classes = args.num_class

    # Workaround to start with the same initialized weights for all workers.
    np.random.seed(313)
    t_model = get_model(
        args, test=False)
    t_model.pred.persistent = True  # Not clearing buffer of pred in backward
    t_pred2 = t_model.pred.unlinked()
    t_e = F.sum(F.top_n_error(t_pred2, t_model.label, axis=1)
                * t_model.mask) / F.sum(t_model.mask)

    v_model = get_model(
        args, test=True)
    v_model.pred.persistent = True  # Not clearing buffer of pred in forward
    v_pred2 = v_model.pred.unlinked()
    v_e = F.sum(F.top_n_error(v_pred2, v_model.label, axis=1)
                * v_model.mask) / F.sum(t_model.mask)

    # Create Solver
    solver = S.Momentum(args.learning_rate, 0.9)
    solver.set_parameters(nn.get_parameters())

    # Load checkpoint
    start_point = 0
    if args.checkpoint is not None:
        # load weights and solver state info from specified checkpoint file.
        start_point = load_checkpoint(args.checkpoint, solver)

    # Setting warmup.
    base_lr = args.learning_rate / n_devices
    warmup_iter = int(1. * n_train_samples /
                      args.batch_size / args.accum_grad / n_devices) * args.warmup_epoch
    warmup_slope = base_lr * (n_devices - 1) / warmup_iter
    solver.set_learning_rate(base_lr)

    # Create monitor
    import nnabla.monitor as M
    monitor = M.Monitor(args.monitor_path)
    monitor_loss = M.MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = M.MonitorSeries("Training error", monitor, interval=10)
    monitor_vloss = M.MonitorSeries("Validation loss", monitor, interval=1)
    monitor_verr = M.MonitorSeries("Validation error", monitor, interval=1)
    monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_miou = M.MonitorSeries("mean IOU", monitor, interval=10)
    monitor_vtime = M.MonitorTimeElapsed(
        "Validation time", monitor, interval=1)

    # save_nnp
    contents = save_nnp({'x': v_model.image}, {
                        'y': v_model.pred}, args.batch_size)
    save.save(os.path.join(args.model_save_path,
                           'Deeplabv3plus_result_epoch0.nnp'), contents, variable_batch_size=False)

    # Training loop
    for i in range(start_point, int(args.max_iter / n_devices)):
        # Save parameters
        if i % (args.model_save_interval // n_devices) == 0 and device_id == 0:
            save_checkpoint(args.model_save_path, i, solver)
        # Validation
        if i % (args.val_interval // n_devices) == 0 and i != 0:
            vmiou_local = 0.
            val_iter_local = n_val_samples // args.batch_size
            vl_local = nn.NdArray()
            vl_local.zero()
            ve_local = nn.NdArray()
            ve_local.zero()
            for j in range(val_iter_local):
                images, labels, masks = vdata.next()
                v_model.image.d = images
                v_model.label.d = labels
                v_model.mask.d = masks
                v_model.image.data.cast(np.float32, ctx)
                v_model.label.data.cast(np.int32, ctx)
                v_model.loss.forward(clear_buffer=True)
                v_e.forward(clear_buffer=True)
                vl_local += v_model.loss.data
                ve_local += v_e.data
                # Mean IOU computation
                if compute_acc:
                    vmiou_local += compute_miou(num_classes, labels,
                                                np.argmax(v_model.pred.d, axis=1), masks)

            vl_local /= val_iter_local
            ve_local /= val_iter_local
            if compute_acc:
                vmiou_local /= val_iter_local
                vmiou_ndarray = nn.NdArray.from_numpy_array(
                    np.array(vmiou_local))
            if distributed:
                comm.all_reduce(vl_local, division=True, inplace=True)
                comm.all_reduce(ve_local, division=True, inplace=True)
                if compute_acc:
                    comm.all_reduce(vmiou_ndarray, division=True, inplace=True)

            if device_id == 0:
                monitor_vloss.add(i * n_devices, vl_local.data.copy())
                monitor_verr.add(i * n_devices, ve_local.data.copy())
                if compute_acc:
                    monitor_miou.add(i * n_devices, vmiou_local)
                monitor_vtime.add(i * n_devices)

        # Training
        l = 0.0
        e = 0.0
        solver.zero_grad()

        e_acc = nn.NdArray(t_e.shape)
        e_acc.zero()
        l_acc = nn.NdArray(t_model.loss.shape)
        l_acc.zero()
        # Gradient accumulation loop
        for j in range(args.accum_grad):
            images, labels, masks = data.next()
            t_model.image.d = images
            t_model.label.d = labels
            t_model.mask.d = masks
            t_model.image.data.cast(np.float32, ctx)
            t_model.label.data.cast(np.int32, ctx)
            t_model.loss.forward(clear_no_need_grad=True)
            t_model.loss.backward(clear_buffer=True)  # Accumulating gradients
            t_e.forward(clear_buffer=True)
            e_acc += t_e.data
            l_acc += t_model.loss.data

        # AllReduce
        if distributed:
            params = [x.grad for x in nn.get_parameters().values()]
            comm.all_reduce(params, division=False, inplace=False)
            comm.all_reduce(l_acc, division=True, inplace=True)
            comm.all_reduce(e_acc, division=True, inplace=True)
        solver.scale_grad(1./args.accum_grad)
        solver.weight_decay(args.weight_decay)
        solver.update()

        # Linear Warmup
        if i <= warmup_iter:
            lr = base_lr + warmup_slope * i
            solver.set_learning_rate(lr)

        if distributed:
            # Synchronize by averaging the weights over devices using allreduce
            if (i+1) % args.sync_weight_every_itr == 0:
                weights = [x.data for x in nn.get_parameters().values()]
                comm.all_reduce(weights, division=True, inplace=True)

        if device_id == 0:
            monitor_loss.add(
                i * n_devices, (l_acc / args.accum_grad).data.copy())
            monitor_err.add(
                i * n_devices, (e_acc / args.accum_grad).data.copy())
            monitor_time.add(i * n_devices)

        # Learning rate decay at scheduled iter --> changed to poly learning rate decay policy
        # if i in args.learning_rate_decay_at:
        solver.set_learning_rate(base_lr * ((1 - i / args.max_iter)**0.1))

    if device_id == 0:
        nn.save_parameters(os.path.join(args.model_save_path,
                                        'param_%06d.h5' % args.max_iter))

    contents = save_nnp({'x': v_model.image}, {
                        'y': v_model.pred}, args.batch_size)
    save.save(os.path.join(args.model_save_path,
                           'Deeplabv3plus_result.nnp'), contents, variable_batch_size=False)