コード例 #1
0
def main():
    setup_logger()
    arguments = docopt.docopt(__doc__)
    data_root = arguments['--data']
    batch_size = int(arguments['--batch_size'])

    train_path = os.path.join(data_root, 'train.brick')
    train_iter = Reader(train_path, batch_size=batch_size)

    val_path = os.path.join(data_root, 'val.brick')
    val_iter = Reader(val_path, batch_size=batch_size)
    pre_iter = mx.io.PrefetchingIter([train_iter])

    model = SqueezeDet()
    module = build_module(model.error,
                          'squeezeDetMX',
                          train_iter,
                          ctx=[mx.gpu(0),
                               mx.gpu(1),
                               mx.gpu(2),
                               mx.gpu(3)])

    try:
        module.fit(
            train_data=pre_iter,
            eval_data=val_iter,
            num_epoch=50,
            batch_end_callback=mx.callback.Speedometer(batch_size, 10),
            eval_metric=metric.CompositeEvalMetric(
                metrics=[BboxError(), ClassError(),
                         IOUError()]),
            epoch_end_callback=mx.callback.do_checkpoint('squeezeDetMX', 1))
    except KeyboardInterrupt:
        module.save_params('squeezeDet-{}-9999.params'.format(
            str(time.time())[-5:]))
