예제 #1
0
def build_dataset_train(dataset, input_size, batch_size, train_type, random_scale, random_mirror, num_workers):
    data_dir = os.path.join('./dataset/', dataset)
    dataset_list = os.path.join(dataset, '_trainval_list.txt')
    train_data_list = os.path.join(data_dir, dataset + '_' + train_type + '_list.txt')
    val_data_list = os.path.join(data_dir, dataset + '_val' + '_list.txt')
    inform_data_file = os.path.join('./dataset/inform/', dataset + '_inform.pkl')

    # inform_data_file collect the information of mean, std and weigth_class
    if not os.path.isfile(inform_data_file):
        print("%s is not found" % (inform_data_file))
        if dataset == "cityscapes":
            dataCollect = CityscapesTrainInform(data_dir, 19, train_set_file=dataset_list,
                                                inform_data_file=inform_data_file)
        elif dataset == 'camvid':
            dataCollect = CamVidTrainInform(data_dir, 11, train_set_file=dataset_list,
                                            inform_data_file=inform_data_file)
        else:
            raise NotImplementedError(
                "This repository now supports two datasets: cityscapes and camvid, %s is not included" % dataset)

        datas = dataCollect.collectDataAndSave()
        if datas is None:
            print("error while pickling data. Please check.")
            exit(-1)
    else:
        print("find file: ", str(inform_data_file))
        datas = pickle.load(open(inform_data_file, "rb"))

    if dataset == "cityscapes":

        trainLoader = data.DataLoader(
            CityscapesDataSet(data_dir, train_data_list, crop_size=input_size, scale=random_scale,
                              mirror=random_mirror, mean=datas['mean']),
            batch_size=batch_size, shuffle=True, num_workers=num_workers,
            pin_memory=True, drop_last=True)

        valLoader = data.DataLoader(
            CityscapesValDataSet(data_dir, val_data_list, f_scale=1, mean=datas['mean']),
            batch_size=1, shuffle=True, num_workers=num_workers, pin_memory=True,
            drop_last=True)

        return datas, trainLoader, valLoader

    elif dataset == "camvid":

        trainLoader = data.DataLoader(
            CamVidDataSet(data_dir, train_data_list, crop_size=input_size, scale=random_scale,
                          mirror=random_mirror, mean=datas['mean']),
            batch_size=batch_size, shuffle=True, num_workers=num_workers,
            pin_memory=True, drop_last=True)
        if train_type == 'trainval':
            val_data_list = os.path.join(data_dir, dataset + '_test' + '_list.txt')

        valLoader = data.DataLoader(
            CamVidValDataSet(data_dir, val_data_list, f_scale=1, mean=datas['mean']),
            batch_size=1, shuffle=True, num_workers=num_workers, pin_memory=True)

        return datas, trainLoader, valLoader
예제 #2
0
def build_dataset_test(dataset, num_workers, none_gt=False):
    data_dir = os.path.join('./dataset/', dataset)
    dataset_list = os.path.join(dataset, '_trainval_list.txt')
    test_data_list = os.path.join(data_dir, dataset + '_test' + '_list.txt')
    inform_data_file = os.path.join('./dataset/inform/', dataset + '_inform.pkl')

    # inform_data_file collect the information of mean, std and weigth_class
    if not os.path.isfile(inform_data_file):
        print("%s is not found" % (inform_data_file))
        if dataset == "cityscapes":
            dataCollect = CityscapesTrainInform(data_dir, 19, train_set_file=dataset_list,
                                                inform_data_file=inform_data_file)
        elif dataset == 'camvid':
            dataCollect = CamVidTrainInform(data_dir, 11, train_set_file=dataset_list,
                                            inform_data_file=inform_data_file)
        else:
            raise NotImplementedError(
                "This repository now supports two datasets: cityscapes and camvid, %s is not included" % dataset)
        
        datas = dataCollect.collectDataAndSave()
        if datas is None:
            print("error while pickling data. Please check.")
            exit(-1)
    else:
        print("find file: ", str(inform_data_file))
        datas = pickle.load(open(inform_data_file, "rb"))

    if dataset == "cityscapes":
        # for cityscapes, if test on validation set, set none_gt to False
        # if test on the test set, set none_gt to True
        if none_gt:
            testLoader = data.DataLoader(
                CityscapesTestDataSet(data_dir, test_data_list, mean=datas['mean']),
                batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True)
        else:
            test_data_list = os.path.join(data_dir, dataset + '_val' + '_list.txt')
            testLoader = data.DataLoader(
                CityscapesValDataSet(data_dir, test_data_list, mean=datas['mean']),
                batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True)

        return datas, testLoader

    elif dataset == "camvid":

        testLoader = data.DataLoader(
            CamVidValDataSet(data_dir, test_data_list, mean=datas['mean']),
            batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True)

        return datas, testLoader
