Ejemplo n.º 1
0
def displayImage(primaryArray, helperArray, mmdScore, dataSet, primaryClass, primaryInstances, helperInstances, mode):
    fig = plt.figure()
    
    a=fig.add_subplot(1,2,1)
    a.set_title('$X_{batch}$')

    genImage = torchvision.utils.make_grid(primaryArray[:49], nrow=7, normalize=True)
    genImage = genImage.permute(1,2,0)
    genImage = genImage.numpy()
    plt.axis('off')
    plt.imshow(genImage)
    
    a=fig.add_subplot(1,2,2)
    a.set_title('$Y_{batch}$')

    genImage = torchvision.utils.make_grid(helperArray[:49], nrow=7, normalize=True)
    genImage = genImage.permute(1,2,0)
    genImage = genImage.numpy()
    plt.axis('off')
    plt.imshow(genImage)
    
    plt.text(-100,280,'$MMD^{2}(X_{batch},Y_{batch})$: '+str(round(mmdScore,3)))
    
    plotFolderName = resultDir+'mmdValues'+'/'+mode+'/'+dataSet+'/'
    checkAndCreateFolder(plotFolderName)

    plotFileName = \ plotFolderName+dataSet+'_'+str(primaryClass)+'_'+str(primaryInstances)+'_'+str(helperInstances)+'_'+str(batchSize)+'.png'
    plt.savefig(plotFileName, bbox_inches='tight')
    plt.show()
Ejemplo n.º 2
0
def showTrainHist(trainHist, fileName, epoch):
    '''
    Plot Generator and Discriminator loss function
    '''
    x = range(len(trainHist['discLoss']))

    y1 = trainHist['discLoss']
    y2 = trainHist['genLoss']

    plt.plot(x, y1, label='D_loss')
    plt.plot(x, y2, label='G_loss')

    plt.xlabel('Iter')
    plt.ylabel('Loss')

    plt.legend(loc='upper right')
    plt.grid(True)
    plt.tight_layout()

    folder = fileName.split('_')[0]

    lossFolderName = resultDir + 'loss/nonMMD' + '/' + folder + '/'
    checkAndCreateFolder(lossFolderName)
    lossFileName = lossFolderName + fileName + '_' + str(epoch) + '.png'
    plt.savefig(lossFileName, bbox_inches='tight')

    plt.show()
Ejemplo n.º 3
0
def saveImageSamples(dataSet, classes, numOfInstances):
    '''
    dataset : string
    classes : list 
    '''
        
    x,y = getRealData(dataSet, classes, numOfInstances, mode='train')
    numOfChannels = getChannels(dataSet)
    sampleImage = getImageSamples(x, numOfChannels=numOfChannels)

    # plot the figure of generated samples and save
    fig = plt.figure()

    plt.imshow(sampleImage, cmap='gray')
    plt.axis('off')
    
    plotFolderName = 'samples'+'/'+dataSet+'/'
    checkAndCreateFolder(plotFolderName)
    
    if len(classes)==9:
        plotFileName = dataSet+'_'+'all'+'.png'
    else :
        plotFileName = dataSet+'_'+str(classes[0])+'.png'
        
    plt.savefig(plotFolderName + plotFileName, bbox_inches='tight')
    plt.show()
Ejemplo n.º 4
0
def MMDBar(dataSet, yReal, primaryClass, primaryInstances, helperInstances,
           batchSize):
    '''
    Plot Avg. MMD value within a dataset as a histogram
    '''
    fig = plt.figure()
    ax = plt.subplot(111)
    plt.ylabel('Avg. MMD Value')
    plt.xlabel('Class')

    yReal = np.asarray(yReal)
    yReal[yReal < 0] = 0
    yAbove = np.max(yReal) + (np.max(yReal) / 3)
    yStep = (yAbove - np.min(yReal)) / 10
    plt.yticks(np.arange(0.0, yAbove, yStep))

    classNames = getClasses(dataSet)
    plt.title(dataSet + ' - ' + str(classNames[primaryClass]))

    # plot accuracy
    xReal = getClasses(dataSet)

    ind = np.arange(0, 10)
    ax.set_xticks(ind)
    ax.set_xticklabels(xReal)

    plt.xticks(rotation=45)
    plt.bar(ind, yReal, 0.50)

    plotFolderName = resultDir + 'mmdValues' + '/' + dataSet + '/'
    checkAndCreateFolder(plotFolderName)

    plotFileName = plotFolderName + dataSet + '_' + str(
        primaryClass) + '_' + str(primaryInstances) + '_' + str(
            helperInstances) + '_' + str(batchSize) + '.png'
    plt.savefig(plotFileName, bbox_inches='tight')
    plt.show()
