Esempio n. 1
0
    def loadData(self):
        # First check if unzipped directory exists
        DatasetDir = os.path.join(ptUtils.expandTilde(self.DataDir),
                                  os.path.splitext(self.FileName)[0])
        if os.path.exists(DatasetDir) == False:
            DataPath = os.path.join(ptUtils.expandTilde(self.DataDir),
                                    self.FileName)
            if os.path.exists(DataPath) == False:
                if self.isDownload:
                    print('[ INFO ]: Downloading', DataPath)
                    ptUtils.downloadFile(self.DataURL, DataPath)

                if os.path.exists(DataPath) == False:  # Not downloaded
                    raise RuntimeError('Specified data path does not exist: ' +
                                       DataPath)
            # Unzip
            with zipfile.ZipFile(DataPath, 'r') as File2Unzip:
                print('[ INFO ]: Unzipping.')
                File2Unzip.extractall(ptUtils.expandTilde(self.DataDir))

        FilesPath = os.path.join(DatasetDir, 'val/')
        if self.isTrainData:
            FilesPath = os.path.join(DatasetDir, 'train/')

        CameraIdxStr = '*'
        if self.CameraIdx >= 0 and self.CameraIdx <= 4:
            CameraIdxStr = str(self.CameraIdx).zfill(2)

        print('[ INFO ]: Loading data for camera {}.'.format(CameraIdxStr))

        self.RGBList = glob.glob(FilesPath + '/**/frame_*_cam_' +
                                 CameraIdxStr + '_color.*')
        self.RGBList.sort()
        self.InstMaskList = glob.glob(FilesPath + '/**/frame_*_cam_' +
                                      CameraIdxStr + '_binmask.*')
        self.InstMaskList.sort()
        self.NOCSList = glob.glob(FilesPath + '/**/frame_*_cam_' +
                                  CameraIdxStr + '_nocs.*')
        self.NOCSList.sort()

        if self.RGBList is None or self.InstMaskList is None or self.NOCSList is None:
            raise RuntimeError('[ ERR ]: No files found during data loading.')

        if len(self.RGBList) != len(self.InstMaskList) or len(
                self.InstMaskList) != len(self.NOCSList):
            raise RuntimeError('[ ERR ]: Data corrupted. Sizes do not match')

        print('[ INFO ]: Found {} items in dataset.'.format(len(self)))
        DatasetLength = self.DataLimit
        self.RGBList = self.RGBList[:DatasetLength]
        self.InstMaskList = self.InstMaskList[:DatasetLength]
        self.NOCSList = self.NOCSList[:DatasetLength]
Esempio n. 2
0
    def loadData(self):
        # First check if unzipped directory exists
        DatasetDir = os.path.join(ptUtils.expandTilde(self.DataDir), os.path.splitext(self.FileName)[0])
        if os.path.exists(DatasetDir) == False:
            DataPath = os.path.join(ptUtils.expandTilde(self.DataDir), self.FileName)
            if os.path.exists(DataPath) == False:
                if self.isDownload:
                    print('[ INFO ]: Downloading', DataPath)
                    ptUtils.downloadFile(self.DataURL, DataPath)

                if os.path.exists(DataPath) == False: # Not downloaded
                    raise RuntimeError('Specified data path does not exist: ' + DataPath)
            # Unzip
            with zipfile.ZipFile(DataPath, 'r') as File2Unzip:
                print('[ INFO ]: Unzipping.')
                File2Unzip.extractall(ptUtils.expandTilde(self.DataDir))

        FilesPath = os.path.join(DatasetDir, 'val/')
        if self.isTrainData:
            FilesPath = os.path.join(DatasetDir, 'train/')

        self.RGBList = (glob.glob(FilesPath + '/*_VertexColors.png'))
        self.RGBList.sort()
        self.InstMaskList = (glob.glob(FilesPath + '/*_InstanceMask.png'))
        self.InstMaskList.sort()
        self.NOCSList = (glob.glob(FilesPath + '/*_NOCS.png'))
        self.NOCSList.sort()

        if self.RGBList is None or self.InstMaskList is None or self.NOCSList is None:
            raise RuntimeError('[ ERR ]: No files found during data loading.')

        if len(self.RGBList) != len(self.InstMaskList) or len(self.InstMaskList) != len(self.NOCSList):
            raise RuntimeError('[ ERR ]: Data corrupted. Sizes do not match')

        DatasetLength = self.DataLimit
        self.RGBList = self.RGBList[:DatasetLength]
        self.InstMaskList = self.InstMaskList[:DatasetLength]
        self.NOCSList = self.NOCSList[:DatasetLength]

        self.RGBs = []
        self.InstMasks = []
        self.NOCSs = []
        if self.LoadMemory:
            print('[ INFO ]: Loading all images to memory.')
            for RGBFile, InstMaskFile, NOCSFile in zip(self.RGBList, self.InstMaskList, self.NOCSList):
                RGB, InstMask, NOCS = self.loadImages(RGBFile, InstMaskFile, NOCSFile)
                self.RGBs.append(RGB)
                self.InstMasks.append(InstMask)
                self.NOCSs.append(NOCS)