def build_dataset_predict(dataset_path, dataset, num_workers, none_gt=False):
    data_dir = os.path.join('./dataset/', dataset)
    dataset_list = dataset + '_trainval_list.txt'
    inform_data_file = os.path.join('./dataset/inform/',
                                    dataset + '_inform.pkl')

    # inform_data_file collect the information of mean, std and weigth_class
    if not os.path.isfile(inform_data_file):
        print("%s is not found" % (inform_data_file))
        if dataset == "cityscapes":
            dataCollect = CityscapesTrainInform(
                data_dir,
                19,
                train_set_file=dataset_list,
                inform_data_file=inform_data_file)
        elif dataset == 'camvid':
            dataCollect = CamVidTrainInform(data_dir,
                                            11,
                                            train_set_file=dataset_list,
                                            inform_data_file=inform_data_file)
        elif dataset == 'custom_dataset':
            dataCollect = CustomTrainInform(data_dir,
                                            2,
                                            train_set_file=dataset_list,
                                            inform_data_file=inform_data_file)
        else:
            raise NotImplementedError(
                "This repository now supports two datasets: cityscapes and camvid, %s is not included"
                % dataset)

        datas = dataCollect.collectDataAndSave()
        if datas is None:
            print("error while pickling data. Please check.")
            exit(-1)
    else:
        print("find file: ", str(inform_data_file))
        datas = pickle.load(open(inform_data_file, "rb"))

    if dataset == "custom_dataset":
        testLoader = CustomPredictDataSet(dataset_path, mean=datas['mean'])
        return datas, testLoader
