示例#1
0
def saveImageSamples(dataSet, classes, numOfInstances):
    '''
    dataset : string
    classes : list 
    '''
        
    x,y = getRealData(dataSet, classes, numOfInstances, mode='train')
    numOfChannels = getChannels(dataSet)
    sampleImage = getImageSamples(x, numOfChannels=numOfChannels)

    # plot the figure of generated samples and save
    fig = plt.figure()

    plt.imshow(sampleImage, cmap='gray')
    plt.axis('off')
    
    plotFolderName = 'samples'+'/'+dataSet+'/'
    checkAndCreateFolder(plotFolderName)
    
    if len(classes)==9:
        plotFileName = dataSet+'_'+'all'+'.png'
    else :
        plotFileName = dataSet+'_'+str(classes[0])+'.png'
        
    plt.savefig(plotFolderName + plotFileName, bbox_inches='tight')
    plt.show()
示例#2
0
    def __init__(self,  variables, domainA_loader, domainB_loader):
        
        self.domainB_loader = domainB_loader
        self.domainA_loader = domainA_loader
        
        self.domainA_channels = getChannels(variables.domainA)
        self.domainB_channels = getChannels(variables.domainB)
        
        self.g12 = None
        self.g21 = None
        self.d1 = None
        self.d2 = None
        
        # optimizer parameters
        self.g_optimizer = None
        self.d_optimizer = None
        self.beta1 = variables.beta1
        self.beta2 = variables.beta2
        self.lr = variables.learningRate
        
        # cycleGAN 
        self.use_reconst_loss = variables.use_reconst_loss
        
        # semi-supervised GAN
        self.use_labels = variables.use_labels
        

        self.numGenFilter = variables.numGenFilter
        self.numDiscFilter = variables.numDiscFilter
        
        self.epochs = variables.epochs
        self.batchSize = variables.batchSize
        self.num_classes = variables.num_classes

        # bookkeeping
        self.log_step = variables.log_step
        self.sample_step = variables.sample_step
        
        self.sample_path = variables.sample_path
        self.model_path = variables.model_path
        self.name = variables.name
        
        self.build_model()
示例#3
0
def showImageMatrix(dataSet, primaryClass, helperClass, primaryInstances, helperInstances, showImage):
    '''
    Inputs :
    
    dataSets : List : Datasets for which samples are to be genrated
    instances : List : Number of instances to be used from original dataset
    classes : List : Classes for which samples are to be generated
    
    Outputs :
    
    5x5 image matrix 
    
    '''
    fileName = resultDir+'results/MMD'+'/'+'compressed'+'/'+dataSet+'/'+ dataSet + '_' \
            + str(primaryClass) + '_' + str(helperClass) + '_' + str(primaryInstances) + '_' \
            + str(helperInstances) + '.npy'
            
    images = np_load(fileName)
    
    # get random list of images to be displayed
    randomList = np.random.randint(0,1000,(25))
    imageList = images[randomList]
    
    
    fmt = 'png'
    
    # need to generalise this snippet
    fig, axes = plt.subplots(5,5)
    fig.tight_layout()
    fig.subplots_adjust(wspace=-0.7, hspace=-0.1)
    plt.axis('off')
    
    numOutputChannels = getChannels(dataSet)
    if numOutputChannels==3:
        imageList = np.transpose(imageList,(0,2,3,1))
        
    for i in range(5):
        for j in range(5):
            f = StringIO()
                        
            image = PIL.Image.fromarray((imageList[i*5+j]).astype('uint8'))
            
            image.save(f,fmt)
            axes[i,j].imshow((imageList[i*5+j]).astype('uint8'),cmap='gray')
            axes[i,j].axis('off')
            axes[i,j].set_xticklabels([])
            axes[i,j].set_yticklabels([])
            axes[i,j].set_aspect("equal")
    
    plotFileName = resultDir+'results'+'/'+'MMD/samples'+'/'+dataSet+'/'+dataSet+ '_' + str(primaryClass) + '_' + str(helperClass) + '_' + str(primaryInstances) + '_' \
            + str(helperInstances)
    plt.savefig(plotFileName, bbox_inches='tight')
    if showImage==1:
        plt.show()
