Esempio n. 1
0
    def test_procedure(self, test_data, distribution_op, inputX, inputY, mode):
        confusion_matrics = np.zeros([self.num_class, self.num_class], dtype="int")

        tst_batch_num = int(np.ceil(test_data[0].shape[0] / self.bs))
        for step in range(tst_batch_num):
            _testImg = test_data[0][step * self.bs:step * self.bs + self.bs]
            _testLab = test_data[1][step * self.bs:step * self.bs + self.bs]

            matrix_row, matrix_col = self.sess.run(distribution_op, feed_dict={inputX: _testImg,
                                                                               inputY: _testLab,
                                                                               self.is_training: False,
                                                                               self.keep_rate: 0.5})
            for m, n in zip(matrix_row, matrix_col):
                confusion_matrics[m][n] += 1

        test_accuracy = float(np.sum([confusion_matrics[q][q] for q in range(self.num_class)])) / float(
            np.sum(confusion_matrics))
        detail_test_accuracy = [confusion_matrics[i][i] / np.sum(confusion_matrics[i]) for i in
                                range(self.num_class)]
        log0 = "Mode: " + mode
        log1 = "Test Accuracy : %g" % test_accuracy
        log2 = np.array(confusion_matrics.tolist())
        log3 = ''
        for j in range(self.num_class):
            log3 += 'category %s test accuracy : %g\n' % (da_utils.pulmonary_category[j], detail_test_accuracy[j])
        log3 = log3[:-1]
        log4 = 'F_Value : %g\n' % self.f_value(confusion_matrics)

        da_utils.save2file(log0, self.ckptDir, self.model)
        da_utils.save2file(log1, self.ckptDir, self.model)
        da_utils.save2file(log2, self.ckptDir, self.model)
        da_utils.save2file(log3, self.ckptDir, self.model)
        da_utils.save2file(log4, self.ckptDir, self.model)