コード例 #2
0
ファイル: det_solver.py プロジェクト: xiaohedu/dspnet
    def fit(self, train_data, eval_data=None,
            eval_metric='acc',
            grad_req='write',
            epoch_end_callback=None,
            batch_end_callback=None,
            kvstore='local',
            logger=None):
        global outimgiter
        if logger is None:
            logger = logging
        logging.info('Start training with %s', str(self.ctx))
        logging.info(str(self.kwargs))
        batch_size = train_data.provide_data[0][1][0]
        arg_shapes, out_shapes, aux_shapes = self.symbol.infer_shape( \
            data=tuple(train_data.provide_data[0][1]), label_det=(batch_size,200,6))
        arg_names = self.symbol.list_arguments()
        out_names = self.symbol.list_outputs()
        aux_names = self.symbol.list_auxiliary_states()

        # pprint([(n,s) for n,s in zip(arg_names,arg_shapes)])
        # pprint([(n,s) for n,s in zip(out_names,out_shapes)])
        # pprint([(n,s) for n,s in zip(aux_names,aux_shapes)])
        
        if grad_req != 'null':
            self.grad_params = {}
            for name, shape in zip(arg_names, arg_shapes):
                if not (name.endswith('data') or name.endswith('label')):
                    self.grad_params[name] = mx.nd.zeros(shape, self.ctx)
        else:
            self.grad_params = None
        self.aux_params = {k : mx.nd.zeros(s, self.ctx) for k, s in zip(aux_names, aux_shapes)}
        data_name = train_data.provide_data[0][0]
        label_name_det = train_data.provide_label[0][0]
        label_name_seg = train_data.provide_label[1][0]
        input_names = [data_name, label_name_det, label_name_seg]

        print(train_data.provide_label)
        print(os.environ["MXNET_CUDNN_AUTOTUNE_DEFAULT"])

        self.optimizer = opt.create(self.optimizer, rescale_grad=(1.0/train_data.batch_size), **(self.kwargs))
        self.updater = get_updater(self.optimizer)
        eval_metric = CustomAccuracyMetric() # metric.create(eval_metric)
        multibox_metric = MultiBoxMetric()

        eval_metrics = metric.CompositeEvalMetric()
        eval_metrics.add(multibox_metric)
        # eval_metrics.add(eval_metric)
        
        # begin training
        for epoch in range(self.begin_epoch, self.num_epoch):
            nbatch = 0
            train_data.reset()
            eval_metrics.reset()
            logger.info('learning rate: '+str(self.optimizer.learning_rate))
            for data,_ in train_data:
                if self.evaluation_only:
                    break
                nbatch += 1
                label_shape_det = data.label[0].shape
                label_shape_seg = data.label[1].shape
                self.arg_params[data_name] = mx.nd.array(data.data[0], self.ctx)
                self.arg_params[label_name_det] = mx.nd.array(data.label[0], self.ctx)
                self.arg_params[label_name_seg] = mx.nd.array(data.label[1], self.ctx)
                output_names = self.symbol.list_outputs()

                ###################### analyze shapes ####################
                # pprint([(k,v.shape) for k,v in self.arg_params.items()])
                
                self.executor = self.symbol.bind(self.ctx, self.arg_params,
                    args_grad=self.grad_params, grad_req=grad_req, aux_states=self.aux_params)
                assert len(self.symbol.list_arguments()) == len(self.executor.grad_arrays)
                update_dict = {name: nd for name, nd in zip(self.symbol.list_arguments(), \
                    self.executor.grad_arrays) if nd is not None}
                output_dict = {}
                output_buff = {}
                for key, arr in zip(self.symbol.list_outputs(), self.executor.outputs):
                    output_dict[key] = arr
                    output_buff[key] = mx.nd.empty(arr.shape, ctx=mx.cpu())
                    # output_buff[key] = mx.nd.empty(arr.shape, ctx=self.ctx)

                def stat_helper(name, array):
                    """wrapper for executor callback"""
                    import ctypes
                    from mxnet.ndarray import NDArray
                    from mxnet.base import NDArrayHandle, py_str
                    array = ctypes.cast(array, NDArrayHandle)
                    if 0:
                        array = NDArray(array, writable=False).asnumpy()
                        print (name, array.shape, np.mean(array), np.std(array),
                               ('%.1fms' % (float(time.time()-stat_helper.start_time)*1000)))
                    else:
                        array = NDArray(array, writable=False)
                        array.wait_to_read()
                        elapsed = float(time.time()-stat_helper.start_time)*1000.
                        if elapsed>5:
                            print (name, array.shape, ('%.1fms' % (elapsed,)))
                    stat_helper.start_time=time.time()
                stat_helper.start_time=float(time.time())
                # self.executor.set_monitor_callback(stat_helper)

                tic = time.time()
                    
                self.executor.forward(is_train=True)
                for key in output_dict:
                    output_dict[key].copyto(output_buff[key])

                # exit(0) # for debugging forward pass only
                    
                self.executor.backward()
                for key, arr in update_dict.items():
                    if key != "bigscore_weight":
                        self.updater(key, arr, self.arg_params[key])

                for output in self.executor.outputs:
                    output.wait_to_read()
                if TIMING:
                    print("%.0fms" % ((time.time()-tic)*1000.,))
                        
                output_dict = dict(zip(output_names, self.executor.outputs))
                pred_det_shape = output_dict["det_out_output"].shape
                # pred_seg_shape = output_dict["seg_out_output"].shape
                label_det = mx.nd.array(data.label[0].reshape((label_shape_det[0],
                                                               label_shape_det[1]*label_shape_det[2])))
                # label_seg = mx.nd.array(data.label[1].reshape((label_shape_seg[0],
                #                                                label_shape_seg[1]*label_shape_seg[2])))
                pred_det = mx.nd.array(output_buff["det_out_output"].reshape((pred_det_shape[0],
                    pred_det_shape[1], pred_det_shape[2])))
                # pred_seg = mx.nd.array(output_buff["seg_out_output"].reshape((pred_seg_shape[0],
                #     pred_seg_shape[1], pred_seg_shape[2]*pred_seg_shape[3])))
                if DEBUG:
                    print(data.label[0].asnumpy()[0,:2,:])

                if TIMING:
                    print("%.0fms" % ((time.time()-tic)*1000.,))
                    
                eval_metrics.get_metric(0).update([mx.nd.zeros(output_buff["cls_prob_output"].shape),
                                        mx.nd.zeros(output_buff["loc_loss_output"].shape),label_det],
                                       [output_buff["cls_prob_output"], output_buff["loc_loss_output"],
                                        output_buff["cls_label_output"]])
                # eval_metrics.get_metric(1).update([label_seg.as_in_context(self.ctx)], [pred_seg.as_in_context(self.ctx)])

                self.executor.outputs[0].wait_to_read()

                ##################### display results ##############################
                # out_det = output_dict["det_out_output"].asnumpy()
                # for imgidx in range(out_det.shape[0]):
                #     img = np.squeeze(data.data[0].asnumpy()[imgidx,:,:,:])
                #     det = out_det[imgidx,:,:]
                #     gt = label_det.asnumpy()[imgidx,:].reshape((-1,6))
                #     display_results(img, det, gt, self.class_names)
                #     [exit(0) if (cv2.waitKey(1)&0xff)==27 else None]
                # outimgiter += 1

                batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, eval_metric=eval_metrics)
                batch_end_callback(batch_end_params)

                if TIMING:
                    print("%.0fms" % ((time.time()-tic)*1000.,))
                    
                # exit(0) # for debugging only
                
            ##### save snapshot
            if (not self.evaluation_only) and (epoch_end_callback is not None):
                epoch_end_callback(epoch, self.symbol, self.arg_params, self.aux_params)
                
            names, values = eval_metrics.get()
            for name, value in zip(names,values):
                logger.info("                     --->Epoch[%d] Train-%s=%f", epoch, name, value)
                
            # evaluation
            if eval_data:
                logger.info(" in eval process...")
                nbatch = 0
                depth_metric = DistanceAccuracyMetric(class_names=self.class_names)
                eval_data.reset()
                eval_metrics.reset()
                self.valid_metric.reset()
                depth_metric.reset()
                timing_results = []
                for data, fnames in eval_data:
                    nbatch += 1
                    label_shape_det = data.label[0].shape
                    # label_shape_seg = data.label[1].shape
                    self.arg_params[data_name] = mx.nd.array(data.data[0], self.ctx)
                    self.arg_params[label_name_det] = mx.nd.array(data.label[0], self.ctx)
                    # self.arg_params[label_name_seg] = mx.nd.array(data.label[1], self.ctx)
                    self.executor = self.symbol.bind(self.ctx, self.arg_params,
                        args_grad=self.grad_params, grad_req=grad_req, aux_states=self.aux_params)
                    
                    output_names = self.symbol.list_outputs()
                    output_dict = dict(zip(output_names, self.executor.outputs))

                    # cpu_output_array = mx.nd.zeros(output_dict["seg_out_output"].shape)

                    ############## monitor status
                    # def stat_helper(name, array):
                    #     """wrapper for executor callback"""
                    #     import ctypes
                    #     from mxnet.ndarray import NDArray
                    #     from mxnet.base import NDArrayHandle, py_str
                    #     array = ctypes.cast(array, NDArrayHandle)
                    #     if 1:
                    #         array = NDArray(array, writable=False).asnumpy()
                    #         print (name, array.shape, np.mean(array), np.std(array),
                    #                ('%.1fms' % (float(time.time()-stat_helper.start_time)*1000)))
                    #     else:
                    #         array = NDArray(array, writable=False)
                    #         array.wait_to_read()
                    #         elapsed = float(time.time()-stat_helper.start_time)*1000.
                    #         if elapsed>5:
                    #             print (name, array.shape, ('%.1fms' % (elapsed,)))
                    #     stat_helper.start_time=time.time()
                    # stat_helper.start_time=float(time.time())
                    # self.executor.set_monitor_callback(stat_helper)
                    
                    ############## forward
                    tic = time.time()
                    self.executor.forward(is_train=True)
                    # output_dict["seg_out_output"].wait_to_read()
                    timing_results.append((time.time()-tic)*1000.)
                    
                    # output_dict["seg_out_output"].copyto(cpu_output_array)
                    # pred_shape = output_dict["seg_out_output"].shape
                    # label = mx.nd.array(data.label[1].reshape((label_shape_seg[0], label_shape_seg[1]*label_shape_seg[2])))
                    # output_dict["seg_out_output"].wait_to_read()
                    # seg_out_output = output_dict["seg_out_output"].asnumpy()

                    pred_det_shape = output_dict["det_out_output"].shape
                    # pred_seg_shape = output_dict["seg_out_output"].shape
                    label_det = mx.nd.array(data.label[0].reshape((label_shape_det[0], label_shape_det[1]*label_shape_det[2])))
                    # label_seg = mx.nd.array(data.label[1].reshape((label_shape_seg[0], label_shape_seg[1]*label_shape_seg[2])),ctx=self.ctx)
                    pred_det = mx.nd.array(output_dict["det_out_output"].reshape((pred_det_shape[0], pred_det_shape[1], pred_det_shape[2])))
                    # pred_seg = mx.nd.array(output_dict["seg_out_output"].reshape((pred_seg_shape[0], pred_seg_shape[1], pred_seg_shape[2]*pred_seg_shape[3])),ctx=self.ctx)

                    #### remove invalid boxes
                    out_dets = output_dict["det_out_output"].asnumpy()
                    assert len(out_dets.shape)==3
                    pred_det = np.zeros((batch_size, 200, 7), np.float32)-1.
                    for idx, out_det in enumerate(out_dets):
                        assert len(out_det.shape)==2
                        out_det = np.expand_dims(out_det, axis=0)
                        indices = np.where(out_det[:,:,0]>=0) # labeled as negative
                        out_det = np.expand_dims(out_det[indices[0],indices[1],:],axis=0)
                        indices = np.where(out_det[:,:,1]>.25) # higher confidence
                        out_det = np.expand_dims(out_det[indices[0],indices[1],:],axis=0)
                        pred_det[idx, :out_det.shape[1], :] = out_det
                        del out_det
                    pred_det = mx.nd.array(pred_det)
                    
                    ##### display results
                    if False: # self.evaluation_only:
                        # out_img = output_dict["seg_out_output"]
                        # out_img = mx.nd.split(out_img, axis=0, num_outputs=out_img.shape[0], squeeze_axis=0)
                        # if not isinstance(out_img,list):
                        #     out_img = [out_img]
                        for imgidx in range(eval_data.batch_size):
                            img = np.squeeze(data.data[0].asnumpy()[imgidx,:,:,:])
                            det = pred_det.asnumpy()[imgidx,:,:]
                            ### ground-truth
                            gt = label_det.asnumpy()[imgidx,:].reshape((-1,6))
                            # display result
                            display_img = display_results(img, det, gt, self.class_names)
                            res_fname = fnames[imgidx].replace("SegmentationClass","Results").replace("labelIds","results")
                            if cv2.imwrite(res_fname, display_img):
                                print(res_fname,'saved.')
                            [exit(0) if (cv2.waitKey()&0xff)==27 else None]
                        outimgiter += 1

                    if self.evaluation_only:
                        continue

                    eval_metrics.get_metric(0).update(None,
                                           [output_dict["cls_prob_output"], output_dict["loc_loss_output"],
                                            output_dict["cls_label_output"]])
                    # eval_metrics.get_metric(1).update([label_seg], [pred_seg])
                    self.valid_metric.update([mx.nd.slice_axis(data.label[0],axis=2,begin=0,end=5)], \
                                             [mx.nd.slice_axis(pred_det,axis=2,begin=0,end=6)])
                    disparities = []
                    for imgidx in range(batch_size):
                        dispname = fnames[imgidx].replace("SegmentationClass","Disparity").replace("gtFine_labelTrainIds","disparity")
                        disparities.append(cv2.imread(dispname,-1))
                        assert disparities[0] is not None, dispname + " not found."
                    depth_metric.update(mx.nd.array(disparities),[pred_det])
                    
                    det_metric = self.valid_metric
                    det_names, det_values = det_metric.get()
                    depth_names, depth_values = depth_metric.get()
                    print("\r %d/%d speed=%.1fms %.1f%% %s=%.1f %s=%.1f" % \
                          (nbatch*eval_data.batch_size,eval_data.num_samples,
                           math.fsum(timing_results)/float(nbatch),
                           float(nbatch*eval_data.batch_size)*100./float(eval_data.num_samples),
                           det_names[-1],det_values[-1]*100.,
                           depth_names[-1],depth_values[-1]*100.,),end='\r')
                    
                names, values = eval_metrics.get()
                for name, value in zip(names,values):
                    logger.info(' epoch[%d] Validation-%s=%f', epoch, name, value)
                logger.info('----------------------------------------------')
                print(' & '.join(names))
                print(' & '.join(map(lambda v:'%.1f'%(v*100.,),values)))
                logger.info('----------------------------------------------')
                names, values = self.valid_metric.get()
                for name, value in zip(names,values):
                    logger.info(' epoch[%d] Validation-%s=%f', epoch, name, value)
                logger.info('----------------------------------------------')
                print(' & '.join(names))
                print(' & '.join(map(lambda v:'%.1f'%(v*100.,),values)))
                logger.info('----------------------------------------------')
                names, values = depth_metric.get()
                for name, value in zip(names,values):
                    logger.info(' epoch[%d] Validation-%s=%f', epoch, name, value)
                logger.info('----------------------------------------------')
                print(' & '.join(names))
                print(' & '.join(map(lambda v:'%.1f'%(v*100.,),values)))
                logger.info('----------------------------------------------')
                    
                if self.evaluation_only:
                    exit(0) ## for debugging only
