예제 #1
0
def trainSamples(primaryDatasets, primaryClasses, primaryInstances,
                 helperInstances):

    for dataSet in primaryDataSets:
        for cls in primaryClasses:
            for instance in primaryInstances:
                for helperInstance in helperInstances:

                    # if the number of primary instances are larger than the number of helper
                    # instances, no need to calculate MMD
                    if instance > helperInstance:
                        continue

                    # get a fixed helper class for a particular dataset and primary class
                    helperClass = getHelperClass(dataSet, cls)

                    if helperClass == -1:
                        continue
                    print(
                        'Primary Class: {} Helper Class: {} Primary Instances: {} Helper Instance {}'
                        .format(cls, helperClass, instance, helperInstance))

                    primaryFileName = dataSet + '_' + str(cls) + '_' + str(
                        instance)
                    helperFileName = dataSet + '_' + str(
                        helperClass) + '_' + str(helperInstance)
                    dataFolder = rootDir + str(dataSet)

                    x = loadDataset(dataSet, cls, instance)
                    y = loadDataset(dataSet, helperClass, helperInstance)
                    primaryTrainLoader = torch.utils.data.DataLoader(
                        x,
                        batch_size=batchSize,
                        shuffle=True,
                        num_workers=4,
                        drop_last=True)
                    helperTrainLoader = torch.utils.data.DataLoader(
                        y,
                        batch_size=batchSize,
                        shuffle=True,
                        num_workers=4,
                        drop_last=True)
                    numOutputChannels = getChannels(dataSet)
                    epochs = getEpochs(dataSet, instance)

                    print epochs
                    train(primaryFileName,
                          helperFileName,
                          primaryTrainLoader,
                          helperTrainLoader,
                          instance,
                          numOutputChannels,
                          epochs=epochs)
예제 #2
0
def trainSamples(primaryDataSet, helperDataSet, primaryClasses,
                 primaryInstances, helperInstances):

    for cls in primaryClasses:
        for instance in primaryInstances:
            for helperInstance in helperInstances:

                # if the number of primary instances are larger than the number of helper
                # instances, no need to calculate MMD
                if instance > helperInstance:
                    continue

                # the primary class is same as helper class in this case
                helperClass = cls

                print(
                    'Primary Class: {} Helper Class: {} Primary Instances: {} Helper Instance {}'
                    .format(cls, helperClass, instance, helperInstance))

                primaryFileName = primaryDataSet + '_' + str(cls) + '_' + str(
                    instance)
                helperFileName = helperDataSet + '_' + str(
                    helperClass) + '_' + str(helperInstance)

                x = loadDataset(primaryDataSet, cls, instance)
                y = loadDataset(helperDataSet, helperClass, helperInstance)

                primaryTrainLoader = torch.utils.data.DataLoader(
                    x,
                    batch_size=batchSize,
                    shuffle=True,
                    num_workers=4,
                    drop_last=False)
                helperTrainLoader = torch.utils.data.DataLoader(
                    y,
                    batch_size=batchSize,
                    shuffle=True,
                    num_workers=4,
                    drop_last=True)
                numOutputChannels = getChannels(primaryDataSet)
                epochs = getEpochs(primaryDataSet, instance)

                train(primaryFileName,
                      helperFileName,
                      primaryTrainLoader,
                      helperTrainLoader,
                      instance,
                      numOutputChannels,
                      epochs=epochs)
예제 #3
0
def MMDhist_avg(dataSet, primaryClass, primaryInstances, helperInstances,
                batchSize):

    # load the datasets for which the bar graph needs to be plotted
    x = loadDataset(dataSet, [primaryClass], [primaryInstances])
    classNames = getClasses(dataSet)
    avgValues = []

    # check the average discrepancy between two datasets
    for m in range(10):

        y = loadDataset(dataSet, [m], [helperInstances])

        primaryTrainLoader = torch.utils.data.DataLoader(x,
                                                         batch_size=batchSize,
                                                         shuffle=True,
                                                         num_workers=4,
                                                         drop_last=True)
        helperTrainLoader = torch.utils.data.DataLoader(y,
                                                        batch_size=batchSize,
                                                        shuffle=True,
                                                        num_workers=4,
                                                        drop_last=True)

        mmdValues = []

        for i, primaryData in enumerate(primaryTrainLoader, 0):
            primaryDataInstance, primaryDataLabel = primaryData
            for j, helperData in enumerate(helperTrainLoader, 0):
                helperDataInstance, helperDataLabel = helperData

                mmdValue = kernelTwoSampleTest(primaryDataInstance,
                                               helperDataInstance)
                mmdValues.append(mmdValue)

        mmdValues = np.asarray(mmdValues)
        avgValue = np.mean(mmdValues)

        print dataSet, classNames[primaryClass], classNames[m], avgValue
        print 'Average Discrepancy for ' + dataSet + ' Primary Class: ' + str(
            classNames[primaryClass]) + ' Helper Class: ' + str(
                classNames[m]) + ' is :' + str(avgValue)

        avgValues.append(avgValue)

    # defining plot attributes
    MMDBar(dataSet, avgValues, primaryClass, primaryInstances, helperInstances,
           batchSize)
