Esempio n. 1
0
    def train(self):
        # Create temp dir to hold img data and model
        try:
            if self.tmpdir1:
                self.tmpdir1.cleanup()
            if self.tmpdir2:
                self.tmpdir2.cleanup()
        except:
            pass
        self.tmpdir1 = tempfile.TemporaryDirectory(prefix='CNN_')
        print('Temporary img dir:', self.tmpdir1.name)
        self.tmpdir2 = tempfile.TemporaryDirectory(prefix='CNN_')
        print('Temporary model dir:', self.tmpdir2.name)

        # Find train segments belong to each class
        self.DataGen = CNN.GenerateData(self.currfilt, self.imgWidth,
                                        self.windowWidth, self.windowInc,
                                        self.imgsize[0], self.imgsize[1],
                                        self.f1, self.f2)

        # Find how many images with default hop (=imgWidth), adjust hop to make a good number of images also keep space
        # for some in-built augmenting (width-shift)
        hop = [self.imgWidth for i in range(len(self.calltypes) + 1)]
        imgN = self.DataGen.getImgCount(dirName=self.tmpdir1.name,
                                        dataset=self.traindata,
                                        hop=hop)
        print('Expected number of images when no overlap: ', imgN)
        print('Updating hop...')
        hop = self.updateHop(imgN, hop)
        imgN = self.DataGen.getImgCount(dirName=self.tmpdir1.name,
                                        dataset=self.traindata,
                                        hop=hop)
        print('Expected number of images with updated hop: ', imgN)

        print('Generating CNN images...')
        self.genImgDataset(hop)
        print('\nGenerated images:\n')
        for i in range(len(self.calltypes)):
            print("\t%s:\t%d\n" % (self.calltypes[i], self.Nimg[i]))
        print("\t%s:\t%d\n" % ("Noise", self.Nimg[-1]))

        # CNN training
        cnn = CNN.CNN(self.configdir, self.species, self.calltypes, self.fs,
                      self.imgWidth, self.windowWidth, self.windowInc,
                      self.imgsize[0], self.imgsize[1])

        # 1. Data augmentation
        print('Data augmenting...')
        filenames, labels = cnn.getImglist(self.tmpdir1.name)
        labels = np.argmax(labels, axis=1)
        ns = [
            np.shape(np.where(labels == i)[0])[0]
            for i in range(len(self.calltypes) + 1)
        ]
        # create image data augmentation generator in-build
        datagen = ImageDataGenerator(width_shift_range=0.3,
                                     fill_mode='nearest')
        # Data augmentation for each call type
        for ct in range(len(self.calltypes) + 1):
            if self.LearningDict['t'] - ns[ct] > self.LearningDict['batchsize']:
                # load this ct images
                samples = cnn.loadCTImg(
                    os.path.join(self.tmpdir1.name, str(ct)))
                # prepare iterator
                it = datagen.flow(samples,
                                  batch_size=self.LearningDict['batchsize'])
                # generate samples
                batch = it.next()
                for j in range(
                        int((self.LearningDict['t'] - ns[ct]) /
                            self.LearningDict['batchsize'])):
                    newbatch = it.next()
                    batch = np.vstack((batch, newbatch))
                # Save augmented data
                k = 0
                for sgRaw in batch:
                    np.save(
                        os.path.join(self.tmpdir1.name, str(ct),
                                     str(ct) + '_aug' + "%06d" % k + '.npy'),
                        sgRaw)
                    k += 1
                try:
                    del batch
                    del samples
                    del newbatch
                except:
                    pass
                gc.collect()

        # 2. TRAIN - use custom image generator
        filenamesall, labelsall = cnn.getImglist(self.tmpdir1.name)
        print('Final CNN images...')
        labelsalld = np.argmax(labelsall, axis=1)
        ns = [
            np.shape(np.where(labelsalld == i)[0])[0]
            for i in range(len(self.calltypes) + 1)
        ]
        for i in range(len(self.calltypes)):
            print("\t%s:\t%d\n" % (self.calltypes[i], ns[i]))
        print("\t%s:\t%d\n" % ("Noise", ns[-1]))

        filenamesall, labelsall = shuffle(filenamesall, labelsall)

        X_train_filenames, X_val_filenames, y_train, y_val = train_test_split(
            filenamesall,
            labelsall,
            test_size=self.LearningDict['test_size'],
            random_state=1)
        training_batch_generator = CNN.CustomGenerator(
            X_train_filenames, y_train, self.LearningDict['batchsize'],
            self.tmpdir1.name, cnn.imageheight, cnn.imagewidth, 1)
        validation_batch_generator = CNN.CustomGenerator(
            X_val_filenames, y_val, self.LearningDict['batchsize'],
            self.tmpdir1.name, cnn.imageheight, cnn.imagewidth, 1)

        print('Creating CNN architecture...')
        cnn.createArchitecture()

        print('Training...')
        cnn.train(modelsavepath=self.tmpdir2.name,
                  training_batch_generator=training_batch_generator,
                  validation_batch_generator=validation_batch_generator)
        print('Training complete!')

        self.bestThr = [[0, 0] for i in range(len(self.calltypes))]
        self.bestThrInd = [0 for i in range(len(self.calltypes))]

        # 3. Prepare ROC plots
        print('Generating ROC statistics...')
        # Load the model
        # Find best weights
        weights = []
        epoch = []
        for r, d, files in os.walk(self.tmpdir2.name):
            for f in files:
                if f.endswith('.h5') and 'weights' in f:
                    epoch.append(int(f.split('weights.')[-1][:2]))
                    weights.append(f)
            j = np.argmax(epoch)
            weightfile = weights[j]
        model = os.path.join(self.tmpdir2.name, 'model.json')
        self.bestweight = os.path.join(self.tmpdir2.name, weightfile)
        # Load the model and prepare
        jsonfile = open(model, 'r')
        loadedmodeljson = jsonfile.read()
        jsonfile.close()
        model = model_from_json(loadedmodeljson)
        # Load weights into new model
        model.load_weights(self.bestweight)
        # Compile the model
        model.compile(loss=self.LearningDict['loss'],
                      optimizer=self.LearningDict['optimizer'],
                      metrics=self.LearningDict['metrics'])
        # model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
        print('Loaded CNN model from ', self.tmpdir2.name)

        TPs = [0 for i in range(len(self.calltypes) + 1)]
        FPs = [0 for i in range(len(self.calltypes) + 1)]
        TNs = [0 for i in range(len(self.calltypes) + 1)]
        FNs = [0 for i in range(len(self.calltypes) + 1)]
        CTps = [[[] for i in range(len(self.calltypes) + 1)]
                for j in range(len(self.calltypes) + 1)]
        # Do all the plots based on Validation set (eliminate augmented?)
        # N = len(filenames)
        N = len(X_val_filenames)
        y_val = np.argmax(y_val, axis=1)
        print('Validation data: ', N)
        if os.path.isdir(self.tmpdir2.name):
            print('Model directory exists')
        else:
            print('Model directory DOES NOT exist')
        if os.path.isdir(self.tmpdir1.name):
            print('Img directory exists')
        else:
            print('Img directory DOES NOT exist')

        for i in range(int(np.ceil(N / self.LearningDict['batchsize_ROC']))):
            # imagesb = cnn.loadImgBatch(filenames[i * self.LearningDict['batchsize_ROC']:min((i + 1) * self.LearningDict['batchsize_ROC'], N)])
            # labelsb = labels[i * self.LearningDict['batchsize_ROC']:min((i + 1) * self.LearningDict['batchsize_ROC'], N)]
            imagesb = cnn.loadImgBatch(
                X_val_filenames[i * self.LearningDict['batchsize_ROC']:min(
                    (i + 1) * self.LearningDict['batchsize_ROC'], N)])
            labelsb = y_val[i * self.LearningDict['batchsize_ROC']:min(
                (i + 1) * self.LearningDict['batchsize_ROC'], N)]
            for ct in range(len(self.calltypes) + 1):
                res, ctp = self.testCT(
                    ct, imagesb, labelsb, model
                )  # res=[thrlist, TPs, FPs, TNs, FNs], ctp=[[0to0 probs], [0to1 probs], [0to2 probs]]
                for j in range(len(self.calltypes) + 1):
                    CTps[ct][j] += ctp[j]
                if TPs[ct] == 0:
                    TPs[ct] = res[1]
                    FPs[ct] = res[2]
                    TNs[ct] = res[3]
                    FNs[ct] = res[4]
                else:
                    TPs[ct] = [
                        TPs[ct][i] + res[1][i] for i in range(len(TPs[ct]))
                    ]
                    FPs[ct] = [
                        FPs[ct][i] + res[2][i] for i in range(len(FPs[ct]))
                    ]
                    TNs[ct] = [
                        TNs[ct][i] + res[3][i] for i in range(len(TNs[ct]))
                    ]
                    FNs[ct] = [
                        FNs[ct][i] + res[4][i] for i in range(len(FNs[ct]))
                    ]
        self.Thrs = res[0]
        print('Thrs: ', self.Thrs)
        print('validation TPs[0]: ', TPs[0])

        self.TPRs = [[0.0 for i in range(len(self.Thrs))]
                     for j in range(len(self.calltypes) + 1)]
        self.FPRs = [[0.0 for i in range(len(self.Thrs))]
                     for j in range(len(self.calltypes) + 1)]
        self.Precisions = [[0.0 for i in range(len(self.Thrs))]
                           for j in range(len(self.calltypes) + 1)]
        self.Accs = [[0.0 for i in range(len(self.Thrs))]
                     for j in range(len(self.calltypes) + 1)]

        plt.style.use('ggplot')
        fig, axs = plt.subplots(len(self.calltypes) + 1,
                                len(self.calltypes) + 1,
                                sharey=True,
                                sharex='col')

        for ct in range(len(self.calltypes) + 1):
            self.TPRs[ct] = [
                TPs[ct][i] / (TPs[ct][i] + FNs[ct][i])
                for i in range(len(self.Thrs))
            ]
            self.FPRs[ct] = [
                FPs[ct][i] / (TNs[ct][i] + FPs[ct][i])
                for i in range(len(self.Thrs))
            ]
            self.Precisions[ct] = [
                0.0 if (TPs[ct][i] + FPs[ct][i]) == 0 else TPs[ct][i] /
                (TPs[ct][i] + FPs[ct][i]) for i in range(len(self.Thrs))
            ]
            self.Accs[ct] = [
                (TPs[ct][i] + TNs[ct][i]) /
                (TPs[ct][i] + TNs[ct][i] + FPs[ct][i] + FNs[ct][i])
                for i in range(len(self.Thrs))
            ]

            # Temp plot is saved in train data directory - prediction probabilities for instances of current ct
            for i in range(len(self.calltypes) + 1):
                CTps[ct][i] = sorted(CTps[ct][i], key=float)
                axs[i, ct].plot(CTps[ct][i], 'k')
                axs[i, ct].plot(CTps[ct][i], 'bo')
                if ct == i == len(self.calltypes):
                    axs[i, 0].set_ylabel('Noise')
                    axs[0, ct].set_title('Noise')
                elif ct == i:
                    axs[i, 0].set_ylabel(str(self.calltypes[ct]))
                    axs[0, ct].set_title(str(self.calltypes[ct]))
                if i == len(self.calltypes):
                    axs[i, ct].set_xlabel('Number of samples')
        fig.suptitle('Human')
        if self.folderTrain1:
            fig.savefig(os.path.join(self.folderTrain1,
                                     'validation-plots.png'))
            print('Validation plot is saved: ',
                  os.path.join(self.folderTrain1, 'validation-plots.png'))
        else:
            fig.savefig(os.path.join(self.folderTrain2,
                                     'validation-plots.png'))
            print('Validation plot is saved: ',
                  os.path.join(self.folderTrain2, 'validation-plots.png'))
        plt.close()

        # Collate ROC daaa
        self.ROCdata["TPR"] = self.TPRs
        self.ROCdata["FPR"] = self.FPRs
        self.ROCdata["thr"] = self.Thrs
        print('TPR: ', self.ROCdata["TPR"])
        print('FPR: ', self.ROCdata["FPR"])

        # 4. Auto select the upper threshold (fpr = 0)
        for ct in range(len(self.calltypes)):
            try:
                self.bestThr[ct][1] = self.Thrs[self.FPRs[ct].index(0.0)]
            except:
                self.bestThr[ct][1] = self.Thrs[len(self.FPRs[ct]) - 1]

        # 5. Auto select lower threshold IF the user asked so
        if self.autoThr:
            for ct in range(len(self.calltypes)):
                # Get min distance to ROC from (0 FPR, 1 TPR)
                distarr = (np.float64(1) - self.TPRs[ct])**2 + (
                    np.float64(0) - self.FPRs[ct])**2
                self.thr_min_ind = np.unravel_index(np.argmin(distarr),
                                                    distarr.shape)[0]
                self.bestThr[ct][0] = self.Thrs[self.thr_min_ind]
                self.bestThrInd[ct] = self.thr_min_ind
        return True