示例#1
0
def test_model(args):
    """
    main function for testing 
    args:
       args: global arguments
    """
    print("=====> Check if the cached file exists ")
    if not os.path.isfile(args.inform_data_file):
        print("%s is not found" % (args.inform_data_file))
        dataCollect = CamVidTrainInform(
            args.data_dir,
            args.classes,
            train_set_file=args.dataset_list,
            inform_data_file=args.inform_data_file
        )  #collect mean std, weigth_class information
        datas = dataCollect.collectDataAndSave()
        if datas is None:
            print('Error while pickling data. Please check.')
            exit(-1)
    else:
        print("%s exists" % (args.inform_data_file))
        datas = pickle.load(open(args.inform_data_file, "rb"))

    print(args)
    global network_type

    if args.cuda:
        print("=====> Use gpu id: '{}'".format(args.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
        if not torch.cuda.is_available():
            raise Exception(
                "No GPU found or Wrong gpu id, please run without --cuda")

    args.seed = random.randint(1, 10000)
    print("Random Seed: ", args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
    cudnn.enabled = True

    M = args.M
    N = args.N
    model = CGNet.Context_Guided_Network(classes=args.classes, M=M, N=N)
    network_type = "CGNet"
    print("=====> current architeture:  CGNet_M%sN%s" % (M, N))
    total_paramters = netParams(model)
    print("the number of parameters: " + str(total_paramters))
    print("data['classWeights']: ", datas['classWeights'])
    weight = torch.from_numpy(datas['classWeights'])
    print("=====> Dataset statistics")
    print("mean and std: ", datas['mean'], datas['std'])

    # define optimization criteria
    criteria = CrossEntropyLoss2d(weight, args.ignore_label)
    if args.cuda:
        model = model.cuda()
        criteria = criteria.cuda()

    #load test set
    train_transform = transforms.Compose([transforms.ToTensor()])
    testLoader = data.DataLoader(CamVidValDataSet(args.data_dir,
                                                  args.test_data_list,
                                                  f_scale=1,
                                                  mean=datas['mean']),
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 pin_memory=True,
                                 drop_last=True)

    if args.resume:
        if os.path.isfile(args.resume):
            print("=====> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            #model.load_state_dict(convert_state_dict(checkpoint['model']))
            model.load_state_dict(checkpoint['model'])
        else:
            print("=====> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    print("=====> beginning test")
    print("length of test set:", len(testLoader))
    mIOU_val, per_class_iu = test(args, testLoader, model, criteria)
    print(mIOU_val)
    print(per_class_iu)
示例#2
0
def train_model(args):
    """
    args:
       args: global arguments
    """
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)
    print("=====> checking if inform_data_file exists")
    if not os.path.isfile(args.inform_data_file):
        print("%s is not found" % (args.inform_data_file))
        dataCollect = CamVidTrainInform(
            args.data_dir,
            args.classes,
            train_set_file=args.dataset_list,
            inform_data_file=args.inform_data_file
        )  #collect mean std, weigth_class information
        datas = dataCollect.collectDataAndSave()
        if datas is None:
            print("error while pickling data. Please check.")
            exit(-1)
    else:
        print("find file: ", str(args.inform_data_file))
        datas = pickle.load(open(args.inform_data_file, "rb"))

    print(args)
    global network_type

    if args.cuda:
        print("=====> use gpu id: '{}'".format(args.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
        if not torch.cuda.is_available():
            raise Exception(
                "No GPU found or Wrong gpu id, please run without --cuda")

    args.seed = random.randint(1, 10000)
    print("====> Random Seed: ", args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    cudnn.enabled = True
    M = args.M
    N = args.N
    print("=====> building network")
    model = CGNet.Context_Guided_Network(classes=args.classes, M=M, N=N)
    network_type = "CGNet"
    print("=====> current architeture:  CGNet")

    print("=====> computing network parameters")
    total_paramters = netParams(model)
    print("the number of parameters: " + str(total_paramters))

    print("data['classWeights']: ", datas['classWeights'])
    print('=====> Dataset statistics')
    print('mean and std: ', datas['mean'], datas['std'])

    # define optimization criteria
    weight = torch.from_numpy(datas['classWeights'])
    criteria = CrossEntropyLoss2d(weight, args.ignore_label)

    if args.cuda:
        criteria = criteria.cuda()
        args.gpu_nums = 1
        if torch.cuda.device_count() > 1:
            print("torch.cuda.device_count()=", torch.cuda.device_count())
            args.gpu_nums = torch.cuda.device_count()
            model = torch.nn.DataParallel(model).cuda()
        else:
            print("single GPU for training")
            model = model.cuda()

    args.savedir = (args.savedir + args.dataset + '/' + network_type + "_M" +
                    str(M) + 'N' + str(N) + 'bs' + str(args.batch_size) +
                    'gpu' + str(args.gpu_nums) + "_" + str(args.train_type) +
                    '/')
    if not os.path.exists(args.savedir):
        os.makedirs(args.savedir)

    #Data augmentation, compose the data with transforms
    train_transform = transforms.Compose([transforms.ToTensor()])
    trainLoader = data.DataLoader(CamVidDataSet(args.data_dir,
                                                args.train_data_list,
                                                crop_size=input_size,
                                                scale=args.random_scale,
                                                mirror=args.random_mirror,
                                                mean=datas['mean']),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True,
                                  drop_last=True)
    valLoader = data.DataLoader(CamVidValDataSet(args.data_dir,
                                                 args.val_data_list,
                                                 f_scale=1,
                                                 mean=datas['mean']),
                                batch_size=1,
                                shuffle=True,
                                num_workers=args.num_workers,
                                pin_memory=True,
                                drop_last=True)

    start_epoch = 0
    if args.resume:
        if os.path.isfile(args.resume):
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['model'])
            #model.load_state_dict(convert_state_dict(checkpoint['model']))
            print("=====> loading checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=====> no checkpoint found at '{}'".format(args.resume))

    model.train()
    cudnn.benchmark = True

    logFileLoc = args.savedir + args.logFile
    if os.path.isfile(logFileLoc):
        logger = open(logFileLoc, 'a')
    else:
        logger = open(logFileLoc, 'w')
        logger.write("Parameters: %s" % (str(total_paramters)))
        logger.write(
            "\n%s\t\t%s\t\t%s\t\t%s\t\t%s\t\t" %
            ('Epoch', 'Loss(Tr)', 'Loss(val)', 'mIOU (tr)', 'mIOU (val)'))
    logger.flush()

    optimizer = torch.optim.Adam(model.parameters(),
                                 args.lr, (0.9, 0.999),
                                 eps=1e-08,
                                 weight_decay=5e-4)

    print('=====> beginning training')
    for epoch in range(start_epoch, args.max_epochs):
        #training
        lossTr, per_class_iu_tr, mIOU_tr, lr = train(args, trainLoader, model,
                                                     criteria, optimizer,
                                                     epoch)

        #validation
        if epoch % 50 == 0:
            mIOU_val, per_class_iu = val(args, valLoader, model, criteria)
            logger.write("\n%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.7f" %
                         (epoch, lossTr, mIOU_tr, mIOU_val, lr))
            logger.flush()
            print("epoch: " + str(epoch) + ' Details')
            print(
                "\nEpoch No.: %d\tTrain Loss = %.4f\t mIOU(tr) = %.4f\t mIOU(val) = %.4f\t lr= %.6f"
                % (epoch, lossTr, mIOU_tr, mIOU_val, lr))
        else:
            logger.write("\n%d\t\t%.4f\t\t%.4f\t\t%.7f" %
                         (epoch, lossTr, mIOU_tr, lr))
            logger.flush()
            print("Epoch : " + str(epoch) + ' Details')
            print(
                "\nEpoch No.: %d\tTrain Loss = %.4f\t mIOU(tr) = %.4f\t lr= %.6f"
                % (epoch, lossTr, mIOU_tr, lr))

        #save the model
        model_file_name = args.savedir + '/model_' + str(epoch + 1) + '.pth'
        state = {"epoch": epoch + 1, "model": model.state_dict()}
        if epoch > args.max_epochs - 10:
            torch.save(state, model_file_name)
        elif not epoch % 20:
            torch.save(state, model_file_name)

    logger.close()