예제 #4
0
def trainSamples(primaryDatasets, primaryClasses, primaryInstances,
                 helperInstances):

    for dataSet in primaryDataSets:
        for cls in primaryClasses:
            for instance in primaryInstances:
                for helperInstance in helperInstances:

                    # if the number of primary instances are larger than the number of helper
                    # instances, no need to calculate MMD
                    if instance > helperInstance:
                        continue

                    primaryFileName = dataSet + '_' + str(cls) + '_' + str(
                        instance)
                    helperFileName = dataSet + '_' + 'all' + '_' + str(
                        helperInstance)
                    dataFolder = rootDir + str(dataSet)

                    x = loadDataset(dataSet, cls, instance, except_=0)
                    y = loadDataset(dataSet, cls, instance, except_=1)

                    primaryTrainLoader = torch.utils.data.DataLoader(
                        x,
                        batch_size=batchSize,
                        shuffle=True,
                        num_workers=4,
                        drop_last=True)
                    helperTrainLoader = torch.utils.data.DataLoader(
                        y,
                        batch_size=batchSize,
                        shuffle=True,
                        num_workers=4,
                        drop_last=True)
                    numOutputChannels = getChannels(dataSet)
                    epochs = getEpochs(dataSet, instance)

                    train(primaryFileName,
                          helperFileName,
                          primaryTrainLoader,
                          helperTrainLoader,
                          instance,
                          numOutputChannels,
                          epochs=epochs)
예제 #5
0
def trainSamples(primaryDomain,
                 helperDomain,
                 primaryInstances,
                 helperInstances,
                 fewShotInstances,
                 primaryClasses,
                 helperClasses,
                 fewShotClasses,
                 epochs=100):

    primaryClassList, primaryInstanceList, helperClassList, helperInstanceList = getAttributes(
        primaryInstances, helperInstances, fewShotInstances, primaryClasses,
        helperClasses, fewShotClasses)

    # find the k-NN between train and test images
    x = loadDataset(primaryDomain,
                    primaryClassList,
                    primaryInstanceList,
                    mode='train')
    numOutputChannels = getChannels(primaryDomain)

    tempImageArray = np.asarray([data[0].numpy().squeeze() for data in x])
    tempLabelArray = np.asarray([np.asarray(data[1]) for data in x])
    tempImageArray = tempImageArray.reshape(tempImageArray.shape[0], -1)

    sampleImage = getImageSamples(primaryDomain,
                                  tempImageArray,
                                  numOfChannels=numOutputChannels)

    # plot the figure of generated samples and save
    fig = plt.figure(figsize=(20, 10))
    plt.imshow(sampleImage)
    plt.axis('off')
    plt.show()

    # load the dataset parallaly
    primaryTrainLoader = torch.utils.data.DataLoader(x,
                                                     batch_size=batchSize,
                                                     shuffle=True,
                                                     num_workers=4,
                                                     drop_last=True)

    fileName = primaryDomain + '_' + str(fewShotClasses) + '_' + str(
        primaryInstances) + '_' + str(fewShotInstances)
    train(fileName,
          primaryTrainLoader,
          primaryInstanceList,
          primaryClasses,
          numOutputChannels=numOutputChannels,
          epochs=epochs)
