Ejemplo n.º 1
0
def trainSamples(primaryDatasets, primaryClasses, primaryInstances,
                 helperInstances):

    for dataSet in primaryDataSets:
        for cls in primaryClasses:
            for instance in primaryInstances:
                for helperInstance in helperInstances:

                    # if the number of primary instances are larger than the number of helper
                    # instances, no need to calculate MMD
                    if instance > helperInstance:
                        continue

                    # get a fixed helper class for a particular dataset and primary class
                    helperClass = getHelperClass(dataSet, cls)

                    if helperClass == -1:
                        continue
                    print(
                        'Primary Class: {} Helper Class: {} Primary Instances: {} Helper Instance {}'
                        .format(cls, helperClass, instance, helperInstance))

                    primaryFileName = dataSet + '_' + str(cls) + '_' + str(
                        instance)
                    helperFileName = dataSet + '_' + str(
                        helperClass) + '_' + str(helperInstance)
                    dataFolder = rootDir + str(dataSet)

                    x = loadDataset(dataSet, cls, instance)
                    y = loadDataset(dataSet, helperClass, helperInstance)
                    primaryTrainLoader = torch.utils.data.DataLoader(
                        x,
                        batch_size=batchSize,
                        shuffle=True,
                        num_workers=4,
                        drop_last=True)
                    helperTrainLoader = torch.utils.data.DataLoader(
                        y,
                        batch_size=batchSize,
                        shuffle=True,
                        num_workers=4,
                        drop_last=True)
                    numOutputChannels = getChannels(dataSet)
                    epochs = getEpochs(dataSet, instance)

                    print epochs
                    train(primaryFileName,
                          helperFileName,
                          primaryTrainLoader,
                          helperTrainLoader,
                          instance,
                          numOutputChannels,
                          epochs=epochs)
Ejemplo n.º 2
0
def getFakeData(dataSet, fakeClasses, instances, mmdFlag = 0, numHelperInstances=1000):
    
    '''
    Output image pixels between (-1,1)
    '''
    dataFolder = resultDir+'results/nonMMD/compressed/'+dataSet+'/'+dataSet
    if mmdFlag==1:
        dataFolder = resultDir+'results/MMDall/compressed/'+dataSet+'/'+dataSet
        
    fakeClasses =  sorted(fakeClasses)
    imageArray, labelArray = getInitialArray(dataSet)
    
    for i in fakeClasses:
        if mmdFlag==1:
            tempImageArray = np_load(dataFolder+'_'+str(i)+'_'+str(getHelperClass(dataSet,i))+'_'+str(instances)+'_'+str(numHelperInstances)+'.npy')
        elif mmdFlag==0:
            tempImageArray = np_load(dataFolder+'_'+str(i)+'_'+str(instances)+'.npy')
        
        
        tempImageArray = tempImageArray.reshape(tempImageArray.shape[0],-1)
        
        tempLabelArray = np.empty(tempImageArray.shape[0])
        tempLabelArray.fill(i)        
        
        #print tempImageArray.shape, imageArray.shape
        imageArray = np.concatenate([imageArray, tempImageArray])
        labelArray = np.concatenate([labelArray, tempLabelArray])
    imageArray = imageArray[1:]
    labelArray = labelArray[1:]
    
    
    # random shuffling of images and labels
    p = np.random.permutation(imageArray.shape[0])
    imageArray = imageArray[p]
    labelArray = labelArray[p]

    return imageArray, labelArray
Ejemplo n.º 3
0
        sync(fh)

if __name__=='__main__':
    

    from model_28 import Generator, Discriminator
    numGenFilter=64
    numDiscFilter=32
    imageSize = 28
    
    primaryDataSet = ['MNIST']
    helperDataSet = ['SVHN-BW']
    
    primaryClass = [1]
    primaryInstances = [5000]
    helperInstances = [5000]
    batchSizes = [50]
    
    for d in primaryDataSet:
        for hd in helperDataSet:
            for pc in primaryClass:
                for pi in primaryInstances:
                    for hi in helperInstances:
                        if pi > hi:
                            continue
                        for b in batchSizes:
                            hc = getHelperClass(d,pc)
                            if hc==-1:
                                continue
                            test(d, hd, pc, hc, pi, hi)                
Ejemplo n.º 4
0
def getFakeData(dataSet, fakeClasses, instances, mmdFlag = 0, numHelperInstances=1000):
    
    '''
    Output image pixels between (-1,1)
    '''
    dataFolder = resultDir+'results/nonMMD/compressed/'+dataSet+'/'+dataSet
    if mmdFlag==1:
        dataFolder = resultDir+'results/MMD/compressed/'+dataSet+'/'+dataSet
        
    fakeClasses =  sorted(fakeClasses)
    imageArray, labelArray = getInitialArray(dataSet)
    
    for i in fakeClasses:
        if mmdFlag==1:
            tempImageArray = np_load(dataFolder+'_'+str(i)+'_'+str(getHelperClass(dataSet,i))+'_'+str(instances)+'_'+str(numHelperInstances)+'.npy')
        elif mmdFlag==0:
            tempImageArray = np_load(dataFolder+'_'+str(i)+'_'+str(instances)+'.npy')
        
        # no need to resize as model is producing image of same size
        '''
        if getChannels(dataSet)==1:
            tempImageArrayResize = np.zeros((tempImageArray.shape[0],
                                             sizeOfImage,
                                             sizeOfImage))

            for j in range(tempImageArray.shape[0]):

                tempImage = PIL.Image.fromarray(np.uint8(tempImageArray[j,:,:]))
                tempImage = tempImage.resize((sizeOfImage,sizeOfImage),
                                             PIL.Image.ANTIALIAS)
                tempImageArrayResize[j,:,:] = np.asarray(tempImage)
                
        elif getChannels(dataSet)==3:
            tempImageArrayResize = np.zeros((tempImageArray.shape[0],
                                             sizeOfImage,
                                             sizeOfImage,
                                             3))
            tempImageArray = tempImageArray.transpose([0,2,3,1])
            
            for j in range(tempImageArray.shape[0]):
                tempImage = PIL.Image.fromarray(np.uint8(tempImageArray[j,:,:,:]), 'RGB')
                tempImage = tempImage.resize((sizeOfImage,sizeOfImage),
                                             PIL.Image.ANTIALIAS)
                tempImageArrayResize[j,:,:,:] = np.asarray(tempImage) 
            tempImageArray = tempImageArray.transpose([0,3,1,2])
        '''
        
        tempImageArray = tempImageArray.reshape(tempImageArray.shape[0],-1)
        
        tempLabelArray = np.empty(tempImageArray.shape[0])
        tempLabelArray.fill(i)        
        
        #print tempImageArray.shape, imageArray.shape
        imageArray = np.concatenate([imageArray, tempImageArray])
        labelArray = np.concatenate([labelArray, tempLabelArray])
    imageArray = imageArray[1:]
    labelArray = labelArray[1:]
    
    
    # random shuffling of images and labels
    p = np.random.permutation(imageArray.shape[0])
    imageArray = imageArray[p]
    labelArray = labelArray[p]

    return imageArray, labelArray