Beispiel #1
0
def test_func(args):
    """
     main function for testing
     param args: global arguments
     return: None
    """
    print(args)
    global network_type

    if args.cuda:
        print("=====> use gpu id: '{}'".format(args.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
        if not torch.cuda.is_available():
            raise Exception(
                "no GPU found or wrong gpu id, please run without --cuda")

    args.seed = random.randint(1, 10000)
    print("Random Seed: ", args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    print('=====> checking if processed cached_data_file exists')
    if not os.path.isfile(args.inform_data_file):
        dataCollect = CityscapesTrainInform(
            args.data_dir,
            args.classes,
            train_set_file=args.dataset_list,
            inform_data_file=args.inform_data_file
        )  #collect mean std, weigth_class information
        data = dataCollect.collectDataAndSave()
        if data is None:
            print("error while pickling data, please check")
            exit(-1)
    else:
        data = pickle.load(open(args.inform_data_file, "rb"))
    M = args.M
    N = args.N

    model = CGNet.Context_Guided_Network(classes=args.classes, M=M, N=N)
    network_type = "CGNet"
    print("Arch:  CGNet")
    # define optimization criteria
    weight = torch.from_numpy(
        data['classWeights'])  # convert the numpy array to torch
    if args.cuda:
        weight = weight.cuda()
    criteria = CrossEntropyLoss2d(weight)  #weight

    if args.cuda:
        model = model.cuda()  # using GPU for inference
        criteria = criteria.cuda()
        cudnn.benchmark = True

    print('Dataset statistics')
    print('mean and std: ', data['mean'], data['std'])
    print('classWeights: ', data['classWeights'])

    if args.save_seg_dir:
        if not os.path.exists(args.save_seg_dir):
            os.makedirs(args.save_seg_dir)

    # validation set
    testLoader = torch.utils.data.DataLoader(CityscapesTestDataSet(
        args.data_dir, args.test_data_list, mean=data['mean']),
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=args.num_workers,
                                             pin_memory=True)

    if args.resume:
        if os.path.isfile(args.resume):
            print("=====> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            #model.load_state_dict(checkpoint['model'])
            model.load_state_dict(convert_state_dict(checkpoint['model']))
        else:
            print("=====> no checkpoint found at '{}'".format(args.resume))

    print("=====> beginning testing")
    print("test set length: ", len(testLoader))
    test(args, testLoader, model)
Beispiel #2
0
def test_model(args):
    """
     main function for testing
     param args: global arguments
     return: None
    """
    print(args)

    if args.cuda:
        print("=====> use gpu id: '{}'".format(args.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
        if not torch.cuda.is_available():
            raise Exception(
                "no GPU found or wrong gpu id, please run without --cuda")

    # build the model
    model = build_model(args.model, num_classes=args.classes)

    if args.cuda:
        model = model.cuda()  # using GPU for inference
        cudnn.benchmark = True

    if args.save:
        if not os.path.exists(args.save_seg_dir):
            os.makedirs(args.save_seg_dir)

    # load the test set
    datas, testLoader = build_dataset_test(args.dataset, args.num_workers)

    if not args.best:
        if args.checkpoint:
            if os.path.isfile(args.checkpoint):
                print("=====> loading checkpoint '{}'".format(args.checkpoint))
                checkpoint = torch.load(args.checkpoint)
                #model.load_state_dict(checkpoint['model'])
                model.load_state_dict(convert_state_dict(checkpoint['model']))
            else:
                print("=====> no checkpoint found at '{}'".format(
                    args.checkpoint))
                raise FileNotFoundError("no checkpoint found at '{}'".format(
                    args.checkpoint))

        print("=====> beginning validation")
        print("validation set length: ", len(testLoader))
        mIOU_val, per_class_iu = test(args, testLoader, model)
        print(mIOU_val)
        print(per_class_iu)

    # Get the best test result among the last 10 model records.
    else:
        if args.checkpoint:
            if os.path.isfile(args.checkpoint):
                dirname, basename = os.path.split(args.checkpoint)
                epoch = int(os.path.splitext(basename)[0].split('_')[1])
                mIOU_val = []
                per_class_iu = []
                for i in range(epoch - 9, epoch + 1):
                    basename = 'model_' + str(i) + '.pth'
                    resume = os.path.join(dirname, basename)
                    checkpoint = torch.load(resume)
                    model.load_state_dict(checkpoint['model'])
                    print("=====> beginning test the" + basename)
                    print("validation set length: ", len(testLoader))
                    mIOU_val_0, per_class_iu_0 = test(args, testLoader, model)
                    mIOU_val.append(mIOU_val_0)
                    per_class_iu.append(per_class_iu_0)

                index = list(range(epoch - 9, epoch + 1))[np.argmax(mIOU_val)]
                print("The best mIoU among the last 10 models is", index)
                print(mIOU_val)
                per_class_iu = per_class_iu[np.argmax(mIOU_val)]
                mIOU_val = np.max(mIOU_val)
                print(mIOU_val)
                print(per_class_iu)

            else:
                print("=====> no checkpoint found at '{}'".format(
                    args.checkpoint))
                raise FileNotFoundError("no checkpoint found at '{}'".format(
                    args.checkpoint))

    # Save the result
    if not args.best:
        model_path = os.path.splitext(os.path.basename(args.checkpoint))
        args.logFile = 'test_' + model_path[0] + '.txt'
        logFileLoc = os.path.join(os.path.dirname(args.checkpoint),
                                  args.logFile)
    else:
        args.logFile = 'test_' + 'best' + str(index) + '.txt'
        logFileLoc = os.path.join(os.path.dirname(args.checkpoint),
                                  args.logFile)

    # Save the result
    if os.path.isfile(logFileLoc):
        logger = open(logFileLoc, 'a')
    else:
        logger = open(logFileLoc, 'w')
        logger.write("Mean IoU: %.4f" % mIOU_val)
        logger.write("\nPer class IoU: ")
        for i in range(len(per_class_iu)):
            logger.write("%.4f\t" % per_class_iu[i])
    logger.flush()
    logger.close()
def test_func(args):
    """
     main function for testing
     param args: global arguments
     return: None
    """
    print(args)
    global network_type

    if args.cuda:
        print("=====> use gpu id: '{}'".format(args.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
        if not torch.cuda.is_available():
            raise Exception(
                "no GPU found or wrong gpu id, please run without --cuda")

        device = 'cuda'

    args.seed = random.randint(1, 10000)
    print("Random Seed: ", args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    print('=====> checking if processed cached_data_file exists')
    if not os.path.isfile(args.inform_data_file):
        dataCollect = CityscapesTrainInform(
            args.data_dir,
            args.classes,
            train_set_file=args.dataset_list,
            inform_data_file=args.inform_data_file
        )  #collect mean std, weigth_class information
        data = dataCollect.collectDataAndSave()
        if data is None:
            print("error while pickling data, please check")
            exit(-1)
    else:
        data = pickle.load(open(args.inform_data_file, "rb"))
    M = args.M
    N = args.N

    model = MobileNetV3(model_mode="SMALL", num_classes=args.classes)

    network_type = "MobileNetV3"
    print("Arch:  MobileNetV3")

    if args.cuda:
        model = model.to(device)  # using GPU for inference
        cudnn.benchmark = True

    print('Dataset statistics')
    print('mean and std: ', data['mean'], data['std'])
    print('classWeights: ', data['classWeights'])

    # validation set
    testLoader = torch.utils.data.DataLoader(CityscapesTestDataSet(
        args.data_dir, args.test_data_list, mean=data['mean']),
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=args.num_workers,
                                             pin_memory=True)

    if args.resume:
        if os.path.isfile(args.resume):
            print("=====> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            #model.load_state_dict(checkpoint['model'])
            model.load_state_dict(convert_state_dict(checkpoint['model']))
        else:
            print("=====> no checkpoint found at '{}'".format(args.resume))

    print("=====> beginning testing")
    print("test set length: ", len(testLoader))
    test(args, testLoader, model, device, data)
Beispiel #4
0
def test_func(args):
    '''
     Main function for testing
     param args: global arguments
     return: None
    '''
    print(args)
    global network_type

    if args.cuda:
        print("=====> use gpu id: '{}'".format(args.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
        if not torch.cuda.is_available():
            raise Exception(
                "No GPU found or Wrong gpu id, please run without --cuda")

    args.seed = random.randint(1, 10000)
    print("Random Seed: ", args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    print('checking if processed cached_data_file exists or not')
    if not os.path.isfile(args.cached_data_file):
        dataCollect = StatisticalInformDataset(args.data_dir, args.classes,
                                               args.cached_data_file)
        data = dataCollect.collectDataAndSave()
        if data is None:
            print("error while pickling data, please check")
            exit(-1)
    else:
        data = pickle.load(open(args.cached_data_file, "rb"))

    M = args.M
    N = args.N
    # load the model
    print('====> Building network')
    model = net.Context_Guided_Network(classes=args.classes, M=M, N=N)
    network_type = "CGNet"
    print("Arch:  CGNet")

    if args.cuda:
        #model = torch.nn.DataParallel(model).cuda()  # multi-card testing
        model = model.cuda()  # single-card testing
    print('Dataset statistics')
    print('mean and std: ', data['mean'], data['std'])
    print('classWeights: ', data['classWeights'])

    if args.save_seg_dir:
        if not os.path.exists(args.save_seg_dir):
            os.makedirs(args.save_seg_dir)

    testLoader = torch.utils.data.DataLoader(CityscapesTestDataSet(
        args.data_dir, args.test_data_list, mean=data['mean']),
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=args.num_workers,
                                             pin_memory=True)

    if args.cuda:
        cudnn.benchmark = True

    print("=====> load pretrained model")
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(convert_state_dict(checkpoint['model']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # evaluate on test set
    print("=====> beginning test")
    print("test set length: ", len(testLoader))
    test(args, testLoader, model)