예제 #6
0
def getRealData(dataSet, realClasses, instances, mode='train'):
    '''
    Ouput image pixels between (-1 to 1)
    ''' 
    realClasses = sorted(realClasses)
    
    numOfChannels = getChannels(dataSet)
    imageSize = getImageSize(dataSet)
    
    lenArray = len(realClasses)*instances    
    imageArray = np.zeros((lenArray, numOfChannels*imageSize*imageSize))
    labelArray = np.zeros((lenArray))
    
    initialPoint=0
    for i in realClasses:
        tupleArray = loadDataset(dataSet, i, instances, mode=mode)
        
        tempImageArray = np.asarray([ data[0].numpy().squeeze() for data in tupleArray ])
        tempLabelArray = np.asarray([ np.asarray(data[1]) for data in tupleArray ])
        
        tempImageArray = tempImageArray.reshape(tempImageArray.shape[0],-1)
        
        #print imageArray.shape, tempImageArray.shape
        imageArray[initialPoint:(initialPoint + tempImageArray.shape[0])] = tempImageArray 
        labelArray[initialPoint:(initialPoint + tempImageArray.shape[0])] = tempLabelArray 
        initialPoint = tempImageArray.shape[0] + initialPoint
        
    imageArray = imageArray[:initialPoint,:]
    labelArray = labelArray[:initialPoint]
    
    # random shuffling of images and labels
    p = np.random.permutation(imageArray.shape[0])
    imageArray = imageArray[p]
    labelArray = labelArray[p]
    
    return imageArray, labelArray
예제 #7
0
def MMDhist(primaryDomain,
            helperDomain,
            primaryClass,
            helperClassList,
            primaryInstance,
            helperInstance,
            batchSize):
    
    # load the datasets for which the bar graph needs to be plotted
    x = loadDataset(primaryDomain, [primaryClass], [primaryInstance], mode='train')
    primaryTrainLoader = torch.utils.data.DataLoader(x,
                                               batch_size=batchSize, 
                                               shuffle=True,
                                               num_workers=4,
                                               drop_last = True)
    
    avgMMDValues = []
    for helperClass in helperClassList :
        
        y = loadDataset(helperDomain, [helperClass], [helperInstance], mode='train')    

        helperTrainLoader = torch.utils.data.DataLoader(y,
                                                  batch_size=batchSize, 
                                                  shuffle=True,
                                                  num_workers=4,
                                                  drop_last=True)
        
        mmdValues = []
        
        for i, primaryData in enumerate(primaryTrainLoader, 0):
            primaryDataInstance, primaryDataLabel = primaryData
            for j, helperData in enumerate(helperTrainLoader, 0):

                helperDataInstance, helperDataLabel = helperData
                mmdValue = kernelTwoSampleTest(primaryDataInstance, helperDataInstance)
                mmdValues.append(mmdValue)
        
        mmdValues = np.asarray(mmdValues)
        avgMMDValues.append(np.mean(mmdValues))
    

    # plot the average MMD Values
    fig = plt.figure()
    ax = plt.subplot(111)
    ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))

    plt.ylabel('Avg. MMD Value')
    plt.xlabel('Class')
    
    avgMMDValues = np.asarray(avgMMDValues)
    avgMMDValues[avgMMDValues<0] = 0
    
    # plot adjustments
    yAbove = np.max(avgMMDValues)+(np.max(avgMMDValues)/3)
    yStep = (np.max(avgMMDValues))/10
    
    plt.yticks(np.arange(0.0,yAbove,yStep))
    
    classNames = getClasses(primaryDomain)
    plt.title(primaryDomain+' - '+str(classNames[primaryClass]))
    
    # x-axis adjustments
    xReal = getClasses(helperDomain)
    ind = np.arange(0, len(xReal))
    ax.set_xticks(ind)
    ax.set_xticklabels(xReal)
    
    plt.xticks(rotation=45)
    plt.bar(ind, avgMMDValues, 0.50)   
    
    plotFolderName = resultDir+'mmdValues'+'/'+primaryDomain+'/'
    checkAndCreateFolder(plotFolderName)

    plotFileName = plotFolderName+primaryDomain+'_'+str(primaryClass)+'_'+str(helperDomain)+'_'+str(batchSize)+'.png'
    plt.savefig(plotFileName, bbox_inches='tight')
    plt.show()