Ejemplo n.º 5
0
def train(fileName,
          trainLoader,
          primaryInstanceList,
          numClasses,
          numOutputChannels=1,
          learningRate=0.0002,
          optimBetas=(0.5, 0.999),
          epochs=5):
    '''
    Training for Deep Convolutional Generative Adversatial Network
    '''
    folder = fileName.split('_')[0]
    instances = sum(primaryInstanceList)

    # generator takes input channels, number of labels, number of generative filters, number of output channels
    G = Generator(numInputChannels, numClasses, numGenFilter,
                  numOutputChannels)
    D = Discriminator(numOutputChannels, numClasses, numDiscFilter)

    G.weight_init(mean=0.0, std=0.02)
    D.weight_init(mean=0.0, std=0.02)

    lossFunction = nn.BCELoss()

    genOptimiser = optim.Adam(G.parameters(),
                              lr=learningRate,
                              betas=optimBetas)
    disOptimiser = optim.Adam(D.parameters(),
                              lr=learningRate,
                              betas=optimBetas)

    numElementsNeededPerClass = 10

    fixedNoise = torch.randn(numElementsNeededPerClass * numClasses,
                             numInputChannels, 1, 1)

    # class from which the GAN should output a distribution
    fixedNoiseClass = torch.zeros(numElementsNeededPerClass * numClasses,
                                  numClasses, 1, 1)

    classIndex = torch.zeros(numElementsNeededPerClass, 1)
    for i in range(numClasses - 1):
        temp = torch.ones(numElementsNeededPerClass, 1) + i
        classIndex = torch.cat([classIndex, temp], 0)

    fixedNoiseClass = fixedNoiseClass.squeeze().scatter_(
        1, classIndex.type(torch.LongTensor), 1)
    fixedNoiseClass = fixedNoiseClass.view(-1, numClasses, 1, 1)

    # added imagesize
    discRealInput = torch.FloatTensor(batchSize, numOutputChannels, imageSize,
                                      imageSize)

    discRealInputClass = torch.zeros(batchSize, numClasses, imageSize,
                                     imageSize)

    discFakeInput = torch.FloatTensor(batchSize, numInputChannels, 1, 1)

    discFakeInputClass = torch.zeros(batchSize, numClasses, 1, 1)

    discRealLabel = torch.FloatTensor(batchSize)
    discRealLabel.fill_(1)

    discFakeLabel = torch.FloatTensor(batchSize)
    discFakeLabel.fill_(0)

    if instances < batchSize:

        discRealInput = torch.FloatTensor(instances, numOutputChannels,
                                          imageSize, imageSize)
        discFakeInput = torch.FloatTensor(instances, numInputChannels, 1, 1)

        discRealInputClass = torch.zeros(instances, numClasses, imageSize,
                                         imageSize)
        discFakeInputClass = torch.zeros(instances, numClasses, 1, 1)

        discRealLabel = torch.FloatTensor(instances)
        discRealLabel.fill_(1)

        discFakeLabel = torch.FloatTensor(instances)
        discFakeLabel.fill_(0)

    if cuda:
        G = G.cuda()
        D = D.cuda()

        lossFunction = lossFunction.cuda()

        discRealInput = discRealInput.cuda()
        discFakeInput = discFakeInput.cuda()

        discRealInputClass = discRealInputClass.cuda()
        discFakeInputClass = discFakeInputClass.cuda()

        discRealLabel = discRealLabel.cuda()
        discFakeLabel = discFakeLabel.cuda()

        fixedNoise = fixedNoise.cuda()
        fixedNoiseClass = fixedNoiseClass.cuda()

    fixedNoiseVariable = Variable(fixedNoise)
    fixedNoiseClassVariable = Variable(fixedNoiseClass)

    # can take the oneHot representation to feed into generator directly from here
    oneHotGen = torch.zeros(numClasses, numClasses)
    oneHotGen = oneHotGen.scatter_(
        1,
        torch.LongTensor([i for i in range(numClasses)]).view(numClasses, 1),
        1).view(numClasses, numClasses, 1, 1)

    # can take the oneHot representation to feed into discriminator directly from here
    oneHotDisc = torch.zeros([numClasses, numClasses, imageSize, imageSize])
    for i in range(numClasses):
        oneHotDisc[i, i, :, :] = 1

    trainHist = {}
    trainHist['discLoss'] = []
    trainHist['genLoss'] = []
    trainHist['perEpochTime'] = []
    trainHist['totalTime'] = []

    imageList = []

    for epoch in range(epochs):

        generatorLosses = []
        discriminatorLosses = []

        epochStartTime = time.time()

        for i, data in enumerate(trainLoader, 0):

            # train discriminator on real data,
            D.zero_grad()
            dataInstance, dataClass = data

            # one-hot encoding for discriminator class input
            dataClass = oneHotDisc[dataClass]

            if cuda:
                dataInstance = dataInstance.cuda()
                dataClass = dataClass.cuda()

            discRealInput.copy_(dataInstance)
            discRealInputClass.copy_(dataClass)

            discRealInputVariable = Variable(discRealInput)
            discRealInputClassVariable = Variable(discRealInputClass)
            discRealLabelVariable = Variable(discRealLabel)

            discRealOutput = D(discRealInputVariable,
                               discRealInputClassVariable)

            lossRealDisc = lossFunction(discRealOutput, discRealLabelVariable)
            lossRealDisc.backward()

            # train discriminator on fake data
            discFakeInput.normal_(0, 1)

            # change this as instances x numClasses x imagesize x imagesize
            if instances < batchSize:
                #dataFakeClass = (torch.rand(instances)*numClasses).type(torch.LongTensor)
                dataFakeClass = torch.from_numpy(
                    np.random.choice(numClasses,
                                     instances,
                                     p=getProbDist(primaryInstanceList)))
            else:
                #dataFakeClass = (torch.rand(batchSize)*numClasses).type(torch.LongTensor)
                dataFakeClass = torch.from_numpy(
                    np.random.choice(numClasses,
                                     batchSize,
                                     p=getProbDist(primaryInstanceList)))

            discFakeInputClass = oneHotDisc[dataFakeClass]
            genFakeInputClass = oneHotGen[dataFakeClass]

            if cuda:
                discFakeInputClass = discFakeInputClass.cuda()
                genFakeInputClass = genFakeInputClass.cuda()

            discFakeInputVariable = Variable(discFakeInput)
            discFakeInputClassVariable = Variable(discFakeInputClass)
            genFakeInputClassVariable = Variable(genFakeInputClass)
            discFakeLabelVariable = Variable(discFakeLabel)

            discFakeInputGen = G(discFakeInputVariable,
                                 genFakeInputClassVariable)

            # change the gradients of discriminator only
            discFakeOutput = D(discFakeInputGen.detach(),
                               discFakeInputClassVariable)

            lossFakeDisc = lossFunction(discFakeOutput, discFakeLabelVariable)
            lossFakeDisc.backward()

            disOptimiser.step()

            # log the loss for discriminator
            discriminatorLosses.append((lossRealDisc + lossFakeDisc).data[0])

            # train generator based on discriminator
            G.zero_grad()

            genInputVariable = discFakeInputGen

            # get the class function here
            genOutputDisc = D(genInputVariable, discFakeInputClassVariable)

            lossGen = lossFunction(genOutputDisc, discRealLabelVariable)

            lossGen.backward()
            genOptimiser.step()

            # log the loss for generator
            generatorLosses.append(lossGen.data[0])

        # create an image for every epoch
        # generate samples from trained generator
        genImage = G(fixedNoiseVariable, fixedNoiseClassVariable)
        genImage = genImage.data
        genImage = genImage.cpu()

        genImage = torchvision.utils.make_grid(genImage, nrow=10)
        genImage = (genImage / 2) + 0.5
        genImage = genImage.permute(1, 2, 0)
        genImage = genImage.numpy()

        plt.figure()
        fig = plt.figure(figsize=(20, 10))
        plt.imshow(genImage)
        plt.axis('off')

        txt = 'Epoch: ' + str(epoch + 1)
        fig.text(.45, .05, txt)

        plt.savefig('y.png', bbox_inches='tight')

        imageList.append(imageio.imread('y.png'))

        epochEndTime = time.time()
        perEpochTime = epochEndTime - epochStartTime
        discLoss = torch.mean(torch.FloatTensor(discriminatorLosses))
        genLoss = torch.mean(torch.FloatTensor(generatorLosses))
        print('Epoch : [%d/%d] time: %.2f, loss_d: %.3f, loss_g: %.3f' %
              (epoch + 1, epochs, perEpochTime, discLoss, genLoss))

        if epoch == (epochs - 1):
            print('Completed processing ' + str(instances) + 'for ' +
                  str(epochs) + 'epochs.')

            plotFolderName = resultDir + 'plots/nonMMD' + '/' + folder + '/'
            checkAndCreateFolder(plotFolderName)
            plotFileName = plotFolderName + fileName + '_' + str(
                epochs) + '.png'

            plt.imshow(genImage)
            plt.savefig(plotFileName, bbox_inches='tight')
            plt.close('all')

            # create gif animation
            animFolderName = resultDir + 'animation/nonMMD' + '/' + folder + '/'
            checkAndCreateFolder(animFolderName)
            animFileName = animFolderName + fileName + '_' + str(
                epochs) + '.gif'

            imageio.mimsave(animFileName, imageList, fps=5)

        trainHist['discLoss'].append(discLoss)
        trainHist['genLoss'].append(genLoss)

    # save the model parameters in a file
    modelFolderName = resultDir + 'models/nonMMD' + '/' + folder + '/'
    checkAndCreateFolder(modelFolderName)
    modelFileName = modelFolderName + fileName + '_' + str(epochs) + '.pt'
    torch.save(G.state_dict(), modelFileName)

    showTrainHist(trainHist, fileName, epoch)