Esempio n. 2
0
    def train(self):
        print(
            'Start to run in mode [Domain Adaptation Across Source and Target Domain]'
        )
        self.sess.run(tf.global_variables_initializer())
        self.preTrained_saver = tf.train.Saver(var_list=self.g2_preTrained_var)
        self.preTrained_saver.restore(self.sess, self.preTrained_path)
        print('Pre-trained model has been successfully restored !')

        self.train_itr = len(self.source_training_data[0]) // self.bs

        for e in range(1, self.eps + 1):
            _src_tr_img, _src_tr_lab = DA_init.shuffle_data(
                self.source_training_data[0], self.source_training_data[1])
            _tar_tr_img = DA_init.shuffle_data_nolabel(
                self.target_training_data)

            source_training_acc = 0.0
            source_training_loss = 0.0
            g_loss = 0.0
            d_loss = 0.0

            for itr in range(self.train_itr):
                _src_tr_img_batch, _src_tr_lab_batch = DA_init.next_batch(
                    _src_tr_img, _src_tr_lab, self.bs, itr)
                _tar_tr_img_batch = DA_init.next_batch_nolabel(
                    _tar_tr_img, self.bs)

                feed_dict = {
                    self.x_source: _src_tr_img_batch,
                    self.y_source: _src_tr_lab_batch,
                    self.x_target: _tar_tr_img_batch,
                    self.is_training: True,
                    self.keep_prob: self.kp
                }
                feed_dict_eval = {
                    self.x_source: _src_tr_img_batch,
                    self.y_source: _src_tr_lab_batch,
                    self.x_target: _tar_tr_img_batch,
                    self.is_training: False,
                    self.keep_prob: 1.0
                }

                if e < 100:
                    _ = self.sess.run(self.g_train_op_step1,
                                      feed_dict=feed_dict)
                    _training_accuracy, _training_loss = self.sess.run(
                        [self.accuracy_source, self.loss_source],
                        feed_dict=feed_dict_eval)

                    source_training_acc += _training_accuracy
                    source_training_loss += _training_loss

                elif e < 200:
                    _, _ = self.sess.run(
                        [self.g_train_op_step2, self.d_train_op_step1],
                        feed_dict=feed_dict)
                    _training_accuracy, _training_loss, _g_loss, _d_loss = self.sess.run(
                        [
                            self.accuracy_source, self.loss_source,
                            self.g_loss_step2, self.d_loss_step1
                        ],
                        feed_dict=feed_dict_eval)

                    source_training_acc += _training_accuracy
                    source_training_loss += _training_loss
                    g_loss += _g_loss
                    d_loss += _d_loss

                elif e < self.eps:
                    _, _ = self.sess.run(
                        [self.g_train_op_step3, self.d_train_op_step2],
                        feed_dict=feed_dict)
                    _training_accuracy, _training_loss, _g_loss, _d_loss = self.sess.run(
                        [
                            self.accuracy_source, self.loss_source,
                            self.g_loss_step3, self.d_loss_step2
                        ],
                        feed_dict=feed_dict_eval)

                    source_training_acc += _training_accuracy
                    source_training_loss += _training_loss
                    g_loss += _g_loss
                    d_loss += _d_loss

            summary = self.sess.run(self.merged, feed_dict=feed_dict_eval)

            source_training_acc = float(source_training_acc / self.train_itr)
            source_training_loss = float(source_training_loss / self.train_itr)
            g_loss = float(g_loss / self.train_itr)
            d_loss = float(d_loss / self.train_itr)

            source_validation_acc, source_validation_loss = self.validation_procedure(
                validation_data=self.source_validation_data,
                distribution_op=self.distribution_source,
                loss_op=self.loss_source,
                inputX=self.x_source,
                inputY=self.y_source)

            log1 = "Epoch: [%d], Domain: Source, Training Accuracy: [%g], Validation Accuracy: [%g], " \
                   "Training Loss: [%g], Validation Loss: [%g], generator Loss: [%g], Discriminator Loss: [%g], " \
                   "Time: [%s]" % (
                       e, source_training_acc, source_validation_acc, source_training_loss, source_validation_loss,
                       g_loss, d_loss, time.ctime(time.time()))

            self.plt_epoch.append(e)
            self.plt_training_accuracy.append(source_training_acc)
            self.plt_training_loss.append(source_training_loss)
            self.plt_validation_accuracy.append(source_validation_acc)
            self.plt_validation_loss.append(source_validation_loss)
            self.plt_d_loss.append(d_loss)
            self.plt_g_loss.append(g_loss)

            da_utils.plotAccuracy(x=self.plt_epoch,
                                  y1=self.plt_training_accuracy,
                                  y2=self.plt_validation_accuracy,
                                  figName=self.model,
                                  line1Name='training',
                                  line2Name='validation',
                                  savePath=self.ckptDir)

            da_utils.plotLoss(x=self.plt_epoch,
                              y1=self.plt_training_loss,
                              y2=self.plt_validation_loss,
                              figName=self.model,
                              line1Name='training',
                              line2Name='validation',
                              savePath=self.ckptDir)

            da_utils.plotLoss(x=self.plt_epoch,
                              y1=self.plt_d_loss,
                              y2=self.plt_g_loss,
                              figName=self.model + '_GD_Loss',
                              line1Name='D_Loss',
                              line2Name='G_Loss',
                              savePath=self.ckptDir)

            da_utils.save2file(log1, self.ckptDir, self.model)

            self.writer.add_summary(summary, e)

            self.saver.save(self.sess,
                            self.ckptDir + self.model + '-' + str(e))

            self.test_procedure(self.source_test_data,
                                distribution_op=self.distribution_source,
                                inputX=self.x_source,
                                inputY=self.y_source,
                                mode='source')
            self.test_procedure(self.target_test_data,
                                distribution_op=self.distribution_target,
                                inputX=self.x_target,
                                inputY=self.y_target,
                                mode='target')