Esempio n. 3
0
    def loadCheckpoint(self, Path=None, Device='cpu'):
        if Path is None:
            self.ExptDirPath = os.path.join(ptUtils.expandTilde(self.Config.Args.output_dir), self.Config.Args.expt_name)
            print('[ INFO ]: Loading from latest checkpoint.')
            CheckpointDict = ptUtils.loadLatestPyTorchCheckpoint(self.ExptDirPath, map_location=Device)
        else: # Load latest
            print('[ INFO ]: Loading from checkpoint {}'.format(Path))
            CheckpointDict = ptUtils.loadPyTorchCheckpoint(Path)

        self.load_state_dict(CheckpointDict['ModelStateDict'])
Esempio n. 4
0
    def __init__(self, InputArgs=None, isPrint=True):
        self.Parser = argparse.ArgumentParser(description='Parse arguments for a PyTorch neural network.', fromfile_prefix_chars='@')

        # Search params
        self.Parser.add_argument('--learning-rate', help='Choose the learning rate.', required=False, default=0.001,
                            type=RestrictedFloat_N10_100)
        self.Parser.add_argument('--batch-size', help='Choose mini-batch size.', choices=range(1, 4096), metavar='1..4096',
                            required=False, default=128, type=int)

        # Machine-specific params
        self.Parser.add_argument('--expt-name', help='Provide a name for this experiment.')
        self.Parser.add_argument('--input-dir', help='Provide the input directory where datasets are stored.')
        # -----
        self.Parser.add_argument('--output-dir',
                            help='Provide the *absolute* output directory where checkpoints, logs, and other output will be stored (under expt_name).')
        self.Parser.add_argument('--rel-output-dir',
                            help='Provide the *relative* (pwd or config file) output directory where checkpoints, logs, and other output will be stored (under expt_name).')
        # -----
        self.Parser.add_argument('--epochs', help='Choose number of epochs.', choices=range(1, 10000), metavar='1..10000',
                            required=False, default=10, type=int)
        self.Parser.add_argument('--save-freq', help='Choose epoch frequency to save checkpoints. Zero (0) will only at the end of training [not recommended].', choices=range(0, 10000), metavar='0..10000',
                            required=False, default=5, type=int)

        self.Args, _ = self.Parser.parse_known_args(InputArgs)

        if self.Args.expt_name is None:
            raise RuntimeError('No experiment name (--expt-name) provided.')

        if self.Args.rel_output_dir is None and self.Args.output_dir is None:
            raise RuntimeError('One or both of --output-dir or --rel-output-dir is required.')

        if self.Args.rel_output_dir is not None: # Relative path takes precedence
            if self.Args.output_dir is not None:
                print('[ INFO ]: Relative path taking precedence to absolute path.')
            DirPath = os.getcwd() # os.path.dirname(os.path.realpath(__file__))
            for Arg in InputArgs:
                if '@' in Arg: # Config file is passed, path should be relative to config file
                    DirPath = os.path.abspath(os.path.dirname(ptUtils.expandTilde(Arg[1:]))) # Abs directory path of config file
                    break
            self.Args.output_dir = os.path.join(DirPath, self.Args.rel_output_dir)
            print('[ INFO ]: Converted relative path {} to absolute path {}'.format(self.Args.rel_output_dir, self.Args.output_dir))

        # Logging directory and file
        self.ExptDirPath = ''
        self.ExptDirPath = os.path.join(ptUtils.expandTilde(self.Args.output_dir), self.Args.expt_name)
        if os.path.exists(self.ExptDirPath) == False:
            os.makedirs(self.ExptDirPath)

        self.ExptLogFile = os.path.join(self.ExptDirPath, self.Args.expt_name + '_' + ptUtils.getTimeString('humanlocal') + '.log')
        # if os.path.exists(self.ExptLogFile) == False:
        with open(self.ExptLogFile, 'w+', newline='') as f:
            os.utime(self.ExptLogFile, None)

        sys.stdout = ptUtils.ptLogger(sys.stdout, self.ExptLogFile)
        sys.stderr = ptUtils.ptLogger(sys.stderr, self.ExptLogFile)

        if isPrint:
            print('-'*60)
            ArgsDict = vars(self.Args)
            for Arg in ArgsDict:
                if ArgsDict[Arg] is not None:
                    print('{:<15}:   {:<50}'.format(Arg, ArgsDict[Arg]))
                else:
                    print('{:<15}:   {:<50}'.format(Arg, 'NOT DEFINED'))
            print('-'*60)