Ejemplo n.º 6
0
def plotAccuracyBar(dataSet,
                 classifier,
                 primaryInstances,
                 helperInstances,
                 accuracyArray,
                 showImage = 1):
    

    fig = plt.figure()
    ax = fig.add_subplot(111)
    

    ind = np.arange(len(primaryInstances))
    width = 0.15
    delta = 0.02

    color = ['black', 'red', 'blue', 'green']

    histList = []
    legendList = []
    
    for i in range(accuracyArray.shape[0]):
        hist = ax.bar(ind+i*(width+delta), list(accuracyArray[:,i]), width, color=color[i])
        histList.append(hist)
        
        if i==0:
            legendList.append('Original')
        elif i==1:
            legendList.append('GAN - 0 Helper Instances')
        else:
            legendList.append('GAN - '+str(helperInstances[i-2])+' Helper Instances')
    
    print (accuracyArray)
    ax.set_xlim(-4*width,len(ind)+width)
    ax.set_ylim(0,1.5)

    ax.set_ylabel('Accuracy [out of 1]')
    ax.set_xlabel('Number of Primary Instances')
    ax.set_title(dataSet)
    

    ax.set_xticks(ind+width)
    xtickNames = ax.set_xticklabels(primaryInstances)
    plt.setp(xtickNames, rotation=45, fontsize=10)
    
    # add to tuple
    ax.legend( tuple(histList), tuple(legendList) , loc='upper left')    

    ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    
    plotFolderName = 'plots/crossDataSetMMDall/accuracy'+'/'+dataSet+'/'
    checkAndCreateFolder(plotFolderName)
    
    plotFileName = plotFolderName+dataSet+'_'+classifier+'.png'
    plt.savefig(plotFileName, bbox_inches='tight')

    plt.show()
    plt.close()

    accuracyFolderName = 'plots/crossDataSetMMDall/accuracyValues'+'/'+dataSet+'/'
    checkAndCreateFolder(accuracyFolderName)
    accuracyFileName = accuracyFolderName+dataSet+'_'+classifier+'.npy'

    # save the image in some format
    with open(accuracyFileName,'wb+') as fh:

        np_save(fh, accuracyArray, allow_pickle=False)
        sync(fh)