示例#4
0
def trainSamples(primaryDatasets, primaryClasses, primaryInstances,
                 helperInstances):

    for dataSet in primaryDataSets:
        for cls in primaryClasses:
            for instance in primaryInstances:
                for helperInstance in helperInstances:

                    # if the number of primary instances are larger than the number of helper
                    # instances, no need to calculate MMD
                    if instance > helperInstance:
                        continue

                    # get a fixed helper class for a particular dataset and primary class
                    helperClass = getHelperClass(dataSet, cls)

                    if helperClass == -1:
                        continue
                    print(
                        'Primary Class: {} Helper Class: {} Primary Instances: {} Helper Instance {}'
                        .format(cls, helperClass, instance, helperInstance))

                    primaryFileName = dataSet + '_' + str(cls) + '_' + str(
                        instance)
                    helperFileName = dataSet + '_' + str(
                        helperClass) + '_' + str(helperInstance)
                    dataFolder = rootDir + str(dataSet)

                    x = loadDataset(dataSet, cls, instance)
                    y = loadDataset(dataSet, helperClass, helperInstance)
                    primaryTrainLoader = torch.utils.data.DataLoader(
                        x,
                        batch_size=batchSize,
                        shuffle=True,
                        num_workers=4,
                        drop_last=True)
                    helperTrainLoader = torch.utils.data.DataLoader(
                        y,
                        batch_size=batchSize,
                        shuffle=True,
                        num_workers=4,
                        drop_last=True)
                    numOutputChannels = getChannels(dataSet)
                    epochs = getEpochs(dataSet, instance)

                    print epochs
                    train(primaryFileName,
                          helperFileName,
                          primaryTrainLoader,
                          helperTrainLoader,
                          instance,
                          numOutputChannels,
                          epochs=epochs)
示例#5
0
def trainSamples(primaryDomain,
                 helperDomain,
                 primaryInstances,
                 helperInstances,
                 fewShotInstances,
                 primaryClasses,
                 helperClasses,
                 fewShotClasses,
                 epochs=100):

    primaryClassList, primaryInstanceList, helperClassList, helperInstanceList = getAttributes(
        primaryInstances, helperInstances, fewShotInstances, primaryClasses,
        helperClasses, fewShotClasses)

    # find the k-NN between train and test images
    x = loadDataset(primaryDomain,
                    primaryClassList,
                    primaryInstanceList,
                    mode='train')
    numOutputChannels = getChannels(primaryDomain)

    tempImageArray = np.asarray([data[0].numpy().squeeze() for data in x])
    tempLabelArray = np.asarray([np.asarray(data[1]) for data in x])
    tempImageArray = tempImageArray.reshape(tempImageArray.shape[0], -1)

    sampleImage = getImageSamples(primaryDomain,
                                  tempImageArray,
                                  numOfChannels=numOutputChannels)

    # plot the figure of generated samples and save
    fig = plt.figure(figsize=(20, 10))
    plt.imshow(sampleImage)
    plt.axis('off')
    plt.show()

    # load the dataset parallaly
    primaryTrainLoader = torch.utils.data.DataLoader(x,
                                                     batch_size=batchSize,
                                                     shuffle=True,
                                                     num_workers=4,
                                                     drop_last=True)

    fileName = primaryDomain + '_' + str(fewShotClasses) + '_' + str(
        primaryInstances) + '_' + str(fewShotInstances)
    train(fileName,
          primaryTrainLoader,
          primaryInstanceList,
          primaryClasses,
          numOutputChannels=numOutputChannels,
          epochs=epochs)