コード例 #3
0
ファイル: multi_eval.py プロジェクト: xiaohedu/dspnet
def evaluate(netname,
             path_imgrec,
             num_classes,
             num_seg_classes,
             mean_pixels,
             data_shape,
             model_prefix,
             epoch,
             ctx=mx.cpu(),
             batch_size=1,
             path_imglist="",
             nms_thresh=0.45,
             force_nms=False,
             ovp_thresh=0.5,
             use_difficult=False,
             class_names=None,
             seg_class_names=None,
             voc07_metric=False):
    """
    evalute network given validation record file

    Parameters:
    ----------
    net : str or None
        Network name or use None to load from json without modifying
    path_imgrec : str
        path to the record validation file
    path_imglist : str
        path to the list file to replace labels in record file, optional
    num_classes : int
        number of classes, not including background
    mean_pixels : tuple
        (mean_r, mean_g, mean_b)
    data_shape : tuple or int
        (3, height, width) or height/width
    model_prefix : str
        model prefix of saved checkpoint
    epoch : int
        load model epoch
    ctx : mx.ctx
        mx.gpu() or mx.cpu()
    batch_size : int
        validation batch size
    nms_thresh : float
        non-maximum suppression threshold
    force_nms : boolean
        whether suppress different class objects
    ovp_thresh : float
        AP overlap threshold for true/false postives
    use_difficult : boolean
        whether to use difficult objects in evaluation if applicable
    class_names : comma separated str
        class names in string, must correspond to num_classes if set
    voc07_metric : boolean
        whether to use 11-point evluation as in VOC07 competition
    """
    global outimgiter

    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    # args
    if isinstance(data_shape, int):
        data_shape = (3, data_shape, data_shape)
    else:
        data_shape = map(int, data_shape.split(","))
    assert len(data_shape) == 3 and data_shape[0] == 3
    model_prefix += '_' + str(data_shape[1])

    # iterator
    eval_iter = MultiTaskRecordIter(path_imgrec,
                                    batch_size,
                                    data_shape,
                                    path_imglist=path_imglist,
                                    enable_aug=False,
                                    **cfg.valid)
    # model params
    load_net, args, auxs = mx.model.load_checkpoint(model_prefix, epoch)
    # network
    if netname is None:
        net = load_net
    elif netname.endswith("det"):
        net = get_det_symbol(netname.split("_")[0],
                             data_shape[1],
                             num_classes=num_classes,
                             nms_thresh=nms_thresh,
                             force_suppress=force_nms)
    elif netname.endswith("seg"):
        net = get_seg_symbol(netname.split("_")[0],
                             data_shape[1],
                             num_classes=num_classes,
                             nms_thresh=nms_thresh,
                             force_suppress=force_nms)
    elif netname.endswith("multi"):
        net = get_multi_symbol(netname.split("_")[0],
                               data_shape[1],
                               num_classes=num_classes,
                               nms_thresh=nms_thresh,
                               force_suppress=force_nms)
    else:
        raise NotImplementedError("")

    if not 'label_det' in net.list_arguments():
        label_det = mx.sym.Variable(name='label_det')
        net = mx.sym.Group([net, label_det])
    if not 'seg_out_label' in net.list_arguments():
        seg_out_label = mx.sym.Variable(name='seg_out_label')
        net = mx.sym.Group([net, seg_out_label])

    # init module
    # mod = mx.mod.Module(net, label_names=('label_det','seg_out_label',), logger=logger, context=ctx,
    #     fixed_param_names=net.list_arguments())
    # mod.bind(data_shapes=eval_iter.provide_data, label_shapes=eval_iter.provide_label)
    # mod.set_params(args, auxs, allow_missing=False, force_init=True)
    # metric = MApMetric(ovp_thresh, use_difficult, class_names)
    # results = mod.score(eval_iter, metric, num_batch=None)
    # for k, v in results:
    #     print("{}: {}".format(k, v))

    ctx = ctx[0]
    eval_metric = CustomAccuracyMetric()
    multibox_metric = MultiBoxMetric()
    depth_metric = DistanceAccuracyMetric(class_names=class_names)
    det_metric = MApMetric(ovp_thresh, use_difficult, class_names)
    seg_metric = IoUMetric(class_names=seg_class_names, axis=1)
    eval_metrics = metric.CompositeEvalMetric()
    eval_metrics.add(multibox_metric)
    eval_metrics.add(eval_metric)
    arg_params = {key: val.as_in_context(ctx) for key, val in args.items()}
    aux_params = {key: val.as_in_context(ctx) for key, val in auxs.items()}
    data_name = eval_iter.provide_data[0][0]
    label_name_det = eval_iter.provide_label[0][0]
    label_name_seg = eval_iter.provide_label[1][0]
    symbol = load_net

    # evaluation
    logger.info(" in eval process...")
    logger.info(
        str({
            "ovp_thresh": ovp_thresh,
            "nms_thresh": nms_thresh,
            "batch_size": batch_size,
            "force_nms": force_nms,
        }))
    nbatch = 0
    eval_iter.reset()
    eval_metrics.reset()
    det_metric.reset()
    total_time = 0

    for data, fnames in eval_iter:
        nbatch += 1
        label_shape_det = data.label[0].shape
        label_shape_seg = data.label[1].shape
        arg_params[data_name] = mx.nd.array(data.data[0], ctx)
        arg_params[label_name_det] = mx.nd.array(data.label[0], ctx)
        arg_params[label_name_seg] = mx.nd.array(data.label[1], ctx)
        executor = symbol.bind(ctx, arg_params, aux_states=aux_params)

        output_names = symbol.list_outputs()
        output_dict = dict(zip(output_names, executor.outputs))

        cpu_output_array = mx.nd.zeros(output_dict["seg_out_output"].shape)

        ############## monitor status
        def stat_helper(name, array):
            """wrapper for executor callback"""
            import ctypes
            from mxnet.ndarray import NDArray
            from mxnet.base import NDArrayHandle, py_str
            array = ctypes.cast(array, NDArrayHandle)
            if 1:
                array = NDArray(array, writable=False).asnumpy()
                print(name, array.shape, np.mean(array), np.std(array),
                      ('%.1fms' %
                       (float(time.time() - stat_helper.start_time) * 1000)))
            else:
                array = NDArray(array, writable=False)
                array.wait_to_read()
                elapsed = float(time.time() - stat_helper.start_time) * 1000.
                if elapsed > 5:
                    print(name, array.shape, ('%.1fms' % (elapsed, )))
            stat_helper.start_time = time.time()

        stat_helper.start_time = float(time.time())
        # executor.set_monitor_callback(stat_helper)

        ############## forward
        tic = time.time()
        executor.forward(is_train=True)
        output_dict["seg_out_output"].copyto(cpu_output_array)
        pred_shape = output_dict["seg_out_output"].shape
        label = mx.nd.array(data.label[1].reshape(
            (label_shape_seg[0], label_shape_seg[1] * label_shape_seg[2])))
        output_dict["seg_out_output"].wait_to_read()

        toc = time.time()

        seg_out_output = output_dict["seg_out_output"].asnumpy()

        pred_seg_shape = output_dict["seg_out_output"].shape
        label_det = mx.nd.array(data.label[0].reshape(
            (label_shape_det[0], label_shape_det[1] * label_shape_det[2])))
        label_seg = mx.nd.array(data.label[1].reshape(
            (label_shape_seg[0], label_shape_seg[1] * label_shape_seg[2])),
                                ctx=ctx)
        pred_seg = mx.nd.array(output_dict["seg_out_output"].reshape(
            (pred_seg_shape[0], pred_seg_shape[1],
             pred_seg_shape[2] * pred_seg_shape[3])),
                               ctx=ctx)
        #### remove invalid boxes
        out_det = output_dict["det_out_output"].asnumpy()
        indices = np.where(out_det[:, :, 0] >= 0)  # labeled as negative
        out_det = np.expand_dims(out_det[indices[0], indices[1], :], axis=0)
        indices = np.where(out_det[:, :, 1] > .1)  # higher confidence
        out_det = np.expand_dims(out_det[indices[0], indices[1], :], axis=0)
        # indices = np.where(out_det[:,:,6]<=(100/255.)) # too far away
        # out_det = np.expand_dims(out_det[indices[0],indices[1],:],axis=0)
        pred_det = mx.nd.array(out_det)
        #### remove labels too faraway
        # label_det = label_det.asnumpy().reshape((200,6))
        # indices = np.where(label_det[:,5]<=(100./255.))
        # label_det = np.expand_dims(label_det[indices[0],:],axis=0)
        # label_det = mx.nd.array(label_det)

        ################# display results ####################
        out_img = output_dict["seg_out_output"]
        out_img = mx.nd.split(out_img, axis=0, num_outputs=out_img.shape[0])
        for imgidx in range(batch_size):
            seg_prob = out_img[imgidx]
            res_img = np.squeeze(seg_prob.asnumpy().argmax(axis=0).astype(
                np.uint8))
            label_img = data.label[1].asnumpy()[imgidx, :, :].astype(np.uint8)
            img = np.squeeze(data.data[0].asnumpy()[imgidx, :, :, :])
            det = out_det[imgidx, :, :]
            gt = label_det.asnumpy()[imgidx, :].reshape((-1, 6))
            # save to results folder for evalutation
            res_fname = fnames[imgidx].replace("SegmentationClass", "results")
            lut = np.zeros(256)
            lut[:19] = np.array([
                7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
                31, 32, 33
            ])
            seg_resized = prob_upsampling(seg_prob, target_shape=(1024, 2048))
            seg_resized2 = cv2.LUT(seg_resized, lut)
            # seg = cv2.LUT(res_img,lut)
            # cv2.imshow("seg",seg.astype(np.uint8))
            cv2.imwrite(res_fname, seg_resized2)
            # display result
            print(fnames[imgidx], np.average(img))
            display_img = display_results(res_img,
                                          np.expand_dims(label_img, axis=0),
                                          img, det, gt, class_names)
            res_fname = fnames[imgidx].replace("SegmentationClass",
                                               "output").replace(
                                                   "labelTrainIds", "output")
            cv2.imwrite(res_fname, display_img)
            [exit(0) if (cv2.waitKey() & 0xff) == 27 else None]
        outimgiter += 1
        ################# display results ####################

        eval_metrics.get_metric(0).update(None, [
            output_dict["cls_prob_output"], output_dict["loc_loss_output"],
            output_dict["cls_label_output"]
        ])
        eval_metrics.get_metric(1).update([label_seg], [pred_seg])
        det_metric.update([mx.nd.slice_axis(data.label[0],axis=2,begin=0,end=5)], \
                                 [mx.nd.slice_axis(pred_det,axis=2,begin=0,end=6)])
        seg_metric.update([label_seg], [pred_seg])
        disparities = []
        for imgidx in range(batch_size):
            dispname = fnames[imgidx].replace("SegmentationClass",
                                              "Disparity").replace(
                                                  "gtFine_labelTrainIds",
                                                  "disparity")
            print(dispname)
            disparities.append(cv2.imread(dispname, -1))
        depth_metric.update(mx.nd.array(disparities), [pred_det])

        det_names, det_values = det_metric.get()
        seg_names, seg_values = seg_metric.get()
        depth_names, depth_values = depth_metric.get()
        total_time += toc - tic
        print("\r %d/%d %.1f%% speed=%.1fms %s=%.1f %s=%.1f %s=%.1f" % (
            nbatch * eval_iter.batch_size,
            eval_iter.num_samples,
            float(nbatch * eval_iter.batch_size) * 100. /
            float(eval_iter.num_samples),
            total_time * 1000. / nbatch,
            det_names[-1],
            det_values[-1] * 100.,
            seg_names[-1],
            seg_values[-1] * 100.,
            depth_names[-1],
            depth_values[-1] * 100.,
        ),
              end='\r')

        # if nbatch>50: break ## debugging

    names, values = eval_metrics.get()
    for name, value in zip(names, values):
        logger.info(' epoch[%d] Validation-%s=%f', epoch, name, value)
    logger.info('----------------------------------------------')
    names, values = det_metric.get()
    for name, value in zip(names, values):
        logger.info(' epoch[%d] Validation-%s=%f', epoch, name, value)
    logger.info('----------------------------------------------')
    logger.info(' & '.join(names))
    logger.info(' & '.join(map(lambda v: '%.1f' % (v * 100., ), values)))
    logger.info('----------------------------------------------')
    names, values = depth_metric.get()
    for name, value in zip(names, values):
        logger.info(' epoch[%d] Validation-%s=%f', epoch, name, value)
    logger.info('----------------------------------------------')
    logger.info(' & '.join(names))
    logger.info(' & '.join(map(lambda v: '%.1f' % (v * 100., ), values)))
    logger.info('----------------------------------------------')
    names, values = seg_metric.get()
    for name, value in zip(names, values):
        logger.info(' epoch[%d] Validation-%s=%f', epoch, name, value)
    logger.info('----------------------------------------------')
    logger.info(' & '.join(names))
    logger.info(' & '.join(map(lambda v: '%.1f' % (v * 100., ), values)))
コード例 #4
0
ファイル: float16.py プロジェクト: ycwang812/HKUSTLectures
                                  batch_size=BATCH_SIZE,
                                  num_workers=NUM_WORKERS,
                                  last_batch='discard')
if SYMBOLIC:
    net.hybridize(static_alloc=True, static_shape=True)

net.cast(TYPECAST)
net.initialize(ctx=ctx, force_reinit=True)

optimizer = mx.optimizer.SGD(momentum=0.9,
                             learning_rate=.001,
                             multi_precision=True)
trainer = gluon.Trainer(params=net.collect_params(), optimizer=optimizer)
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()

metrics = metric.CompositeEvalMetric([metric.Accuracy(), metric.RMSE()])
epochs = 3
outputs = False
old_label = False
first_run = True
print_n_sync = 2

for e in range(epochs):
    metrics.reset()
    tick = time.time()
    etic = time.time()
    for i, minibatch in enumerate(train_data):
        if i == 0:
            tick_0 = time.time()

        data = gluon.utils.split_and_load(data=minibatch[0],