Ejemplo n.º 7
0
def plotConfusionMatrix(dataSet,
                        classifier,
                        classes,
                        trainSet,
                        numOfInstances,
                        cm, 
                        normalize=False,
                        numOfHelperInstances=-1,
                        title='Confusion matrix',
                        cmap='Reds'):
    """
    trainSet: can be 'Real','Fake' or 'FakeMMD'
    
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    
    # Plot normalized confusion matrix
    fig = plt.figure()
    
    if numOfHelperInstances==-1:
        fileName = dataSet+'_'+trainSet+'_'+classifier+'_'+str(numOfInstances)
    else:
        fileName = dataSet+'_'+trainSet+'_'+classifier+'_'+str(numOfInstances)+'_'+str(numOfHelperInstances)

    
    plotFolderName = 'plots/crossDataSetMMDall/cm'+'/'+dataSet+'/'
    checkAndCreateFolder(plotFolderName)
    plotFileName = plotFolderName+fileName+'.png'
        
        
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)


    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        w = ' '+str(format(cm[i, j], fmt))+' '
        plt.text(j, i, w,
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    #plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    
    
    plt.savefig(plotFileName, bbox_inches='tight')
    
    # change this to get correct order !!
    if showImage==0:
        plt.show()
    plt.close()
Ejemplo n.º 8
0
def train(primaryFileName,
          helperFileName,
          primaryTrainLoader,
          helperTrainLoader,
          primaryInstances,
          numOutputChannels=1,
          learningRate=0.0002,
          optimBetas=(0.5, 0.999),
          epochs=5):
    '''
    Training for Deep Convolutional Generative Adversatial Network
    '''

    # define the model
    G = Generator(numInputChannels, numGenFilter, numOutputChannels)
    D = Discriminator(numOutputChannels, numDiscFilter)
    lossFunction = nn.BCELoss()
    genOptimiser = optim.Adam(G.parameters(),
                              lr=learningRate,
                              betas=optimBetas)
    disOptimiser = optim.Adam(D.parameters(),
                              lr=learningRate,
                              betas=optimBetas)
    discRealInput = torch.FloatTensor(batchSize, numOutputChannels, imageSize,
                                      imageSize)
    discFakeInput = torch.FloatTensor(batchSize, numInputChannels, 1, 1)
    fixedNoise = torch.FloatTensor(25, numInputChannels, 1, 1)
    fixedNoise.normal_(0, 1)

    discRealLabel = torch.FloatTensor(batchSize)
    discFakeLabel = torch.FloatTensor(batchSize)
    discRealLabel.fill_(1)
    discFakeLabel.fill_(0)

    # for processing on a GPU
    if cuda:
        G = G.cuda()
        D = D.cuda()

        lossFunction = lossFunction.cuda()

        discRealInput = discRealInput.cuda()
        discFakeInput = discFakeInput.cuda()

        discRealLabel = discRealLabel.cuda()
        discFakeLabel = discFakeLabel.cuda()

        fixedNoise = fixedNoise.cuda()

    fixedNoiseVariable = Variable(fixedNoise)

    # assume that helper primaryInstances are always more than primary primaryInstances
    # define primary Epochs and helper Epochs
    primaryEpochs = epochs
    helperEpochs = 10

    print("Starting training with helper classes.")
    plt.figure()

    # training with helper class
    for epoch in range(helperEpochs):
        for i, primaryData in enumerate(primaryTrainLoader, 0):
            for j, helperData in enumerate(helperTrainLoader, 0):

                #print ('Epoch : {} Primary Class Batch : {}. Helper Class Batch : {}.'.format(epoch+1,i+1,j+1))

                primaryDataInstance, primaryDataLabel = primaryData
                helperDataInstance, helperDataLabel = helperData

                # calculate MMD between two batches of data
                mmd = (1 - kernelTwoSampleTest(primaryDataInstance,
                                               helperDataInstance))
                mmd = torch.from_numpy(np.asarray([mmd]))
                mmdVariable = Variable(mmd.float().cuda())

                # weight given to the term
                lambdaMMD = 1.0
                lambdaMMD = torch.from_numpy(np.asarray([lambdaMMD]))
                lambdaMMDVariable = Variable(lambdaMMD.float().cuda())

                D.zero_grad()

                # train GAN using helper data instance
                if cuda:
                    helperDataInstance = helperDataInstance.cuda()

                discRealInput.copy_(helperDataInstance)
                discRealInputVariable = Variable(discRealInput)

                # should we treat this as 1 ??
                discRealLabelVariable = Variable(discRealLabel)
                discRealOutput = D(discRealInputVariable)
                lossRealDisc = lambdaMMDVariable * mmdVariable * lossFunction(
                    discRealOutput, discRealLabelVariable)

                lossRealDisc.backward()

                # train discriminator on fake data
                discFakeInput.normal_(0, 1)
                discFakeInputVariable = Variable(discFakeInput)
                discFakeInputGen = G(discFakeInputVariable)
                discFakeLabelVariable = Variable(discFakeLabel)

                discFakeOutput = D(discFakeInputGen.detach())
                lossFakeDisc = lossFunction(discFakeOutput,
                                            discFakeLabelVariable)
                lossFakeDisc.backward()
                disOptimiser.step()

                # train generator based on discriminator
                # for every epoch the gradients are reset to 0
                # the discriminator should start to confuse fake primaryInstances
                # with true primaryInstances

                G.zero_grad()

                genInputVariable = discFakeInputGen
                genOutputDisc = D(genInputVariable)

                lossGen = lossFunction(genOutputDisc, discRealLabelVariable)

                lossGen.backward()
                genOptimiser.step()

            if (i == 0) and epoch == (helperEpochs - 1):
                #print ('Completed processing '+str(primaryInstances)+'for'+str(epoch)+'epochs.')

                # name for model and plot file
                folder, primaryClass, _ = primaryFileName.split('_')
                _, helperClass, helperInstances = helperFileName.split('_')

                fileName = folder + '_' + str(primaryClass) + '_' + str(helperClass) + '_' + \
                           str(primaryInstances) + '_' + str(helperInstances)

                # generate samples from trained generator
                genImage = G(fixedNoiseVariable)
                genImage = genImage.data
                genImage = genImage.cpu()
                genImage = torchvision.utils.make_grid(genImage, nrow=5)

                genImage = genImage / 2 + 0.5
                genImage = genImage.permute(1, 2, 0)
                genImage = genImage.numpy()

                #print genImage.shape
                # plot the figure of generated samples and save
                fig = plt.figure()
                plt.imshow(genImage, cmap='gray')
                plt.axis('off')

    if primaryInstances < batchSize:
        discRealInput = torch.FloatTensor(primaryInstances, numOutputChannels,
                                          imageSize, imageSize)
        # why only one as width and height ? Passing through generator.
        discFakeInput = torch.FloatTensor(primaryInstances, numInputChannels,
                                          1, 1)
        discRealLabel = torch.FloatTensor(primaryInstances)
        discFakeLabel = torch.FloatTensor(primaryInstances)
        discRealLabel.fill_(1)
        discFakeLabel.fill_(0)

    if cuda:

        discRealInput = discRealInput.cuda()
        discFakeInput = discFakeInput.cuda()

        discRealLabel = discRealLabel.cuda()
        discFakeLabel = discFakeLabel.cuda()

    print("Ending training with helper classes.")
    print("Starting training with primary class.")

    # training with primary class
    for epoch in range(primaryEpochs):
        for i, data in enumerate(primaryTrainLoader, 0):

            #print ('Epoch : {} Primary Class Batch : {}.'.format(epoch+1,i+1))

            if i > 10000:
                print("Done 2000 Iterations")
                break

            # train discriminator on real data
            D.zero_grad()
            dataInstance, dataLabel = data

            if cuda:
                dataInstance = dataInstance.cuda()

            discRealInput.copy_(dataInstance)
            discRealInputVariable = Variable(discRealInput)
            discRealLabelVariable = Variable(discRealLabel)

            discRealOutput = D(discRealInputVariable)
            lossRealDisc = lossFunction(discRealOutput, discRealLabelVariable)

            lossRealDisc.backward()

            # train discriminator on fake data
            discFakeInput.normal_(0, 1)
            discFakeInputVariable = Variable(discFakeInput)
            discFakeInputGen = G(discFakeInputVariable)

            discFakeLabelVariable = Variable(discFakeLabel)

            discFakeOutput = D(discFakeInputGen.detach())
            lossFakeDisc = lossFunction(discFakeOutput, discFakeLabelVariable)
            lossFakeDisc.backward()

            disOptimiser.step()

            # train generator based on discriminator
            # for every epoch the gradients are reset to 0
            # the discriminator should start to confuse fake primaryInstances
            # with true primaryInstances

            G.zero_grad()

            genInputVariable = discFakeInputGen
            genOutputDisc = D(genInputVariable)

            lossGen = lossFunction(genOutputDisc, discRealLabelVariable)

            lossGen.backward()
            genOptimiser.step()

            if (i == 0) and epoch == (primaryEpochs - 1):

                folder, primaryClass, _ = primaryFileName.split('_')
                _, helperClass, helperInstances = helperFileName.split('_')

                fileName = folder + '_' + str(primaryClass) + '_' + str(helperClass) + '_' + \
                           str(primaryInstances) + '_' + str(helperInstances)

                modelFolder = resultDir + 'models/MMDall' + '/' + folder + '/'
                plotFolder = resultDir + 'plots/MMDall' + '/' + folder + '/'

                checkAndCreateFolder(modelFolder)
                checkAndCreateFolder(plotFolder)

                modelFileName = modelFolder + fileName + '_' + str(
                    epoch) + '.pt'
                plotFileName = plotFolder + fileName + '_' + str(
                    epoch) + '.png'

                # save the model parameters in a file
                torch.save(G.state_dict(), modelFileName)

                # generate samples from trained generator
                genImage = G(fixedNoiseVariable)
                genImage = genImage.data
                genImage = genImage.cpu()
                genImage = torchvision.utils.make_grid(genImage, nrow=5)

                genImage = genImage / 2 + 0.5
                genImage = genImage.permute(1, 2, 0)
                genImage = genImage.numpy()

                # plot the figure of generated samples and save
                fig = plt.figure()

                plt.imshow(genImage, cmap='gray')
                plt.axis('off')

                txt = 'Epoch: ' + str(epoch)
                fig.text(.45, .05, txt)
                if showImage == 1:
                    plt.show()
                '''
                IPython.display.clear_output(wait=True)
                IPython.display.display(plt.gcf())
                '''
                plt.savefig(plotFileName, bbox_inches='tight')
                plt.close('all')

    print("Done trining with primary class primaryInstances.")
Ejemplo n.º 9
0
def MMDhist(primaryDomain,
            helperDomain,
            primaryClass,
            helperClassList,
            primaryInstance,
            helperInstance,
            batchSize):
    
    # load the datasets for which the bar graph needs to be plotted
    x = loadDataset(primaryDomain, [primaryClass], [primaryInstance], mode='train')
    primaryTrainLoader = torch.utils.data.DataLoader(x,
                                               batch_size=batchSize, 
                                               shuffle=True,
                                               num_workers=4,
                                               drop_last = True)
    
    avgMMDValues = []
    for helperClass in helperClassList :
        
        y = loadDataset(helperDomain, [helperClass], [helperInstance], mode='train')    

        helperTrainLoader = torch.utils.data.DataLoader(y,
                                                  batch_size=batchSize, 
                                                  shuffle=True,
                                                  num_workers=4,
                                                  drop_last=True)
        
        mmdValues = []
        
        for i, primaryData in enumerate(primaryTrainLoader, 0):
            primaryDataInstance, primaryDataLabel = primaryData
            for j, helperData in enumerate(helperTrainLoader, 0):

                helperDataInstance, helperDataLabel = helperData
                mmdValue = kernelTwoSampleTest(primaryDataInstance, helperDataInstance)
                mmdValues.append(mmdValue)
        
        mmdValues = np.asarray(mmdValues)
        avgMMDValues.append(np.mean(mmdValues))
    

    # plot the average MMD Values
    fig = plt.figure()
    ax = plt.subplot(111)
    ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))

    plt.ylabel('Avg. MMD Value')
    plt.xlabel('Class')
    
    avgMMDValues = np.asarray(avgMMDValues)
    avgMMDValues[avgMMDValues<0] = 0
    
    # plot adjustments
    yAbove = np.max(avgMMDValues)+(np.max(avgMMDValues)/3)
    yStep = (np.max(avgMMDValues))/10
    
    plt.yticks(np.arange(0.0,yAbove,yStep))
    
    classNames = getClasses(primaryDomain)
    plt.title(primaryDomain+' - '+str(classNames[primaryClass]))
    
    # x-axis adjustments
    xReal = getClasses(helperDomain)
    ind = np.arange(0, len(xReal))
    ax.set_xticks(ind)
    ax.set_xticklabels(xReal)
    
    plt.xticks(rotation=45)
    plt.bar(ind, avgMMDValues, 0.50)   
    
    plotFolderName = resultDir+'mmdValues'+'/'+primaryDomain+'/'
    checkAndCreateFolder(plotFolderName)

    plotFileName = plotFolderName+primaryDomain+'_'+str(primaryClass)+'_'+str(helperDomain)+'_'+str(batchSize)+'.png'
    plt.savefig(plotFileName, bbox_inches='tight')
    plt.show()
Ejemplo n.º 10
0
    def train(self):
        
        domainB_iter = iter(self.domainB_loader)    
        domainA_iter = iter(self.domainA_loader)
        
        
        iter_per_epoch = min(len(domainB_iter), len(domainA_iter))
        
        # fixed domainA and domainB for sampling
        fixed_domainB = to_var(domainB_iter.next()[0])
        fixed_domainA = to_var(domainA_iter.next()[0])
        
        # loss if use_labels = True
        criterion = nn.CrossEntropyLoss()
        
        for step in range(self.epochs+1):
            # reset data_iter for each epoch
            if (step+1) % iter_per_epoch == 0:
                domainA_iter = iter(self.domainA_loader)
                domainB_iter = iter(self.domainB_loader)
            
            # load domainB and domainA dataset
            domainB, s_labels = domainB_iter.next() 
            domainB, s_labels = to_var(domainB), to_var(s_labels).long().squeeze()
            domainA, m_labels = domainA_iter.next()
            domainA, m_labels = to_var(domainA), to_var(m_labels)

            if self.use_labels:
                domainA_fake_labels = to_var(
                    torch.Tensor([self.num_classes]*domainB.size(0)).long())
                domainB_fake_labels = to_var(
                    torch.Tensor([self.num_classes]*domainA.size(0)).long())
            
            #============ train D ============#
            
            # train with real images
            self.reset_grad()
            out = self.d1(domainA)
            if self.use_labels:
                d1_loss = criterion(out, m_labels)
            else:
                d1_loss = torch.mean((out-1)**2)
            
            out = self.d2(domainB)
            if self.use_labels:
                d2_loss = criterion(out, s_labels)
            else:
                d2_loss = torch.mean((out-1)**2)
            
            d_domainA_loss = d1_loss
            d_domainB_loss = d2_loss
            d_real_loss = d1_loss + d2_loss
            d_real_loss.backward()
            self.d_optimizer.step()
            
            # train with fake images
            self.reset_grad()
            fake_domainB = self.g12(domainA)
            out = self.d2(fake_domainB)
            if self.use_labels:
                d2_loss = criterion(out, domainB_fake_labels)
            else:
                d2_loss = torch.mean(out**2)
            
            fake_domainA = self.g21(domainB)
            out = self.d1(fake_domainA)
            if self.use_labels:
                d1_loss = criterion(out, domainA_fake_labels)
            else:
                d1_loss = torch.mean(out**2)
            
            d_fake_loss = d1_loss + d2_loss
            d_fake_loss.backward()
            self.d_optimizer.step()
            
            #============ train G ============#
            
            # train domainA-domainB-domainA cycle
            self.reset_grad()
            fake_domainB = self.g12(domainA)
            out = self.d2(fake_domainB)
            reconst_domainA = self.g21(fake_domainB)
            if self.use_labels:
                g_loss = criterion(out, m_labels) 
            else:
                g_loss = torch.mean((out-1)**2) 

            if self.use_reconst_loss:
                g_loss += torch.mean((domainA - reconst_domainA)**2)

            g_loss.backward()
            self.g_optimizer.step()

            # train domainB-domainA-domainB cycle
            self.reset_grad()
            fake_domainA = self.g21(domainB)
            out = self.d1(fake_domainA)
            reconst_domainB = self.g12(fake_domainA)
            if self.use_labels:
                g_loss = criterion(out, s_labels) 
            else:
                g_loss = torch.mean((out-1)**2) 

            if self.use_reconst_loss:
                g_loss += torch.mean((domainB - reconst_domainB)**2)

            g_loss.backward()
            self.g_optimizer.step()
            
            # print the log info
            if (step+1) % self.log_step == 0:
                print('Step [%d/%d], d_real_loss: %.4f, d_domainA_loss: %.4f, d_domainB_loss: %.4f, '
                      'd_fake_loss: %.4f, g_loss: %.4f' 
                      %(step+1, self.epochs, d_real_loss.data[0], d_domainA_loss.data[0], 
                        d_domainB_loss.data[0], d_fake_loss.data[0], g_loss.data[0]))

            # save the sampled images
            if (step+1) % self.sample_step == 0:
                
                fake_domainB = self.g12(fixed_domainA)
                fake_domainA = self.g21(fixed_domainB)
                
                domainA, fake_domainA = to_data(fixed_domainA), to_data(fake_domainA)
                domainB , fake_domainB = to_data(fixed_domainB), to_data(fake_domainB)
                                
                merged = self.merge_images(domainA, fake_domainB)
                path = os.path.join(self.sample_path, self.name)
                checkAndCreateFolder(path)
                path = os.path.join(self.sample_path, self.name, 'sample-%d-m-s.png' %(step+1))
                scipy.misc.imsave(path, merged)
                print ('saved %s' %path)
                
                merged = self.merge_images(domainB, fake_domainA)
                path = os.path.join(self.sample_path, self.name, 'sample-%d-s-m.png' %(step+1))
               
                scipy.misc.imsave(path, merged)
                print ('saved %s' %path)
            
            if (step+1) % self.log_step == 0:
                # save the model parameters for each epoch
                g12_path = os.path.join(self.model_path, self.name)
                checkAndCreateFolder(g12_path)
                g12_path = os.path.join(self.model_path, self.name,  'g12-%d.pkl' %(step+1))
                g21_path = os.path.join(self.model_path, self.name, 'g21-%d.pkl' %(step+1))
                d1_path = os.path.join(self.model_path, self.name, 'd1-%d.pkl' %(step+1))
                d2_path = os.path.join(self.model_path, self.name, 'd2-%d.pkl' %(step+1))
                torch.save(self.g12.state_dict(), g12_path)
                torch.save(self.g21.state_dict(), g21_path)
                torch.save(self.d1.state_dict(), d1_path)
                torch.save(self.d2.state_dict(), d2_path)
Ejemplo n.º 11
0
def test(primaryDataSet, helperDataSet, primaryClass, helperClass, primaryInstances, helperInstances):
    '''
    Inputs :
    
    dataSets : List : Datasets for which samples are to be genrated
    instances : List : Number of instances to be used from original dataset
    classes : List : Classes for which samples are to be generated
    
    Output :
    
    File with 1000 compressed images generated by GAN
    
    '''
    helperClass = primaryClass
    
    modelFolder = resultDir + 'models/crossDataSetMMDall'+'/'+primaryDataSet+'/'
    print (primaryDataSet, helperDataSet, primaryClass, helperClass, primaryInstances, helperInstances, getEpochs(primaryDataSet,primaryInstances)-1)
    modelFile = modelFolder + primaryDataSet + '_' + helperDataSet + '_' + \
                str(primaryClass) + '_' + str(helperClass) + '_' + \
                str(primaryInstances) + '_' + str(helperInstances)+'_'+ \
                str(getEpochs(primaryDataSet,primaryInstances)-1)+'.pt'
        

    
    print ('Generating examples for Dataset: '+primaryDataSet+
           ' Primary Class: '+str(primaryClass)+
           ' Helper Class: '+str(helperClass)+
           ' Primary Instances: '+str(primaryInstances)+
           ' Helper Instances: '+str(helperInstances)+
           ' Epochs: '+str(getEpochs(primaryDataSet,primaryInstances)))
    
    numOutputChannels = getChannels(primaryDataSet)
    
    # load the model learnt during training
    G = Generator(numInputChannels, numGenFilter, numOutputChannels)
    G.load_state_dict(torch.load(modelFile))
    genImageConcat = np.empty(1)
    
    iterations = numOfSamples/batchSize
    
    for iteration in range(iterations):
        noise = torch.FloatTensor(batchSize,
                                  numInputChannels,
                                  1,
                                  1)
        noise.normal_(0,1)

        if cuda:
            G = G.cuda()
            noise = noise.cuda()
        noiseVariable = Variable(noise)

        genImage = G(noiseVariable)
        genImage = genImage.data
        genImage = genImage.cpu()
        genImage = genImage.numpy()
        
        
        if iteration==0:
            genImageConcat = genImage
        else:
            genImageConcat = np.concatenate((genImageConcat, genImage),
                                            axis=0)
            
        if iteration==(iterations-1):
            
            # normalize sets image pixels between 0 and 1
            genImage = torchvision.utils.make_grid(torch.from_numpy(genImage[:25]), nrow=5, normalize=True)
            
            # mapping between 0 to 1 as required by imshow,
            # otherwise the images are stored in the form -1 to 1
            # done through normalize=True
            
            genImage = genImage.permute(1,2,0)
            genImage = genImage.numpy()

            plt.imshow(genImage, cmap='gray')
            
            plotFolder = resultDir+'results'+'/'+'crossDataSetMMDall/samples'+'/'+primaryDataSet+'/'
            checkAndCreateFolder(plotFolder)
            plotFile = primaryDataSet + '_' + helperDataSet + '_' + \
                str(primaryClass) + '_' + 'all' + '_' + \
                str(primaryInstances) + '_' + str(helperInstances) 
            plotPath = plotFolder + plotFile
            
            plt.axis('off')
            plt.savefig(plotPath, bbox_inches='tight')
            plt.show()
            

    resultFolder = resultDir+'results/crossDataSetMMDall'+'/'+'compressed'+'/'+primaryDataSet+'/'
    checkAndCreateFolder(resultFolder)
    resultFile = primaryDataSet + '_' + helperDataSet + '_' + \
                 str(primaryClass) + '_' + 'all' + '_' + \
                 str(primaryInstances) + '_' + str(helperInstances)  + '.npy'
            
    resultPath = resultFolder + resultFile
    
    # save the image in some format
    with open(resultPath,'wb+') as fh:
        genImageConcat = np.squeeze(genImageConcat)
        np_save(fh, genImageConcat, allow_pickle=False)
        sync(fh)
Ejemplo n.º 12
0
def MMDhist(dataSet, primaryClass, primaryInstance, helperInstances,
            batchSize):

    # load the datasets for which the bar graph needs to be plotted
    x = loadDataset(dataSet, [primaryClass], [primaryInstance])

    primaryClasses = [i for i in range(10)]
    primaryClasses = primaryClasses[:primaryClass] + primaryClasses[
        primaryClass + 1:]
    primaryInstances = [primaryInstance for i in range(10)]

    y = loadDataset(dataSet, primaryClasses, primaryInstances)

    primaryTrainLoader = torch.utils.data.DataLoader(x,
                                                     batch_size=batchSize,
                                                     shuffle=True,
                                                     num_workers=4,
                                                     drop_last=True)
    helperTrainLoader = torch.utils.data.DataLoader(y,
                                                    batch_size=batchSize,
                                                    shuffle=True,
                                                    num_workers=4,
                                                    drop_last=True)

    mmdValues = []
    minMMDValue = 1.5
    maxMMDValue = -0.1

    minPrimaryDataInstance = torch.FloatTensor(batchSize, getChannels(dataSet),
                                               getImageSize(dataSet),
                                               getImageSize(dataSet)).zero_()
    minHelperDataInstance = minPrimaryDataInstance
    maxPrimaryDataInstance = minPrimaryDataInstance
    maxHelperDataInstance = minPrimaryDataInstance

    for i, primaryData in enumerate(primaryTrainLoader, 0):
        primaryDataInstance, primaryDataLabel = primaryData
        for j, helperData in enumerate(helperTrainLoader, 0):

            helperDataInstance, helperDataLabel = helperData

            mmdValue = kernelTwoSampleTest(primaryDataInstance,
                                           helperDataInstance)

            # choosing the pair with minimum MMD
            if minMMDValue > mmdValue:
                minMMDValue = mmdValue
                minPrimaryDataInstance = primaryDataInstance
                minHelperDataInstance = helperDataInstance

            # choosing the pair with maximum MMD
            if maxMMDValue < mmdValue:
                maxMMDValue = mmdValue
                maxPrimaryDataInstance = primaryDataInstance
                maxHelperDataInstance = helperDataInstance

            mmdValues.append(mmdValue)

            #mmdValues.append (1-kernelTwoSampleTest(primaryDataInstance, helperDataInstance))

    displayImage(maxPrimaryDataInstance, maxHelperDataInstance, maxMMDValue,
                 dataSet, primaryClass, primaryInstance, helperInstances,
                 'max')
    displayImage(minPrimaryDataInstance, minHelperDataInstance, minMMDValue,
                 dataSet, primaryClass, primaryInstance, helperInstances,
                 'min')

    mmdValues = np.asarray(mmdValues)
    plt.figure()
    plt.hist(mmdValues, ec='k')

    classes = getClasses(dataSet)

    plt.plot()
    plt.xlabel('$MMD^2$ between batch of Primary Class and Helper Class')
    plt.ylabel('Number of Batch Pairs have $MMD^2$ value in that range')
    # "\n".join(wrap())
    plt.title(
        ' Dataset: {} \n Primary Class: {}  \n   Primary Instances:{} \n  Helper Instances:{} \n Batch Size:{}'
        .format(dataSet, classes[primaryClass], primaryInstance,
                helperInstances, batchSize))

    saveFolder = resultDir + 'mmdValues' + '/' + 'hist' + '/' + dataSet + '/'
    checkAndCreateFolder(saveFolder)
    saveFile = saveFolder + dataSet + '_' + str(primaryClass) + '_' + str(
        primaryInstance) + '_' + str(helperInstances) + '_' + str(
            batchSize) + '.png'

    plt.savefig(saveFile, bbox_inches='tight')
    plt.show()