def trainSamples(primaryDataSet, helperDataSet, primaryClasses,
                 primaryInstances, helperInstances):

    for cls in primaryClasses:
        for instance in primaryInstances:
            for helperInstance in helperInstances:

                # if the number of primary instances are larger than the number of helper
                # instances, no need to calculate MMD
                if instance > helperInstance:
                    continue

                # the primary class is same as helper class in this case
                helperClass = cls

                print(
                    'Primary Class: {} Helper Class: {} Primary Instances: {} Helper Instance {}'
                    .format(cls, helperClass, instance, helperInstance))

                primaryFileName = primaryDataSet + '_' + str(cls) + '_' + str(
                    instance)
                helperFileName = helperDataSet + '_' + str(
                    helperClass) + '_' + str(helperInstance)

                x = loadDataset(primaryDataSet, cls, instance)
                y = loadDataset(helperDataSet, helperClass, helperInstance)

                primaryTrainLoader = torch.utils.data.DataLoader(
                    x,
                    batch_size=batchSize,
                    shuffle=True,
                    num_workers=4,
                    drop_last=False)
                helperTrainLoader = torch.utils.data.DataLoader(
                    y,
                    batch_size=batchSize,
                    shuffle=True,
                    num_workers=4,
                    drop_last=True)
                numOutputChannels = getChannels(primaryDataSet)
                epochs = getEpochs(primaryDataSet, instance)

                train(primaryFileName,
                      helperFileName,
                      primaryTrainLoader,
                      helperTrainLoader,
                      instance,
                      numOutputChannels,
                      epochs=epochs)
示例#7
0
def trainSamples(primaryDatasets, primaryClasses, primaryInstances,
                 helperInstances):

    for dataSet in primaryDataSets:
        for cls in primaryClasses:
            for instance in primaryInstances:
                for helperInstance in helperInstances:

                    # if the number of primary instances are larger than the number of helper
                    # instances, no need to calculate MMD
                    if instance > helperInstance:
                        continue

                    primaryFileName = dataSet + '_' + str(cls) + '_' + str(
                        instance)
                    helperFileName = dataSet + '_' + 'all' + '_' + str(
                        helperInstance)
                    dataFolder = rootDir + str(dataSet)

                    x = loadDataset(dataSet, cls, instance, except_=0)
                    y = loadDataset(dataSet, cls, instance, except_=1)

                    primaryTrainLoader = torch.utils.data.DataLoader(
                        x,
                        batch_size=batchSize,
                        shuffle=True,
                        num_workers=4,
                        drop_last=True)
                    helperTrainLoader = torch.utils.data.DataLoader(
                        y,
                        batch_size=batchSize,
                        shuffle=True,
                        num_workers=4,
                        drop_last=True)
                    numOutputChannels = getChannels(dataSet)
                    epochs = getEpochs(dataSet, instance)

                    train(primaryFileName,
                          helperFileName,
                          primaryTrainLoader,
                          helperTrainLoader,
                          instance,
                          numOutputChannels,
                          epochs=epochs)
示例#8
0
def getRealData(dataSet, realClasses, instances, mode='train'):
    '''
    Ouput image pixels between (-1 to 1)
    ''' 
    realClasses = sorted(realClasses)
    
    numOfChannels = getChannels(dataSet)
    imageSize = getImageSize(dataSet)
    
    lenArray = len(realClasses)*instances    
    imageArray = np.zeros((lenArray, numOfChannels*imageSize*imageSize))
    labelArray = np.zeros((lenArray))
    
    initialPoint=0
    for i in realClasses:
        tupleArray = loadDataset(dataSet, i, instances, mode=mode)
        
        tempImageArray = np.asarray([ data[0].numpy().squeeze() for data in tupleArray ])
        tempLabelArray = np.asarray([ np.asarray(data[1]) for data in tupleArray ])
        
        tempImageArray = tempImageArray.reshape(tempImageArray.shape[0],-1)
        
        #print imageArray.shape, tempImageArray.shape
        imageArray[initialPoint:(initialPoint + tempImageArray.shape[0])] = tempImageArray 
        labelArray[initialPoint:(initialPoint + tempImageArray.shape[0])] = tempLabelArray 
        initialPoint = tempImageArray.shape[0] + initialPoint
        
    imageArray = imageArray[:initialPoint,:]
    labelArray = labelArray[:initialPoint]
    
    # random shuffling of images and labels
    p = np.random.permutation(imageArray.shape[0])
    imageArray = imageArray[p]
    labelArray = labelArray[p]
    
    return imageArray, labelArray