예제 #4
0
def test_model(args):
    """
    main function for testing 
    args:
       args: global arguments
    """
    print("=====> Check if the cached file exists ")
    if not os.path.isfile(args.inform_data_file):
        print("%s is not found" % (args.inform_data_file))
        dataCollect = CamVidTrainInform(
            args.data_dir,
            args.classes,
            train_set_file=args.dataset_list,
            inform_data_file=args.inform_data_file
        )  #collect mean std, weigth_class information
        datas = dataCollect.collectDataAndSave()
        if datas is None:
            print('Error while pickling data. Please check.')
            exit(-1)
    else:
        print("%s exists" % (args.inform_data_file))
        datas = pickle.load(open(args.inform_data_file, "rb"))

    print(args)
    global network_type

    if args.cuda:
        print("=====> Use gpu id: '{}'".format(args.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
        if not torch.cuda.is_available():
            raise Exception(
                "No GPU found or Wrong gpu id, please run without --cuda")

    args.seed = random.randint(1, 10000)
    print("Random Seed: ", args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
    cudnn.enabled = True

    M = args.M
    N = args.N
    model = CGNet.Context_Guided_Network(classes=args.classes, M=M, N=N)
    network_type = "CGNet"
    print("=====> current architeture:  CGNet_M%sN%s" % (M, N))
    total_paramters = netParams(model)
    print("the number of parameters: " + str(total_paramters))
    print("data['classWeights']: ", datas['classWeights'])
    weight = torch.from_numpy(datas['classWeights'])
    print("=====> Dataset statistics")
    print("mean and std: ", datas['mean'], datas['std'])

    # define optimization criteria
    criteria = CrossEntropyLoss2d(weight, args.ignore_label)
    if args.cuda:
        model = model.cuda()
        criteria = criteria.cuda()

    #load test set
    train_transform = transforms.Compose([transforms.ToTensor()])
    testLoader = data.DataLoader(CamVidValDataSet(args.data_dir,
                                                  args.test_data_list,
                                                  f_scale=1,
                                                  mean=datas['mean']),
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 pin_memory=True,
                                 drop_last=True)

    if args.resume:
        if os.path.isfile(args.resume):
            print("=====> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            #model.load_state_dict(convert_state_dict(checkpoint['model']))
            model.load_state_dict(checkpoint['model'])
        else:
            print("=====> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    print("=====> beginning test")
    print("length of test set:", len(testLoader))
    mIOU_val, per_class_iu = test(args, testLoader, model, criteria)
    print(mIOU_val)
    print(per_class_iu)
예제 #5
0
            test_data_list = os.path.join(data_dir,
                                          dataset + '_val' + '_list.txt')
            testLoader = data.DataLoader(CityscapesValDataSet(
                data_dir, test_data_list, mean=datas['mean']),
                                         batch_size=1,
                                         shuffle=False,
                                         num_workers=num_workers,
                                         pin_memory=True)

        return datas, testLoader

    elif dataset == "camvid":

        testLoader = data.DataLoader(CamVidValDataSet(data_dir,
                                                      test_data_list,
                                                      mean=datas['mean']),
                                     batch_size=1,
                                     shuffle=False,
                                     num_workers=num_workers,
                                     pin_memory=True)

        return datas, testLoader


if __name__ == "__main__":
    dataCollection = CamVidTrainInform(
        "/home/mohamed/RINet/dataset/camvid",
        classes=11,
        train_set_file="camvid_trainval_list.txt",
        inform_data_file="inform/camvid_inform.pkl")
    data = dataCollection.collectDataAndSa
예제 #6
0
def build_dataset_train(dataset, input_size, batch_size, train_type, random_scale, random_mirror, num_workers, args):
    data_dir = os.path.join('/media/sdb/datasets/segment/', dataset)
    dataset_list = dataset + '_trainval_list.txt'
    train_data_list = os.path.join(data_dir, dataset + '_' + train_type + '_list.txt')
    val_data_list = os.path.join(data_dir, dataset + '_val' + '_list.txt')
    inform_data_file = os.path.join('./dataset/inform/', dataset + '_inform.pkl')

    if dataset == "cityscapes":
        # inform_data_file collect the information of mean, std and weigth_class
        if not os.path.isfile(inform_data_file):
            print("%s is not found" % (inform_data_file))
            dataCollect = CityscapesTrainInform(data_dir, 19, train_set_file=dataset_list,
                                                    inform_data_file=inform_data_file)
            datas = dataCollect.collectDataAndSave()
            if datas is None:
                print("error while pickling data. Please check.")
                exit(-1)
        else:
            print("find file: ", str(inform_data_file))
            datas = pickle.load(open(inform_data_file, "rb"))

        trainLoader = data.DataLoader(
            CityscapesDataSet(data_dir, train_data_list, crop_size=input_size, scale=random_scale,
                              mirror=random_mirror, mean=datas['mean']),
            batch_size=batch_size, shuffle=True, num_workers=num_workers,
            pin_memory=True, drop_last=True)

        valLoader = data.DataLoader(
            CityscapesValDataSet(data_dir, val_data_list, f_scale=1, mean=datas['mean']),
            batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True,
            drop_last=False)
    elif dataset == "camvid":
        # inform_data_file collect the information of mean, std and weigth_class
        if not os.path.isfile(inform_data_file):
            print("%s is not found" % (inform_data_file))
            dataCollect = CamVidTrainInform(data_dir, 11, train_set_file=dataset_list, inform_data_file=inform_data_file)
            datas = dataCollect.collectDataAndSave()
            if datas is None:
                print("error while pickling data. Please check.")
                exit(-1)
        else:
            print("find file: ", str(inform_data_file))
            datas = pickle.load(open(inform_data_file, "rb"))

        trainLoader = data.DataLoader(
            CamVidDataSet(data_dir, train_data_list, crop_size=input_size, scale=random_scale,
                          mirror=random_mirror, mean=datas['mean']),
            batch_size=batch_size, shuffle=True, num_workers=num_workers,
            pin_memory=True, drop_last=True)

        valLoader = data.DataLoader(
            CamVidValDataSet(data_dir, val_data_list, f_scale=1, mean=datas['mean']),
            batch_size=1, shuffle=True, num_workers=num_workers, pin_memory=True)
    elif dataset == "ade20k":
        inform_data_file = os.path.join('./dataset/inform/', 'cityscapes_inform.pkl')
        datas = pickle.load(open(inform_data_file, "rb"))

        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])

        data_kwargs = {'transform': input_transform, 'base_size': args.input_size[0]+40,
                       'crop_size': args.input_size[0], 'encode': False}
        train_dataset = ADE20KSegmentation(split='train', mode='train', **data_kwargs)
        val_dataset = ADE20KSegmentation(split='val', mode='val', **data_kwargs)

        train_sampler = make_data_sampler(train_dataset, shuffle=True, distributed=False)
        train_batch_sampler = make_batch_data_sampler(train_sampler, batch_size)
        val_sampler = make_data_sampler(val_dataset, shuffle=False, distributed=False)
        val_batch_sampler = make_batch_data_sampler(val_sampler, batch_size)

        trainLoader = data.DataLoader(dataset=train_dataset,
                                 batch_sampler=train_batch_sampler,
                                 num_workers=args.num_workers,
                                 pin_memory=True)
        valLoader = data.DataLoader(dataset=val_dataset,
                                     batch_sampler=val_batch_sampler,
                                     num_workers=args.num_workers,
                                     pin_memory=True)
    else:
        raise NotImplementedError(
            "This repository now supports datasets: cityscapes, camvid and ade20k, %s is not included" % dataset)
    return datas, trainLoader, valLoader
예제 #7
0
def train_model(args):
    """
    args:
       args: global arguments
    """
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)
    print("=====> checking if inform_data_file exists")
    if not os.path.isfile(args.inform_data_file):
        print("%s is not found" % (args.inform_data_file))
        dataCollect = CamVidTrainInform(
            args.data_dir,
            args.classes,
            train_set_file=args.dataset_list,
            inform_data_file=args.inform_data_file
        )  #collect mean std, weigth_class information
        datas = dataCollect.collectDataAndSave()
        if datas is None:
            print("error while pickling data. Please check.")
            exit(-1)
    else:
        print("find file: ", str(args.inform_data_file))
        datas = pickle.load(open(args.inform_data_file, "rb"))

    print(args)
    global network_type

    if args.cuda:
        print("=====> use gpu id: '{}'".format(args.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
        if not torch.cuda.is_available():
            raise Exception(
                "No GPU found or Wrong gpu id, please run without --cuda")

    args.seed = random.randint(1, 10000)
    print("====> Random Seed: ", args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    cudnn.enabled = True
    M = args.M
    N = args.N
    print("=====> building network")
    model = CGNet.Context_Guided_Network(classes=args.classes, M=M, N=N)
    network_type = "CGNet"
    print("=====> current architeture:  CGNet")

    print("=====> computing network parameters")
    total_paramters = netParams(model)
    print("the number of parameters: " + str(total_paramters))

    print("data['classWeights']: ", datas['classWeights'])
    print('=====> Dataset statistics')
    print('mean and std: ', datas['mean'], datas['std'])

    # define optimization criteria
    weight = torch.from_numpy(datas['classWeights'])
    criteria = CrossEntropyLoss2d(weight, args.ignore_label)

    if args.cuda:
        criteria = criteria.cuda()
        args.gpu_nums = 1
        if torch.cuda.device_count() > 1:
            print("torch.cuda.device_count()=", torch.cuda.device_count())
            args.gpu_nums = torch.cuda.device_count()
            model = torch.nn.DataParallel(model).cuda()
        else:
            print("single GPU for training")
            model = model.cuda()

    args.savedir = (args.savedir + args.dataset + '/' + network_type + "_M" +
                    str(M) + 'N' + str(N) + 'bs' + str(args.batch_size) +
                    'gpu' + str(args.gpu_nums) + "_" + str(args.train_type) +
                    '/')
    if not os.path.exists(args.savedir):
        os.makedirs(args.savedir)

    #Data augmentation, compose the data with transforms
    train_transform = transforms.Compose([transforms.ToTensor()])
    trainLoader = data.DataLoader(CamVidDataSet(args.data_dir,
                                                args.train_data_list,
                                                crop_size=input_size,
                                                scale=args.random_scale,
                                                mirror=args.random_mirror,
                                                mean=datas['mean']),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True,
                                  drop_last=True)
    valLoader = data.DataLoader(CamVidValDataSet(args.data_dir,
                                                 args.val_data_list,
                                                 f_scale=1,
                                                 mean=datas['mean']),
                                batch_size=1,
                                shuffle=True,
                                num_workers=args.num_workers,
                                pin_memory=True,
                                drop_last=True)

    start_epoch = 0
    if args.resume:
        if os.path.isfile(args.resume):
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['model'])
            #model.load_state_dict(convert_state_dict(checkpoint['model']))
            print("=====> loading checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=====> no checkpoint found at '{}'".format(args.resume))

    model.train()
    cudnn.benchmark = True

    logFileLoc = args.savedir + args.logFile
    if os.path.isfile(logFileLoc):
        logger = open(logFileLoc, 'a')
    else:
        logger = open(logFileLoc, 'w')
        logger.write("Parameters: %s" % (str(total_paramters)))
        logger.write(
            "\n%s\t\t%s\t\t%s\t\t%s\t\t%s\t\t" %
            ('Epoch', 'Loss(Tr)', 'Loss(val)', 'mIOU (tr)', 'mIOU (val)'))
    logger.flush()

    optimizer = torch.optim.Adam(model.parameters(),
                                 args.lr, (0.9, 0.999),
                                 eps=1e-08,
                                 weight_decay=5e-4)

    print('=====> beginning training')
    for epoch in range(start_epoch, args.max_epochs):
        #training
        lossTr, per_class_iu_tr, mIOU_tr, lr = train(args, trainLoader, model,
                                                     criteria, optimizer,
                                                     epoch)

        #validation
        if epoch % 50 == 0:
            mIOU_val, per_class_iu = val(args, valLoader, model, criteria)
            logger.write("\n%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.7f" %
                         (epoch, lossTr, mIOU_tr, mIOU_val, lr))
            logger.flush()
            print("epoch: " + str(epoch) + ' Details')
            print(
                "\nEpoch No.: %d\tTrain Loss = %.4f\t mIOU(tr) = %.4f\t mIOU(val) = %.4f\t lr= %.6f"
                % (epoch, lossTr, mIOU_tr, mIOU_val, lr))
        else:
            logger.write("\n%d\t\t%.4f\t\t%.4f\t\t%.7f" %
                         (epoch, lossTr, mIOU_tr, lr))
            logger.flush()
            print("Epoch : " + str(epoch) + ' Details')
            print(
                "\nEpoch No.: %d\tTrain Loss = %.4f\t mIOU(tr) = %.4f\t lr= %.6f"
                % (epoch, lossTr, mIOU_tr, lr))

        #save the model
        model_file_name = args.savedir + '/model_' + str(epoch + 1) + '.pth'
        state = {"epoch": epoch + 1, "model": model.state_dict()}
        if epoch > args.max_epochs - 10:
            torch.save(state, model_file_name)
        elif not epoch % 20:
            torch.save(state, model_file_name)

    logger.close()