예제 #8
0
    
    parser.add_argument('--name', type=str, default='Experiment1')
        
    # dataloader
    
    
    variables = parser.parse_args()
    
    
    # define dataloaders of datasets
    classes = [0,1,2,3,4,5,6,7,8,9]
    
    instances = [variables.numDomainA for i in range(10)]

    
    domainA_dataset = loadDataset(variables.domainA, classes, instances, 'train')
    domainA_dataloader = torch.utils.data.DataLoader(domainA_dataset, 
                                                   batch_size = variables.batchSize,
                                                   shuffle = True,
                                                   num_workers = 2,
                                                   drop_last=True)

    instances = [variables.numDomainB for i in range(10)]


    domainB_dataset = loadDataset(variables.domainB, classes, instances, 'train')
    domainB_dataloader = torch.utils.data.DataLoader(domainB_dataset, 
                                                   batch_size = variables.batchSize,
                                                   shuffle = True,
                                                   num_workers = 2,
                                                   drop_last=True)
예제 #9
0
def MMDhist(dataSet, primaryClass, primaryInstance, helperInstances,
            batchSize):

    # load the datasets for which the bar graph needs to be plotted
    x = loadDataset(dataSet, [primaryClass], [primaryInstance])

    primaryClasses = [i for i in range(10)]
    primaryClasses = primaryClasses[:primaryClass] + primaryClasses[
        primaryClass + 1:]
    primaryInstances = [primaryInstance for i in range(10)]

    y = loadDataset(dataSet, primaryClasses, primaryInstances)

    primaryTrainLoader = torch.utils.data.DataLoader(x,
                                                     batch_size=batchSize,
                                                     shuffle=True,
                                                     num_workers=4,
                                                     drop_last=True)
    helperTrainLoader = torch.utils.data.DataLoader(y,
                                                    batch_size=batchSize,
                                                    shuffle=True,
                                                    num_workers=4,
                                                    drop_last=True)

    mmdValues = []
    minMMDValue = 1.5
    maxMMDValue = -0.1

    minPrimaryDataInstance = torch.FloatTensor(batchSize, getChannels(dataSet),
                                               getImageSize(dataSet),
                                               getImageSize(dataSet)).zero_()
    minHelperDataInstance = minPrimaryDataInstance
    maxPrimaryDataInstance = minPrimaryDataInstance
    maxHelperDataInstance = minPrimaryDataInstance

    for i, primaryData in enumerate(primaryTrainLoader, 0):
        primaryDataInstance, primaryDataLabel = primaryData
        for j, helperData in enumerate(helperTrainLoader, 0):

            helperDataInstance, helperDataLabel = helperData

            mmdValue = kernelTwoSampleTest(primaryDataInstance,
                                           helperDataInstance)

            # choosing the pair with minimum MMD
            if minMMDValue > mmdValue:
                minMMDValue = mmdValue
                minPrimaryDataInstance = primaryDataInstance
                minHelperDataInstance = helperDataInstance

            # choosing the pair with maximum MMD
            if maxMMDValue < mmdValue:
                maxMMDValue = mmdValue
                maxPrimaryDataInstance = primaryDataInstance
                maxHelperDataInstance = helperDataInstance

            mmdValues.append(mmdValue)

            #mmdValues.append (1-kernelTwoSampleTest(primaryDataInstance, helperDataInstance))

    displayImage(maxPrimaryDataInstance, maxHelperDataInstance, maxMMDValue,
                 dataSet, primaryClass, primaryInstance, helperInstances,
                 'max')
    displayImage(minPrimaryDataInstance, minHelperDataInstance, minMMDValue,
                 dataSet, primaryClass, primaryInstance, helperInstances,
                 'min')

    mmdValues = np.asarray(mmdValues)
    plt.figure()
    plt.hist(mmdValues, ec='k')

    classes = getClasses(dataSet)

    plt.plot()
    plt.xlabel('$MMD^2$ between batch of Primary Class and Helper Class')
    plt.ylabel('Number of Batch Pairs have $MMD^2$ value in that range')
    # "\n".join(wrap())
    plt.title(
        ' Dataset: {} \n Primary Class: {}  \n   Primary Instances:{} \n  Helper Instances:{} \n Batch Size:{}'
        .format(dataSet, classes[primaryClass], primaryInstance,
                helperInstances, batchSize))

    saveFolder = resultDir + 'mmdValues' + '/' + 'hist' + '/' + dataSet + '/'
    checkAndCreateFolder(saveFolder)
    saveFile = saveFolder + dataSet + '_' + str(primaryClass) + '_' + str(
        primaryInstance) + '_' + str(helperInstances) + '_' + str(
            batchSize) + '.png'

    plt.savefig(saveFile, bbox_inches='tight')
    plt.show()