示例#9
0
def test(primaryDataSet, helperDataSet, primaryClass, helperClass, primaryInstances, helperInstances):
    '''
    Inputs :
    
    dataSets : List : Datasets for which samples are to be genrated
    instances : List : Number of instances to be used from original dataset
    classes : List : Classes for which samples are to be generated
    
    Output :
    
    File with 1000 compressed images generated by GAN
    
    '''
    helperClass = primaryClass
    
    modelFolder = resultDir + 'models/crossDataSetMMDall'+'/'+primaryDataSet+'/'
    print (primaryDataSet, helperDataSet, primaryClass, helperClass, primaryInstances, helperInstances, getEpochs(primaryDataSet,primaryInstances)-1)
    modelFile = modelFolder + primaryDataSet + '_' + helperDataSet + '_' + \
                str(primaryClass) + '_' + str(helperClass) + '_' + \
                str(primaryInstances) + '_' + str(helperInstances)+'_'+ \
                str(getEpochs(primaryDataSet,primaryInstances)-1)+'.pt'
        

    
    print ('Generating examples for Dataset: '+primaryDataSet+
           ' Primary Class: '+str(primaryClass)+
           ' Helper Class: '+str(helperClass)+
           ' Primary Instances: '+str(primaryInstances)+
           ' Helper Instances: '+str(helperInstances)+
           ' Epochs: '+str(getEpochs(primaryDataSet,primaryInstances)))
    
    numOutputChannels = getChannels(primaryDataSet)
    
    # load the model learnt during training
    G = Generator(numInputChannels, numGenFilter, numOutputChannels)
    G.load_state_dict(torch.load(modelFile))
    genImageConcat = np.empty(1)
    
    iterations = numOfSamples/batchSize
    
    for iteration in range(iterations):
        noise = torch.FloatTensor(batchSize,
                                  numInputChannels,
                                  1,
                                  1)
        noise.normal_(0,1)

        if cuda:
            G = G.cuda()
            noise = noise.cuda()
        noiseVariable = Variable(noise)

        genImage = G(noiseVariable)
        genImage = genImage.data
        genImage = genImage.cpu()
        genImage = genImage.numpy()
        
        
        if iteration==0:
            genImageConcat = genImage
        else:
            genImageConcat = np.concatenate((genImageConcat, genImage),
                                            axis=0)
            
        if iteration==(iterations-1):
            
            # normalize sets image pixels between 0 and 1
            genImage = torchvision.utils.make_grid(torch.from_numpy(genImage[:25]), nrow=5, normalize=True)
            
            # mapping between 0 to 1 as required by imshow,
            # otherwise the images are stored in the form -1 to 1
            # done through normalize=True
            
            genImage = genImage.permute(1,2,0)
            genImage = genImage.numpy()

            plt.imshow(genImage, cmap='gray')
            
            plotFolder = resultDir+'results'+'/'+'crossDataSetMMDall/samples'+'/'+primaryDataSet+'/'
            checkAndCreateFolder(plotFolder)
            plotFile = primaryDataSet + '_' + helperDataSet + '_' + \
                str(primaryClass) + '_' + 'all' + '_' + \
                str(primaryInstances) + '_' + str(helperInstances) 
            plotPath = plotFolder + plotFile
            
            plt.axis('off')
            plt.savefig(plotPath, bbox_inches='tight')
            plt.show()
            

    resultFolder = resultDir+'results/crossDataSetMMDall'+'/'+'compressed'+'/'+primaryDataSet+'/'
    checkAndCreateFolder(resultFolder)
    resultFile = primaryDataSet + '_' + helperDataSet + '_' + \
                 str(primaryClass) + '_' + 'all' + '_' + \
                 str(primaryInstances) + '_' + str(helperInstances)  + '.npy'
            
    resultPath = resultFolder + resultFile
    
    # save the image in some format
    with open(resultPath,'wb+') as fh:
        genImageConcat = np.squeeze(genImageConcat)
        np_save(fh, genImageConcat, allow_pickle=False)
        sync(fh)