Esempio n. 5
0
    def loadData(self):
        self.FrameFiles = {}
        # First check if unzipped directory exists
        DatasetDir = os.path.join(ptUtils.expandTilde(self.DataDir),
                                  os.path.splitext(self.FileName)[0])
        if os.path.exists(DatasetDir) == False:
            DataPath = os.path.join(ptUtils.expandTilde(self.DataDir),
                                    self.FileName)
            if os.path.exists(DataPath) == False:
                if self.isDownload:
                    print('[ INFO ]: Downloading', DataPath)
                    ptUtils.downloadFile(self.DataURL, DataPath)

                if os.path.exists(DataPath) == False:  # Not downloaded
                    raise RuntimeError('Specified data path does not exist: ' +
                                       DataPath)
            # Unzip
            with zipfile.ZipFile(DataPath, 'r') as File2Unzip:
                print('[ INFO ]: Unzipping.')
                File2Unzip.extractall(ptUtils.expandTilde(self.DataDir))

        FilesPath = os.path.join(DatasetDir, 'val/')
        if self.isTrainData:
            FilesPath = os.path.join(DatasetDir, 'train/')

        GlobPrepend = '_'.join(str(i) for i in self.FrameLoadStr)
        GlobCache = os.path.join(DatasetDir, 'glob_' + GlobPrepend + '.cache')

        if os.path.exists(GlobCache):
            print('[ INFO ]: Loading from glob cache:', GlobCache)
            with open(GlobCache, 'rb') as fp:
                for Str in self.FrameLoadStr:
                    self.FrameFiles[Str] = pickle.load(fp)
        else:
            print('[ INFO ]: Saving to glob cache:', GlobCache)

            for Str in self.FrameLoadStr:
                print(os.path.join(FilesPath, '*' + Str + '*.*'))
                self.FrameFiles[Str] = glob.glob(
                    os.path.join(FilesPath, '*' + Str + '*.*'))
                self.FrameFiles[Str].sort()

            with open(GlobCache, 'wb') as fp:
                for Str in self.FrameLoadStr:
                    pickle.dump(self.FrameFiles[Str], fp)

        FrameFilesLengths = []
        for K, CurrFrameFiles in self.FrameFiles.items():
            if not CurrFrameFiles:
                raise RuntimeError('None data for {}'.format(K))
            if len(CurrFrameFiles) == 0:
                raise RuntimeError('No files found during data loading for', K)
            FrameFilesLengths.append(len(CurrFrameFiles))

        if len(set(FrameFilesLengths)) != 1:
            raise RuntimeError('Data corrupted. Sizes do not match',
                               FrameFilesLengths)

        TotSize = len(self)
        print('[ INFO ]: Found {} items in dataset.'.format(TotSize))
        DatasetLength = math.ceil((self.DataLimit / 100) * TotSize)

        if DatasetLength is None:
            print('[ INFO ]: Loading all items.')
        else:
            print('[ INFO ]: Loading only {} items.'.format(DatasetLength))
        for K in self.FrameFiles:
            self.FrameFiles[K] = self.FrameFiles[K][:DatasetLength]