Exemple #1
0
 def createDatasetSingleThread(self, opts):
     """ Used for debugging . It is hard to debug a multi-threaded application """
     dataset = TextDataset(self.WORD_LIST_FILE,
                           batchSize=opts.batchSize,
                           cache=opts.output,
                           limit=opts.count,
                           overwriteCache=True)
     results = [dataset.__getitem__(i) for i in range(len(dataset))]
     print('Total lines count=%d' % (len(dataset) * opts.batchSize))
Exemple #2
0
def threadInitializer(fname, batchSize, cache, limit):
    global datasetInOtherthread
    datasetInOtherthread = TextDataset(fname,
                                       batchSize=batchSize,
                                       cache=cache,
                                       limit=limit,
                                       overwriteCache=True)
Exemple #3
0
class DataGenerator( Sequence ):
    def __init__( self, txtFile, **kwargs):
        self.ds = TextDataset( txtFile, **kwargs )

    def __len__(self):
        return self.ds.__len__()

    def __getitem__( self, batchIndex ):
        unNormalized =  self.ds.getUnNormalized( batchIndex )
        images, labels = normalizeBatch( unNormalized, channel_axis=2 )
        labels, label_lengths  = converter.encodeStrListRaw( labels, labelWidth )
        inputs = {
                'the_images': images,
                'the_labels': np.array( labels ),
                'label_lengths': np.array( label_lengths ),
                }
        outputs = {'ctc': np.zeros([ batchSize ])}  # dummy data for dummy loss function
        return (inputs, outputs)
Exemple #4
0
 def createDataset(self, opts):
     dataset = TextDataset(self.WORD_LIST_FILE,
                           batchSize=opts.batchSize,
                           cache=opts.output,
                           limit=opts.count,
                           overwriteCache=True)
     pool = multiprocessing.Pool(os.cpu_count(),
                                 initializer=threadInitializer,
                                 initargs=(self.WORD_LIST_FILE,
                                           opts.batchSize, opts.output,
                                           opts.count))
     results = [
         pool.apply_async(processInThread, (i, ))
         for i in range(len(dataset))
     ]
     print('Total lines count=%d' % (len(dataset) * opts.batchSize))
     for idx, result in enumerate(results):
         result.get()
Exemple #5
0
    opt.outdir = 'expr'
os.system('mkdir -p {0}'.format(opt.outdir))

opt.manualSeed = random.randint(1, 10000)  # fix seed
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
np.random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)

if torch.cuda.is_available() and not opt.cuda:
    print(
        "WARNING: You have a CUDA device, so you should probably run with --cuda"
    )

train_loader = TextDataset(opt.traindata,
                           batchSize=opt.batchSize,
                           limit=opt.traindata_limit,
                           cache=opt.traindata_cache)
test_loader = TextDataset(opt.valdata,
                          batchSize=opt.batchSize,
                          limit=opt.valdata_limit,
                          cache=opt.valdata_cache)

nclass = converter.totalGlyphs
print('Number of char class = %d' % nclass)

criterion = CTCLoss(blank=nclass - 1)


# custom weights initialization called on crnn
def weights_init(m):
    classname = m.__class__.__name__
Exemple #6
0
 def __init__(self, txtFile, **kwargs):
     self.ds = TextDataset(txtFile, **kwargs)