def createDatasetSingleThread(self, opts): """ Used for debugging . It is hard to debug a multi-threaded application """ dataset = TextDataset(self.WORD_LIST_FILE, batchSize=opts.batchSize, cache=opts.output, limit=opts.count, overwriteCache=True) results = [dataset.__getitem__(i) for i in range(len(dataset))] print('Total lines count=%d' % (len(dataset) * opts.batchSize))
def threadInitializer(fname, batchSize, cache, limit): global datasetInOtherthread datasetInOtherthread = TextDataset(fname, batchSize=batchSize, cache=cache, limit=limit, overwriteCache=True)
class DataGenerator( Sequence ): def __init__( self, txtFile, **kwargs): self.ds = TextDataset( txtFile, **kwargs ) def __len__(self): return self.ds.__len__() def __getitem__( self, batchIndex ): unNormalized = self.ds.getUnNormalized( batchIndex ) images, labels = normalizeBatch( unNormalized, channel_axis=2 ) labels, label_lengths = converter.encodeStrListRaw( labels, labelWidth ) inputs = { 'the_images': images, 'the_labels': np.array( labels ), 'label_lengths': np.array( label_lengths ), } outputs = {'ctc': np.zeros([ batchSize ])} # dummy data for dummy loss function return (inputs, outputs)
def createDataset(self, opts): dataset = TextDataset(self.WORD_LIST_FILE, batchSize=opts.batchSize, cache=opts.output, limit=opts.count, overwriteCache=True) pool = multiprocessing.Pool(os.cpu_count(), initializer=threadInitializer, initargs=(self.WORD_LIST_FILE, opts.batchSize, opts.output, opts.count)) results = [ pool.apply_async(processInThread, (i, )) for i in range(len(dataset)) ] print('Total lines count=%d' % (len(dataset) * opts.batchSize)) for idx, result in enumerate(results): result.get()
opt.outdir = 'expr' os.system('mkdir -p {0}'.format(opt.outdir)) opt.manualSeed = random.randint(1, 10000) # fix seed print("Random Seed: ", opt.manualSeed) random.seed(opt.manualSeed) np.random.seed(opt.manualSeed) torch.manual_seed(opt.manualSeed) if torch.cuda.is_available() and not opt.cuda: print( "WARNING: You have a CUDA device, so you should probably run with --cuda" ) train_loader = TextDataset(opt.traindata, batchSize=opt.batchSize, limit=opt.traindata_limit, cache=opt.traindata_cache) test_loader = TextDataset(opt.valdata, batchSize=opt.batchSize, limit=opt.valdata_limit, cache=opt.valdata_cache) nclass = converter.totalGlyphs print('Number of char class = %d' % nclass) criterion = CTCLoss(blank=nclass - 1) # custom weights initialization called on crnn def weights_init(m): classname = m.__class__.__name__
def __init__(self, txtFile, **kwargs): self.ds = TextDataset(txtFile, **kwargs)