Esempio n. 3
0
 def saveConfiguration(self):
     da_utils.save2file('epoch : %d' % self.eps, self.ckptDir, self.model)
     da_utils.save2file('restore epoch : %d' % self.res_eps, self.ckptDir,
                        self.model)
     da_utils.save2file('model : %s' % self.model, self.ckptDir, self.model)
     da_utils.save2file('learning rate : %g' % self.lr, self.ckptDir,
                        self.model)
     da_utils.save2file('batch size : %d' % self.bs, self.ckptDir,
                        self.model)
     da_utils.save2file('image height : %d' % self.img_h, self.ckptDir,
                        self.model)
     da_utils.save2file('image width : %d' % self.img_w, self.ckptDir,
                        self.model)
     da_utils.save2file('num class : %d' % self.num_class, self.ckptDir,
                        self.model)
     da_utils.save2file('train phase : %s' % self.train_phase, self.ckptDir,
                        self.model)
Esempio n. 4
0
    def train(self):
        print('Start to run in mode [Domain Adaptation Across Source and Target Domain]')
        self.sess.run(tf.global_variables_initializer())
        self.reloadSaver = tf.train.Saver(var_list=self.g_var)
        self.reloadSaver.restore(self.sess, self.reloadPath)
        print('Pre-trained G2 model has been successfully reloaded !')

        self.itr_epoch = len(self.source_training_data[0]) // self.bs

        source_training_acc = 0.0
        source_training_loss = 0.0

        for e in range(1, self.eps + 1):
            for itr in range(self.itr_epoch):
                feed_dict_train, feed_dict_eval = self.getBatchData()
                _ = self.sess.run(self.d_train_op, feed_dict=feed_dict_train)

                feed_dict_train, feed_dict_eval = self.getBatchData()
                _ = self.sess.run(self.g_train_op, feed_dict=feed_dict_train)

                _training_accuracy, _training_loss = self.sess.run([self.accuracy_source, self.supervised_loss],
                                                                   feed_dict=feed_dict_eval)

                source_training_acc += _training_accuracy
                source_training_loss += _training_loss

            summary = self.sess.run(self.merged, feed_dict=feed_dict_eval)

            source_training_acc = float(source_training_acc / self.itr_epoch)
            source_training_loss = float(source_training_loss / self.itr_epoch)

            source_validation_acc, source_validation_loss = self.validation_procedure(
                validation_data=self.source_validation_data, distribution_op=self.distribution_source,
                loss_op=self.supervised_loss, inputX=self.x_source, inputY=self.y_source)

            log1 = "Epoch: [%d], Domain: Source, Training Accuracy: [%g], Validation Accuracy: [%g], " \
                   "Training Loss: [%g], Validation Loss: [%g], Time: [%s]" % (
                       e, source_training_acc, source_validation_acc, source_training_loss, source_validation_loss,
                       time.ctime(time.time()))

            self.plt_epoch.append(e)
            self.plt_training_accuracy.append(source_training_acc)
            self.plt_training_loss.append(source_training_loss)
            self.plt_validation_accuracy.append(source_validation_acc)
            self.plt_validation_loss.append(source_validation_loss)

            da_utils.plotAccuracy(x=self.plt_epoch,
                                    y1=self.plt_training_accuracy,
                                    y2=self.plt_validation_accuracy,
                                    figName=self.model,
                                    line1Name='training',
                                    line2Name='validation',
                                    savePath=self.ckptDir)

            da_utils.plotLoss(x=self.plt_epoch,
                                y1=self.plt_training_loss,
                                y2=self.plt_validation_loss,
                                figName=self.model,
                                line1Name='training',
                                line2Name='validation',
                                savePath=self.ckptDir)

            da_utils.save2file(log1, self.ckptDir, self.model)

            self.writer.add_summary(summary, e)

            self.saver.save(self.sess, self.ckptDir + self.model + '-' + str(e))

            self.test_procedure(self.source_test_data, distribution_op=self.distribution_source,
                                inputX=self.x_source,
                                inputY=self.y_source, mode='source')
            self.test_procedure(self.target_test_data, distribution_op=self.distribution_target,
                                inputX=self.x_target,
                                inputY=self.y_target, mode='target')

            source_training_acc = 0.0
            source_training_loss = 0.0