def displayImage(primaryArray, helperArray, mmdScore, dataSet, primaryClass, primaryInstances, helperInstances, mode): fig = plt.figure() a=fig.add_subplot(1,2,1) a.set_title('$X_{batch}$') genImage = torchvision.utils.make_grid(primaryArray[:49], nrow=7, normalize=True) genImage = genImage.permute(1,2,0) genImage = genImage.numpy() plt.axis('off') plt.imshow(genImage) a=fig.add_subplot(1,2,2) a.set_title('$Y_{batch}$') genImage = torchvision.utils.make_grid(helperArray[:49], nrow=7, normalize=True) genImage = genImage.permute(1,2,0) genImage = genImage.numpy() plt.axis('off') plt.imshow(genImage) plt.text(-100,280,'$MMD^{2}(X_{batch},Y_{batch})$: '+str(round(mmdScore,3))) plotFolderName = resultDir+'mmdValues'+'/'+mode+'/'+dataSet+'/' checkAndCreateFolder(plotFolderName) plotFileName = \ plotFolderName+dataSet+'_'+str(primaryClass)+'_'+str(primaryInstances)+'_'+str(helperInstances)+'_'+str(batchSize)+'.png' plt.savefig(plotFileName, bbox_inches='tight') plt.show()
def showTrainHist(trainHist, fileName, epoch): ''' Plot Generator and Discriminator loss function ''' x = range(len(trainHist['discLoss'])) y1 = trainHist['discLoss'] y2 = trainHist['genLoss'] plt.plot(x, y1, label='D_loss') plt.plot(x, y2, label='G_loss') plt.xlabel('Iter') plt.ylabel('Loss') plt.legend(loc='upper right') plt.grid(True) plt.tight_layout() folder = fileName.split('_')[0] lossFolderName = resultDir + 'loss/nonMMD' + '/' + folder + '/' checkAndCreateFolder(lossFolderName) lossFileName = lossFolderName + fileName + '_' + str(epoch) + '.png' plt.savefig(lossFileName, bbox_inches='tight') plt.show()
def saveImageSamples(dataSet, classes, numOfInstances): ''' dataset : string classes : list ''' x,y = getRealData(dataSet, classes, numOfInstances, mode='train') numOfChannels = getChannels(dataSet) sampleImage = getImageSamples(x, numOfChannels=numOfChannels) # plot the figure of generated samples and save fig = plt.figure() plt.imshow(sampleImage, cmap='gray') plt.axis('off') plotFolderName = 'samples'+'/'+dataSet+'/' checkAndCreateFolder(plotFolderName) if len(classes)==9: plotFileName = dataSet+'_'+'all'+'.png' else : plotFileName = dataSet+'_'+str(classes[0])+'.png' plt.savefig(plotFolderName + plotFileName, bbox_inches='tight') plt.show()
def MMDBar(dataSet, yReal, primaryClass, primaryInstances, helperInstances, batchSize): ''' Plot Avg. MMD value within a dataset as a histogram ''' fig = plt.figure() ax = plt.subplot(111) plt.ylabel('Avg. MMD Value') plt.xlabel('Class') yReal = np.asarray(yReal) yReal[yReal < 0] = 0 yAbove = np.max(yReal) + (np.max(yReal) / 3) yStep = (yAbove - np.min(yReal)) / 10 plt.yticks(np.arange(0.0, yAbove, yStep)) classNames = getClasses(dataSet) plt.title(dataSet + ' - ' + str(classNames[primaryClass])) # plot accuracy xReal = getClasses(dataSet) ind = np.arange(0, 10) ax.set_xticks(ind) ax.set_xticklabels(xReal) plt.xticks(rotation=45) plt.bar(ind, yReal, 0.50) plotFolderName = resultDir + 'mmdValues' + '/' + dataSet + '/' checkAndCreateFolder(plotFolderName) plotFileName = plotFolderName + dataSet + '_' + str( primaryClass) + '_' + str(primaryInstances) + '_' + str( helperInstances) + '_' + str(batchSize) + '.png' plt.savefig(plotFileName, bbox_inches='tight') plt.show()
def train(fileName, trainLoader, primaryInstanceList, numClasses, numOutputChannels=1, learningRate=0.0002, optimBetas=(0.5, 0.999), epochs=5): ''' Training for Deep Convolutional Generative Adversatial Network ''' folder = fileName.split('_')[0] instances = sum(primaryInstanceList) # generator takes input channels, number of labels, number of generative filters, number of output channels G = Generator(numInputChannels, numClasses, numGenFilter, numOutputChannels) D = Discriminator(numOutputChannels, numClasses, numDiscFilter) G.weight_init(mean=0.0, std=0.02) D.weight_init(mean=0.0, std=0.02) lossFunction = nn.BCELoss() genOptimiser = optim.Adam(G.parameters(), lr=learningRate, betas=optimBetas) disOptimiser = optim.Adam(D.parameters(), lr=learningRate, betas=optimBetas) numElementsNeededPerClass = 10 fixedNoise = torch.randn(numElementsNeededPerClass * numClasses, numInputChannels, 1, 1) # class from which the GAN should output a distribution fixedNoiseClass = torch.zeros(numElementsNeededPerClass * numClasses, numClasses, 1, 1) classIndex = torch.zeros(numElementsNeededPerClass, 1) for i in range(numClasses - 1): temp = torch.ones(numElementsNeededPerClass, 1) + i classIndex = torch.cat([classIndex, temp], 0) fixedNoiseClass = fixedNoiseClass.squeeze().scatter_( 1, classIndex.type(torch.LongTensor), 1) fixedNoiseClass = fixedNoiseClass.view(-1, numClasses, 1, 1) # added imagesize discRealInput = torch.FloatTensor(batchSize, numOutputChannels, imageSize, imageSize) discRealInputClass = torch.zeros(batchSize, numClasses, imageSize, imageSize) discFakeInput = torch.FloatTensor(batchSize, numInputChannels, 1, 1) discFakeInputClass = torch.zeros(batchSize, numClasses, 1, 1) discRealLabel = torch.FloatTensor(batchSize) discRealLabel.fill_(1) discFakeLabel = torch.FloatTensor(batchSize) discFakeLabel.fill_(0) if instances < batchSize: discRealInput = torch.FloatTensor(instances, numOutputChannels, imageSize, imageSize) discFakeInput = torch.FloatTensor(instances, numInputChannels, 1, 1) discRealInputClass = torch.zeros(instances, numClasses, imageSize, imageSize) discFakeInputClass = torch.zeros(instances, numClasses, 1, 1) discRealLabel = torch.FloatTensor(instances) discRealLabel.fill_(1) discFakeLabel = torch.FloatTensor(instances) discFakeLabel.fill_(0) if cuda: G = G.cuda() D = D.cuda() lossFunction = lossFunction.cuda() discRealInput = discRealInput.cuda() discFakeInput = discFakeInput.cuda() discRealInputClass = discRealInputClass.cuda() discFakeInputClass = discFakeInputClass.cuda() discRealLabel = discRealLabel.cuda() discFakeLabel = discFakeLabel.cuda() fixedNoise = fixedNoise.cuda() fixedNoiseClass = fixedNoiseClass.cuda() fixedNoiseVariable = Variable(fixedNoise) fixedNoiseClassVariable = Variable(fixedNoiseClass) # can take the oneHot representation to feed into generator directly from here oneHotGen = torch.zeros(numClasses, numClasses) oneHotGen = oneHotGen.scatter_( 1, torch.LongTensor([i for i in range(numClasses)]).view(numClasses, 1), 1).view(numClasses, numClasses, 1, 1) # can take the oneHot representation to feed into discriminator directly from here oneHotDisc = torch.zeros([numClasses, numClasses, imageSize, imageSize]) for i in range(numClasses): oneHotDisc[i, i, :, :] = 1 trainHist = {} trainHist['discLoss'] = [] trainHist['genLoss'] = [] trainHist['perEpochTime'] = [] trainHist['totalTime'] = [] imageList = [] for epoch in range(epochs): generatorLosses = [] discriminatorLosses = [] epochStartTime = time.time() for i, data in enumerate(trainLoader, 0): # train discriminator on real data, D.zero_grad() dataInstance, dataClass = data # one-hot encoding for discriminator class input dataClass = oneHotDisc[dataClass] if cuda: dataInstance = dataInstance.cuda() dataClass = dataClass.cuda() discRealInput.copy_(dataInstance) discRealInputClass.copy_(dataClass) discRealInputVariable = Variable(discRealInput) discRealInputClassVariable = Variable(discRealInputClass) discRealLabelVariable = Variable(discRealLabel) discRealOutput = D(discRealInputVariable, discRealInputClassVariable) lossRealDisc = lossFunction(discRealOutput, discRealLabelVariable) lossRealDisc.backward() # train discriminator on fake data discFakeInput.normal_(0, 1) # change this as instances x numClasses x imagesize x imagesize if instances < batchSize: #dataFakeClass = (torch.rand(instances)*numClasses).type(torch.LongTensor) dataFakeClass = torch.from_numpy( np.random.choice(numClasses, instances, p=getProbDist(primaryInstanceList))) else: #dataFakeClass = (torch.rand(batchSize)*numClasses).type(torch.LongTensor) dataFakeClass = torch.from_numpy( np.random.choice(numClasses, batchSize, p=getProbDist(primaryInstanceList))) discFakeInputClass = oneHotDisc[dataFakeClass] genFakeInputClass = oneHotGen[dataFakeClass] if cuda: discFakeInputClass = discFakeInputClass.cuda() genFakeInputClass = genFakeInputClass.cuda() discFakeInputVariable = Variable(discFakeInput) discFakeInputClassVariable = Variable(discFakeInputClass) genFakeInputClassVariable = Variable(genFakeInputClass) discFakeLabelVariable = Variable(discFakeLabel) discFakeInputGen = G(discFakeInputVariable, genFakeInputClassVariable) # change the gradients of discriminator only discFakeOutput = D(discFakeInputGen.detach(), discFakeInputClassVariable) lossFakeDisc = lossFunction(discFakeOutput, discFakeLabelVariable) lossFakeDisc.backward() disOptimiser.step() # log the loss for discriminator discriminatorLosses.append((lossRealDisc + lossFakeDisc).data[0]) # train generator based on discriminator G.zero_grad() genInputVariable = discFakeInputGen # get the class function here genOutputDisc = D(genInputVariable, discFakeInputClassVariable) lossGen = lossFunction(genOutputDisc, discRealLabelVariable) lossGen.backward() genOptimiser.step() # log the loss for generator generatorLosses.append(lossGen.data[0]) # create an image for every epoch # generate samples from trained generator genImage = G(fixedNoiseVariable, fixedNoiseClassVariable) genImage = genImage.data genImage = genImage.cpu() genImage = torchvision.utils.make_grid(genImage, nrow=10) genImage = (genImage / 2) + 0.5 genImage = genImage.permute(1, 2, 0) genImage = genImage.numpy() plt.figure() fig = plt.figure(figsize=(20, 10)) plt.imshow(genImage) plt.axis('off') txt = 'Epoch: ' + str(epoch + 1) fig.text(.45, .05, txt) plt.savefig('y.png', bbox_inches='tight') imageList.append(imageio.imread('y.png')) epochEndTime = time.time() perEpochTime = epochEndTime - epochStartTime discLoss = torch.mean(torch.FloatTensor(discriminatorLosses)) genLoss = torch.mean(torch.FloatTensor(generatorLosses)) print('Epoch : [%d/%d] time: %.2f, loss_d: %.3f, loss_g: %.3f' % (epoch + 1, epochs, perEpochTime, discLoss, genLoss)) if epoch == (epochs - 1): print('Completed processing ' + str(instances) + 'for ' + str(epochs) + 'epochs.') plotFolderName = resultDir + 'plots/nonMMD' + '/' + folder + '/' checkAndCreateFolder(plotFolderName) plotFileName = plotFolderName + fileName + '_' + str( epochs) + '.png' plt.imshow(genImage) plt.savefig(plotFileName, bbox_inches='tight') plt.close('all') # create gif animation animFolderName = resultDir + 'animation/nonMMD' + '/' + folder + '/' checkAndCreateFolder(animFolderName) animFileName = animFolderName + fileName + '_' + str( epochs) + '.gif' imageio.mimsave(animFileName, imageList, fps=5) trainHist['discLoss'].append(discLoss) trainHist['genLoss'].append(genLoss) # save the model parameters in a file modelFolderName = resultDir + 'models/nonMMD' + '/' + folder + '/' checkAndCreateFolder(modelFolderName) modelFileName = modelFolderName + fileName + '_' + str(epochs) + '.pt' torch.save(G.state_dict(), modelFileName) showTrainHist(trainHist, fileName, epoch)
def plotAccuracyBar(dataSet, classifier, primaryInstances, helperInstances, accuracyArray, showImage = 1): fig = plt.figure() ax = fig.add_subplot(111) ind = np.arange(len(primaryInstances)) width = 0.15 delta = 0.02 color = ['black', 'red', 'blue', 'green'] histList = [] legendList = [] for i in range(accuracyArray.shape[0]): hist = ax.bar(ind+i*(width+delta), list(accuracyArray[:,i]), width, color=color[i]) histList.append(hist) if i==0: legendList.append('Original') elif i==1: legendList.append('GAN - 0 Helper Instances') else: legendList.append('GAN - '+str(helperInstances[i-2])+' Helper Instances') print (accuracyArray) ax.set_xlim(-4*width,len(ind)+width) ax.set_ylim(0,1.5) ax.set_ylabel('Accuracy [out of 1]') ax.set_xlabel('Number of Primary Instances') ax.set_title(dataSet) ax.set_xticks(ind+width) xtickNames = ax.set_xticklabels(primaryInstances) plt.setp(xtickNames, rotation=45, fontsize=10) # add to tuple ax.legend( tuple(histList), tuple(legendList) , loc='upper left') ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f')) plotFolderName = 'plots/crossDataSetMMDall/accuracy'+'/'+dataSet+'/' checkAndCreateFolder(plotFolderName) plotFileName = plotFolderName+dataSet+'_'+classifier+'.png' plt.savefig(plotFileName, bbox_inches='tight') plt.show() plt.close() accuracyFolderName = 'plots/crossDataSetMMDall/accuracyValues'+'/'+dataSet+'/' checkAndCreateFolder(accuracyFolderName) accuracyFileName = accuracyFolderName+dataSet+'_'+classifier+'.npy' # save the image in some format with open(accuracyFileName,'wb+') as fh: np_save(fh, accuracyArray, allow_pickle=False) sync(fh)
def plotConfusionMatrix(dataSet, classifier, classes, trainSet, numOfInstances, cm, normalize=False, numOfHelperInstances=-1, title='Confusion matrix', cmap='Reds'): """ trainSet: can be 'Real','Fake' or 'FakeMMD' This function prints and plots the confusion matrix. Normalization can be applied by setting `normalize=True`. """ # Plot normalized confusion matrix fig = plt.figure() if numOfHelperInstances==-1: fileName = dataSet+'_'+trainSet+'_'+classifier+'_'+str(numOfInstances) else: fileName = dataSet+'_'+trainSet+'_'+classifier+'_'+str(numOfInstances)+'_'+str(numOfHelperInstances) plotFolderName = 'plots/crossDataSetMMDall/cm'+'/'+dataSet+'/' checkAndCreateFolder(plotFolderName) plotFileName = plotFolderName+fileName+'.png' if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] print("Normalized confusion matrix") else: print('Confusion matrix, without normalization') plt.imshow(cm, interpolation='nearest', cmap=cmap) plt.title(title) plt.colorbar() tick_marks = np.arange(len(classes)) plt.xticks(tick_marks, classes, rotation=45) plt.yticks(tick_marks, classes) fmt = '.2f' if normalize else 'd' thresh = cm.max() / 2. for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): w = ' '+str(format(cm[i, j], fmt))+' ' plt.text(j, i, w, horizontalalignment="center", color="white" if cm[i, j] > thresh else "black") #plt.tight_layout() plt.ylabel('True label') plt.xlabel('Predicted label') plt.savefig(plotFileName, bbox_inches='tight') # change this to get correct order !! if showImage==0: plt.show() plt.close()
def train(primaryFileName, helperFileName, primaryTrainLoader, helperTrainLoader, primaryInstances, numOutputChannels=1, learningRate=0.0002, optimBetas=(0.5, 0.999), epochs=5): ''' Training for Deep Convolutional Generative Adversatial Network ''' # define the model G = Generator(numInputChannels, numGenFilter, numOutputChannels) D = Discriminator(numOutputChannels, numDiscFilter) lossFunction = nn.BCELoss() genOptimiser = optim.Adam(G.parameters(), lr=learningRate, betas=optimBetas) disOptimiser = optim.Adam(D.parameters(), lr=learningRate, betas=optimBetas) discRealInput = torch.FloatTensor(batchSize, numOutputChannels, imageSize, imageSize) discFakeInput = torch.FloatTensor(batchSize, numInputChannels, 1, 1) fixedNoise = torch.FloatTensor(25, numInputChannels, 1, 1) fixedNoise.normal_(0, 1) discRealLabel = torch.FloatTensor(batchSize) discFakeLabel = torch.FloatTensor(batchSize) discRealLabel.fill_(1) discFakeLabel.fill_(0) # for processing on a GPU if cuda: G = G.cuda() D = D.cuda() lossFunction = lossFunction.cuda() discRealInput = discRealInput.cuda() discFakeInput = discFakeInput.cuda() discRealLabel = discRealLabel.cuda() discFakeLabel = discFakeLabel.cuda() fixedNoise = fixedNoise.cuda() fixedNoiseVariable = Variable(fixedNoise) # assume that helper primaryInstances are always more than primary primaryInstances # define primary Epochs and helper Epochs primaryEpochs = epochs helperEpochs = 10 print("Starting training with helper classes.") plt.figure() # training with helper class for epoch in range(helperEpochs): for i, primaryData in enumerate(primaryTrainLoader, 0): for j, helperData in enumerate(helperTrainLoader, 0): #print ('Epoch : {} Primary Class Batch : {}. Helper Class Batch : {}.'.format(epoch+1,i+1,j+1)) primaryDataInstance, primaryDataLabel = primaryData helperDataInstance, helperDataLabel = helperData # calculate MMD between two batches of data mmd = (1 - kernelTwoSampleTest(primaryDataInstance, helperDataInstance)) mmd = torch.from_numpy(np.asarray([mmd])) mmdVariable = Variable(mmd.float().cuda()) # weight given to the term lambdaMMD = 1.0 lambdaMMD = torch.from_numpy(np.asarray([lambdaMMD])) lambdaMMDVariable = Variable(lambdaMMD.float().cuda()) D.zero_grad() # train GAN using helper data instance if cuda: helperDataInstance = helperDataInstance.cuda() discRealInput.copy_(helperDataInstance) discRealInputVariable = Variable(discRealInput) # should we treat this as 1 ?? discRealLabelVariable = Variable(discRealLabel) discRealOutput = D(discRealInputVariable) lossRealDisc = lambdaMMDVariable * mmdVariable * lossFunction( discRealOutput, discRealLabelVariable) lossRealDisc.backward() # train discriminator on fake data discFakeInput.normal_(0, 1) discFakeInputVariable = Variable(discFakeInput) discFakeInputGen = G(discFakeInputVariable) discFakeLabelVariable = Variable(discFakeLabel) discFakeOutput = D(discFakeInputGen.detach()) lossFakeDisc = lossFunction(discFakeOutput, discFakeLabelVariable) lossFakeDisc.backward() disOptimiser.step() # train generator based on discriminator # for every epoch the gradients are reset to 0 # the discriminator should start to confuse fake primaryInstances # with true primaryInstances G.zero_grad() genInputVariable = discFakeInputGen genOutputDisc = D(genInputVariable) lossGen = lossFunction(genOutputDisc, discRealLabelVariable) lossGen.backward() genOptimiser.step() if (i == 0) and epoch == (helperEpochs - 1): #print ('Completed processing '+str(primaryInstances)+'for'+str(epoch)+'epochs.') # name for model and plot file folder, primaryClass, _ = primaryFileName.split('_') _, helperClass, helperInstances = helperFileName.split('_') fileName = folder + '_' + str(primaryClass) + '_' + str(helperClass) + '_' + \ str(primaryInstances) + '_' + str(helperInstances) # generate samples from trained generator genImage = G(fixedNoiseVariable) genImage = genImage.data genImage = genImage.cpu() genImage = torchvision.utils.make_grid(genImage, nrow=5) genImage = genImage / 2 + 0.5 genImage = genImage.permute(1, 2, 0) genImage = genImage.numpy() #print genImage.shape # plot the figure of generated samples and save fig = plt.figure() plt.imshow(genImage, cmap='gray') plt.axis('off') if primaryInstances < batchSize: discRealInput = torch.FloatTensor(primaryInstances, numOutputChannels, imageSize, imageSize) # why only one as width and height ? Passing through generator. discFakeInput = torch.FloatTensor(primaryInstances, numInputChannels, 1, 1) discRealLabel = torch.FloatTensor(primaryInstances) discFakeLabel = torch.FloatTensor(primaryInstances) discRealLabel.fill_(1) discFakeLabel.fill_(0) if cuda: discRealInput = discRealInput.cuda() discFakeInput = discFakeInput.cuda() discRealLabel = discRealLabel.cuda() discFakeLabel = discFakeLabel.cuda() print("Ending training with helper classes.") print("Starting training with primary class.") # training with primary class for epoch in range(primaryEpochs): for i, data in enumerate(primaryTrainLoader, 0): #print ('Epoch : {} Primary Class Batch : {}.'.format(epoch+1,i+1)) if i > 10000: print("Done 2000 Iterations") break # train discriminator on real data D.zero_grad() dataInstance, dataLabel = data if cuda: dataInstance = dataInstance.cuda() discRealInput.copy_(dataInstance) discRealInputVariable = Variable(discRealInput) discRealLabelVariable = Variable(discRealLabel) discRealOutput = D(discRealInputVariable) lossRealDisc = lossFunction(discRealOutput, discRealLabelVariable) lossRealDisc.backward() # train discriminator on fake data discFakeInput.normal_(0, 1) discFakeInputVariable = Variable(discFakeInput) discFakeInputGen = G(discFakeInputVariable) discFakeLabelVariable = Variable(discFakeLabel) discFakeOutput = D(discFakeInputGen.detach()) lossFakeDisc = lossFunction(discFakeOutput, discFakeLabelVariable) lossFakeDisc.backward() disOptimiser.step() # train generator based on discriminator # for every epoch the gradients are reset to 0 # the discriminator should start to confuse fake primaryInstances # with true primaryInstances G.zero_grad() genInputVariable = discFakeInputGen genOutputDisc = D(genInputVariable) lossGen = lossFunction(genOutputDisc, discRealLabelVariable) lossGen.backward() genOptimiser.step() if (i == 0) and epoch == (primaryEpochs - 1): folder, primaryClass, _ = primaryFileName.split('_') _, helperClass, helperInstances = helperFileName.split('_') fileName = folder + '_' + str(primaryClass) + '_' + str(helperClass) + '_' + \ str(primaryInstances) + '_' + str(helperInstances) modelFolder = resultDir + 'models/MMDall' + '/' + folder + '/' plotFolder = resultDir + 'plots/MMDall' + '/' + folder + '/' checkAndCreateFolder(modelFolder) checkAndCreateFolder(plotFolder) modelFileName = modelFolder + fileName + '_' + str( epoch) + '.pt' plotFileName = plotFolder + fileName + '_' + str( epoch) + '.png' # save the model parameters in a file torch.save(G.state_dict(), modelFileName) # generate samples from trained generator genImage = G(fixedNoiseVariable) genImage = genImage.data genImage = genImage.cpu() genImage = torchvision.utils.make_grid(genImage, nrow=5) genImage = genImage / 2 + 0.5 genImage = genImage.permute(1, 2, 0) genImage = genImage.numpy() # plot the figure of generated samples and save fig = plt.figure() plt.imshow(genImage, cmap='gray') plt.axis('off') txt = 'Epoch: ' + str(epoch) fig.text(.45, .05, txt) if showImage == 1: plt.show() ''' IPython.display.clear_output(wait=True) IPython.display.display(plt.gcf()) ''' plt.savefig(plotFileName, bbox_inches='tight') plt.close('all') print("Done trining with primary class primaryInstances.")
def MMDhist(primaryDomain, helperDomain, primaryClass, helperClassList, primaryInstance, helperInstance, batchSize): # load the datasets for which the bar graph needs to be plotted x = loadDataset(primaryDomain, [primaryClass], [primaryInstance], mode='train') primaryTrainLoader = torch.utils.data.DataLoader(x, batch_size=batchSize, shuffle=True, num_workers=4, drop_last = True) avgMMDValues = [] for helperClass in helperClassList : y = loadDataset(helperDomain, [helperClass], [helperInstance], mode='train') helperTrainLoader = torch.utils.data.DataLoader(y, batch_size=batchSize, shuffle=True, num_workers=4, drop_last=True) mmdValues = [] for i, primaryData in enumerate(primaryTrainLoader, 0): primaryDataInstance, primaryDataLabel = primaryData for j, helperData in enumerate(helperTrainLoader, 0): helperDataInstance, helperDataLabel = helperData mmdValue = kernelTwoSampleTest(primaryDataInstance, helperDataInstance) mmdValues.append(mmdValue) mmdValues = np.asarray(mmdValues) avgMMDValues.append(np.mean(mmdValues)) # plot the average MMD Values fig = plt.figure() ax = plt.subplot(111) ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f')) plt.ylabel('Avg. MMD Value') plt.xlabel('Class') avgMMDValues = np.asarray(avgMMDValues) avgMMDValues[avgMMDValues<0] = 0 # plot adjustments yAbove = np.max(avgMMDValues)+(np.max(avgMMDValues)/3) yStep = (np.max(avgMMDValues))/10 plt.yticks(np.arange(0.0,yAbove,yStep)) classNames = getClasses(primaryDomain) plt.title(primaryDomain+' - '+str(classNames[primaryClass])) # x-axis adjustments xReal = getClasses(helperDomain) ind = np.arange(0, len(xReal)) ax.set_xticks(ind) ax.set_xticklabels(xReal) plt.xticks(rotation=45) plt.bar(ind, avgMMDValues, 0.50) plotFolderName = resultDir+'mmdValues'+'/'+primaryDomain+'/' checkAndCreateFolder(plotFolderName) plotFileName = plotFolderName+primaryDomain+'_'+str(primaryClass)+'_'+str(helperDomain)+'_'+str(batchSize)+'.png' plt.savefig(plotFileName, bbox_inches='tight') plt.show()
def train(self): domainB_iter = iter(self.domainB_loader) domainA_iter = iter(self.domainA_loader) iter_per_epoch = min(len(domainB_iter), len(domainA_iter)) # fixed domainA and domainB for sampling fixed_domainB = to_var(domainB_iter.next()[0]) fixed_domainA = to_var(domainA_iter.next()[0]) # loss if use_labels = True criterion = nn.CrossEntropyLoss() for step in range(self.epochs+1): # reset data_iter for each epoch if (step+1) % iter_per_epoch == 0: domainA_iter = iter(self.domainA_loader) domainB_iter = iter(self.domainB_loader) # load domainB and domainA dataset domainB, s_labels = domainB_iter.next() domainB, s_labels = to_var(domainB), to_var(s_labels).long().squeeze() domainA, m_labels = domainA_iter.next() domainA, m_labels = to_var(domainA), to_var(m_labels) if self.use_labels: domainA_fake_labels = to_var( torch.Tensor([self.num_classes]*domainB.size(0)).long()) domainB_fake_labels = to_var( torch.Tensor([self.num_classes]*domainA.size(0)).long()) #============ train D ============# # train with real images self.reset_grad() out = self.d1(domainA) if self.use_labels: d1_loss = criterion(out, m_labels) else: d1_loss = torch.mean((out-1)**2) out = self.d2(domainB) if self.use_labels: d2_loss = criterion(out, s_labels) else: d2_loss = torch.mean((out-1)**2) d_domainA_loss = d1_loss d_domainB_loss = d2_loss d_real_loss = d1_loss + d2_loss d_real_loss.backward() self.d_optimizer.step() # train with fake images self.reset_grad() fake_domainB = self.g12(domainA) out = self.d2(fake_domainB) if self.use_labels: d2_loss = criterion(out, domainB_fake_labels) else: d2_loss = torch.mean(out**2) fake_domainA = self.g21(domainB) out = self.d1(fake_domainA) if self.use_labels: d1_loss = criterion(out, domainA_fake_labels) else: d1_loss = torch.mean(out**2) d_fake_loss = d1_loss + d2_loss d_fake_loss.backward() self.d_optimizer.step() #============ train G ============# # train domainA-domainB-domainA cycle self.reset_grad() fake_domainB = self.g12(domainA) out = self.d2(fake_domainB) reconst_domainA = self.g21(fake_domainB) if self.use_labels: g_loss = criterion(out, m_labels) else: g_loss = torch.mean((out-1)**2) if self.use_reconst_loss: g_loss += torch.mean((domainA - reconst_domainA)**2) g_loss.backward() self.g_optimizer.step() # train domainB-domainA-domainB cycle self.reset_grad() fake_domainA = self.g21(domainB) out = self.d1(fake_domainA) reconst_domainB = self.g12(fake_domainA) if self.use_labels: g_loss = criterion(out, s_labels) else: g_loss = torch.mean((out-1)**2) if self.use_reconst_loss: g_loss += torch.mean((domainB - reconst_domainB)**2) g_loss.backward() self.g_optimizer.step() # print the log info if (step+1) % self.log_step == 0: print('Step [%d/%d], d_real_loss: %.4f, d_domainA_loss: %.4f, d_domainB_loss: %.4f, ' 'd_fake_loss: %.4f, g_loss: %.4f' %(step+1, self.epochs, d_real_loss.data[0], d_domainA_loss.data[0], d_domainB_loss.data[0], d_fake_loss.data[0], g_loss.data[0])) # save the sampled images if (step+1) % self.sample_step == 0: fake_domainB = self.g12(fixed_domainA) fake_domainA = self.g21(fixed_domainB) domainA, fake_domainA = to_data(fixed_domainA), to_data(fake_domainA) domainB , fake_domainB = to_data(fixed_domainB), to_data(fake_domainB) merged = self.merge_images(domainA, fake_domainB) path = os.path.join(self.sample_path, self.name) checkAndCreateFolder(path) path = os.path.join(self.sample_path, self.name, 'sample-%d-m-s.png' %(step+1)) scipy.misc.imsave(path, merged) print ('saved %s' %path) merged = self.merge_images(domainB, fake_domainA) path = os.path.join(self.sample_path, self.name, 'sample-%d-s-m.png' %(step+1)) scipy.misc.imsave(path, merged) print ('saved %s' %path) if (step+1) % self.log_step == 0: # save the model parameters for each epoch g12_path = os.path.join(self.model_path, self.name) checkAndCreateFolder(g12_path) g12_path = os.path.join(self.model_path, self.name, 'g12-%d.pkl' %(step+1)) g21_path = os.path.join(self.model_path, self.name, 'g21-%d.pkl' %(step+1)) d1_path = os.path.join(self.model_path, self.name, 'd1-%d.pkl' %(step+1)) d2_path = os.path.join(self.model_path, self.name, 'd2-%d.pkl' %(step+1)) torch.save(self.g12.state_dict(), g12_path) torch.save(self.g21.state_dict(), g21_path) torch.save(self.d1.state_dict(), d1_path) torch.save(self.d2.state_dict(), d2_path)
def test(primaryDataSet, helperDataSet, primaryClass, helperClass, primaryInstances, helperInstances): ''' Inputs : dataSets : List : Datasets for which samples are to be genrated instances : List : Number of instances to be used from original dataset classes : List : Classes for which samples are to be generated Output : File with 1000 compressed images generated by GAN ''' helperClass = primaryClass modelFolder = resultDir + 'models/crossDataSetMMDall'+'/'+primaryDataSet+'/' print (primaryDataSet, helperDataSet, primaryClass, helperClass, primaryInstances, helperInstances, getEpochs(primaryDataSet,primaryInstances)-1) modelFile = modelFolder + primaryDataSet + '_' + helperDataSet + '_' + \ str(primaryClass) + '_' + str(helperClass) + '_' + \ str(primaryInstances) + '_' + str(helperInstances)+'_'+ \ str(getEpochs(primaryDataSet,primaryInstances)-1)+'.pt' print ('Generating examples for Dataset: '+primaryDataSet+ ' Primary Class: '+str(primaryClass)+ ' Helper Class: '+str(helperClass)+ ' Primary Instances: '+str(primaryInstances)+ ' Helper Instances: '+str(helperInstances)+ ' Epochs: '+str(getEpochs(primaryDataSet,primaryInstances))) numOutputChannels = getChannels(primaryDataSet) # load the model learnt during training G = Generator(numInputChannels, numGenFilter, numOutputChannels) G.load_state_dict(torch.load(modelFile)) genImageConcat = np.empty(1) iterations = numOfSamples/batchSize for iteration in range(iterations): noise = torch.FloatTensor(batchSize, numInputChannels, 1, 1) noise.normal_(0,1) if cuda: G = G.cuda() noise = noise.cuda() noiseVariable = Variable(noise) genImage = G(noiseVariable) genImage = genImage.data genImage = genImage.cpu() genImage = genImage.numpy() if iteration==0: genImageConcat = genImage else: genImageConcat = np.concatenate((genImageConcat, genImage), axis=0) if iteration==(iterations-1): # normalize sets image pixels between 0 and 1 genImage = torchvision.utils.make_grid(torch.from_numpy(genImage[:25]), nrow=5, normalize=True) # mapping between 0 to 1 as required by imshow, # otherwise the images are stored in the form -1 to 1 # done through normalize=True genImage = genImage.permute(1,2,0) genImage = genImage.numpy() plt.imshow(genImage, cmap='gray') plotFolder = resultDir+'results'+'/'+'crossDataSetMMDall/samples'+'/'+primaryDataSet+'/' checkAndCreateFolder(plotFolder) plotFile = primaryDataSet + '_' + helperDataSet + '_' + \ str(primaryClass) + '_' + 'all' + '_' + \ str(primaryInstances) + '_' + str(helperInstances) plotPath = plotFolder + plotFile plt.axis('off') plt.savefig(plotPath, bbox_inches='tight') plt.show() resultFolder = resultDir+'results/crossDataSetMMDall'+'/'+'compressed'+'/'+primaryDataSet+'/' checkAndCreateFolder(resultFolder) resultFile = primaryDataSet + '_' + helperDataSet + '_' + \ str(primaryClass) + '_' + 'all' + '_' + \ str(primaryInstances) + '_' + str(helperInstances) + '.npy' resultPath = resultFolder + resultFile # save the image in some format with open(resultPath,'wb+') as fh: genImageConcat = np.squeeze(genImageConcat) np_save(fh, genImageConcat, allow_pickle=False) sync(fh)
def MMDhist(dataSet, primaryClass, primaryInstance, helperInstances, batchSize): # load the datasets for which the bar graph needs to be plotted x = loadDataset(dataSet, [primaryClass], [primaryInstance]) primaryClasses = [i for i in range(10)] primaryClasses = primaryClasses[:primaryClass] + primaryClasses[ primaryClass + 1:] primaryInstances = [primaryInstance for i in range(10)] y = loadDataset(dataSet, primaryClasses, primaryInstances) primaryTrainLoader = torch.utils.data.DataLoader(x, batch_size=batchSize, shuffle=True, num_workers=4, drop_last=True) helperTrainLoader = torch.utils.data.DataLoader(y, batch_size=batchSize, shuffle=True, num_workers=4, drop_last=True) mmdValues = [] minMMDValue = 1.5 maxMMDValue = -0.1 minPrimaryDataInstance = torch.FloatTensor(batchSize, getChannels(dataSet), getImageSize(dataSet), getImageSize(dataSet)).zero_() minHelperDataInstance = minPrimaryDataInstance maxPrimaryDataInstance = minPrimaryDataInstance maxHelperDataInstance = minPrimaryDataInstance for i, primaryData in enumerate(primaryTrainLoader, 0): primaryDataInstance, primaryDataLabel = primaryData for j, helperData in enumerate(helperTrainLoader, 0): helperDataInstance, helperDataLabel = helperData mmdValue = kernelTwoSampleTest(primaryDataInstance, helperDataInstance) # choosing the pair with minimum MMD if minMMDValue > mmdValue: minMMDValue = mmdValue minPrimaryDataInstance = primaryDataInstance minHelperDataInstance = helperDataInstance # choosing the pair with maximum MMD if maxMMDValue < mmdValue: maxMMDValue = mmdValue maxPrimaryDataInstance = primaryDataInstance maxHelperDataInstance = helperDataInstance mmdValues.append(mmdValue) #mmdValues.append (1-kernelTwoSampleTest(primaryDataInstance, helperDataInstance)) displayImage(maxPrimaryDataInstance, maxHelperDataInstance, maxMMDValue, dataSet, primaryClass, primaryInstance, helperInstances, 'max') displayImage(minPrimaryDataInstance, minHelperDataInstance, minMMDValue, dataSet, primaryClass, primaryInstance, helperInstances, 'min') mmdValues = np.asarray(mmdValues) plt.figure() plt.hist(mmdValues, ec='k') classes = getClasses(dataSet) plt.plot() plt.xlabel('$MMD^2$ between batch of Primary Class and Helper Class') plt.ylabel('Number of Batch Pairs have $MMD^2$ value in that range') # "\n".join(wrap()) plt.title( ' Dataset: {} \n Primary Class: {} \n Primary Instances:{} \n Helper Instances:{} \n Batch Size:{}' .format(dataSet, classes[primaryClass], primaryInstance, helperInstances, batchSize)) saveFolder = resultDir + 'mmdValues' + '/' + 'hist' + '/' + dataSet + '/' checkAndCreateFolder(saveFolder) saveFile = saveFolder + dataSet + '_' + str(primaryClass) + '_' + str( primaryInstance) + '_' + str(helperInstances) + '_' + str( batchSize) + '.png' plt.savefig(saveFile, bbox_inches='tight') plt.show()