示例#10
0
def MMDhist(dataSet, primaryClass, primaryInstance, helperInstances,
            batchSize):

    # load the datasets for which the bar graph needs to be plotted
    x = loadDataset(dataSet, [primaryClass], [primaryInstance])

    primaryClasses = [i for i in range(10)]
    primaryClasses = primaryClasses[:primaryClass] + primaryClasses[
        primaryClass + 1:]
    primaryInstances = [primaryInstance for i in range(10)]

    y = loadDataset(dataSet, primaryClasses, primaryInstances)

    primaryTrainLoader = torch.utils.data.DataLoader(x,
                                                     batch_size=batchSize,
                                                     shuffle=True,
                                                     num_workers=4,
                                                     drop_last=True)
    helperTrainLoader = torch.utils.data.DataLoader(y,
                                                    batch_size=batchSize,
                                                    shuffle=True,
                                                    num_workers=4,
                                                    drop_last=True)

    mmdValues = []
    minMMDValue = 1.5
    maxMMDValue = -0.1

    minPrimaryDataInstance = torch.FloatTensor(batchSize, getChannels(dataSet),
                                               getImageSize(dataSet),
                                               getImageSize(dataSet)).zero_()
    minHelperDataInstance = minPrimaryDataInstance
    maxPrimaryDataInstance = minPrimaryDataInstance
    maxHelperDataInstance = minPrimaryDataInstance

    for i, primaryData in enumerate(primaryTrainLoader, 0):
        primaryDataInstance, primaryDataLabel = primaryData
        for j, helperData in enumerate(helperTrainLoader, 0):

            helperDataInstance, helperDataLabel = helperData

            mmdValue = kernelTwoSampleTest(primaryDataInstance,
                                           helperDataInstance)

            # choosing the pair with minimum MMD
            if minMMDValue > mmdValue:
                minMMDValue = mmdValue
                minPrimaryDataInstance = primaryDataInstance
                minHelperDataInstance = helperDataInstance

            # choosing the pair with maximum MMD
            if maxMMDValue < mmdValue:
                maxMMDValue = mmdValue
                maxPrimaryDataInstance = primaryDataInstance
                maxHelperDataInstance = helperDataInstance

            mmdValues.append(mmdValue)

            #mmdValues.append (1-kernelTwoSampleTest(primaryDataInstance, helperDataInstance))

    displayImage(maxPrimaryDataInstance, maxHelperDataInstance, maxMMDValue,
                 dataSet, primaryClass, primaryInstance, helperInstances,
                 'max')
    displayImage(minPrimaryDataInstance, minHelperDataInstance, minMMDValue,
                 dataSet, primaryClass, primaryInstance, helperInstances,
                 'min')

    mmdValues = np.asarray(mmdValues)
    plt.figure()
    plt.hist(mmdValues, ec='k')

    classes = getClasses(dataSet)

    plt.plot()
    plt.xlabel('$MMD^2$ between batch of Primary Class and Helper Class')
    plt.ylabel('Number of Batch Pairs have $MMD^2$ value in that range')
    # "\n".join(wrap())
    plt.title(
        ' Dataset: {} \n Primary Class: {}  \n   Primary Instances:{} \n  Helper Instances:{} \n Batch Size:{}'
        .format(dataSet, classes[primaryClass], primaryInstance,
                helperInstances, batchSize))

    saveFolder = resultDir + 'mmdValues' + '/' + 'hist' + '/' + dataSet + '/'
    checkAndCreateFolder(saveFolder)
    saveFile = saveFolder + dataSet + '_' + str(primaryClass) + '_' + str(
        primaryInstance) + '_' + str(helperInstances) + '_' + str(
            batchSize) + '.png'

    plt.savefig(saveFile, bbox_inches='tight')
    plt.show()