def test_procedure(self, test_data, distribution_op, inputX, inputY, mode): confusion_matrics = np.zeros([self.num_class, self.num_class], dtype="int") tst_batch_num = int(np.ceil(test_data[0].shape[0] / self.bs)) for step in range(tst_batch_num): _testImg = test_data[0][step * self.bs:step * self.bs + self.bs] _testLab = test_data[1][step * self.bs:step * self.bs + self.bs] matrix_row, matrix_col = self.sess.run(distribution_op, feed_dict={inputX: _testImg, inputY: _testLab, self.is_training: False, self.keep_rate: 0.5}) for m, n in zip(matrix_row, matrix_col): confusion_matrics[m][n] += 1 test_accuracy = float(np.sum([confusion_matrics[q][q] for q in range(self.num_class)])) / float( np.sum(confusion_matrics)) detail_test_accuracy = [confusion_matrics[i][i] / np.sum(confusion_matrics[i]) for i in range(self.num_class)] log0 = "Mode: " + mode log1 = "Test Accuracy : %g" % test_accuracy log2 = np.array(confusion_matrics.tolist()) log3 = '' for j in range(self.num_class): log3 += 'category %s test accuracy : %g\n' % (da_utils.pulmonary_category[j], detail_test_accuracy[j]) log3 = log3[:-1] log4 = 'F_Value : %g\n' % self.f_value(confusion_matrics) da_utils.save2file(log0, self.ckptDir, self.model) da_utils.save2file(log1, self.ckptDir, self.model) da_utils.save2file(log2, self.ckptDir, self.model) da_utils.save2file(log3, self.ckptDir, self.model) da_utils.save2file(log4, self.ckptDir, self.model)
def train(self): print( 'Start to run in mode [Domain Adaptation Across Source and Target Domain]' ) self.sess.run(tf.global_variables_initializer()) self.preTrained_saver = tf.train.Saver(var_list=self.g2_preTrained_var) self.preTrained_saver.restore(self.sess, self.preTrained_path) print('Pre-trained model has been successfully restored !') self.train_itr = len(self.source_training_data[0]) // self.bs for e in range(1, self.eps + 1): _src_tr_img, _src_tr_lab = DA_init.shuffle_data( self.source_training_data[0], self.source_training_data[1]) _tar_tr_img = DA_init.shuffle_data_nolabel( self.target_training_data) source_training_acc = 0.0 source_training_loss = 0.0 g_loss = 0.0 d_loss = 0.0 for itr in range(self.train_itr): _src_tr_img_batch, _src_tr_lab_batch = DA_init.next_batch( _src_tr_img, _src_tr_lab, self.bs, itr) _tar_tr_img_batch = DA_init.next_batch_nolabel( _tar_tr_img, self.bs) feed_dict = { self.x_source: _src_tr_img_batch, self.y_source: _src_tr_lab_batch, self.x_target: _tar_tr_img_batch, self.is_training: True, self.keep_prob: self.kp } feed_dict_eval = { self.x_source: _src_tr_img_batch, self.y_source: _src_tr_lab_batch, self.x_target: _tar_tr_img_batch, self.is_training: False, self.keep_prob: 1.0 } if e < 100: _ = self.sess.run(self.g_train_op_step1, feed_dict=feed_dict) _training_accuracy, _training_loss = self.sess.run( [self.accuracy_source, self.loss_source], feed_dict=feed_dict_eval) source_training_acc += _training_accuracy source_training_loss += _training_loss elif e < 200: _, _ = self.sess.run( [self.g_train_op_step2, self.d_train_op_step1], feed_dict=feed_dict) _training_accuracy, _training_loss, _g_loss, _d_loss = self.sess.run( [ self.accuracy_source, self.loss_source, self.g_loss_step2, self.d_loss_step1 ], feed_dict=feed_dict_eval) source_training_acc += _training_accuracy source_training_loss += _training_loss g_loss += _g_loss d_loss += _d_loss elif e < self.eps: _, _ = self.sess.run( [self.g_train_op_step3, self.d_train_op_step2], feed_dict=feed_dict) _training_accuracy, _training_loss, _g_loss, _d_loss = self.sess.run( [ self.accuracy_source, self.loss_source, self.g_loss_step3, self.d_loss_step2 ], feed_dict=feed_dict_eval) source_training_acc += _training_accuracy source_training_loss += _training_loss g_loss += _g_loss d_loss += _d_loss summary = self.sess.run(self.merged, feed_dict=feed_dict_eval) source_training_acc = float(source_training_acc / self.train_itr) source_training_loss = float(source_training_loss / self.train_itr) g_loss = float(g_loss / self.train_itr) d_loss = float(d_loss / self.train_itr) source_validation_acc, source_validation_loss = self.validation_procedure( validation_data=self.source_validation_data, distribution_op=self.distribution_source, loss_op=self.loss_source, inputX=self.x_source, inputY=self.y_source) log1 = "Epoch: [%d], Domain: Source, Training Accuracy: [%g], Validation Accuracy: [%g], " \ "Training Loss: [%g], Validation Loss: [%g], generator Loss: [%g], Discriminator Loss: [%g], " \ "Time: [%s]" % ( e, source_training_acc, source_validation_acc, source_training_loss, source_validation_loss, g_loss, d_loss, time.ctime(time.time())) self.plt_epoch.append(e) self.plt_training_accuracy.append(source_training_acc) self.plt_training_loss.append(source_training_loss) self.plt_validation_accuracy.append(source_validation_acc) self.plt_validation_loss.append(source_validation_loss) self.plt_d_loss.append(d_loss) self.plt_g_loss.append(g_loss) da_utils.plotAccuracy(x=self.plt_epoch, y1=self.plt_training_accuracy, y2=self.plt_validation_accuracy, figName=self.model, line1Name='training', line2Name='validation', savePath=self.ckptDir) da_utils.plotLoss(x=self.plt_epoch, y1=self.plt_training_loss, y2=self.plt_validation_loss, figName=self.model, line1Name='training', line2Name='validation', savePath=self.ckptDir) da_utils.plotLoss(x=self.plt_epoch, y1=self.plt_d_loss, y2=self.plt_g_loss, figName=self.model + '_GD_Loss', line1Name='D_Loss', line2Name='G_Loss', savePath=self.ckptDir) da_utils.save2file(log1, self.ckptDir, self.model) self.writer.add_summary(summary, e) self.saver.save(self.sess, self.ckptDir + self.model + '-' + str(e)) self.test_procedure(self.source_test_data, distribution_op=self.distribution_source, inputX=self.x_source, inputY=self.y_source, mode='source') self.test_procedure(self.target_test_data, distribution_op=self.distribution_target, inputX=self.x_target, inputY=self.y_target, mode='target')
def saveConfiguration(self): da_utils.save2file('epoch : %d' % self.eps, self.ckptDir, self.model) da_utils.save2file('restore epoch : %d' % self.res_eps, self.ckptDir, self.model) da_utils.save2file('model : %s' % self.model, self.ckptDir, self.model) da_utils.save2file('learning rate : %g' % self.lr, self.ckptDir, self.model) da_utils.save2file('batch size : %d' % self.bs, self.ckptDir, self.model) da_utils.save2file('image height : %d' % self.img_h, self.ckptDir, self.model) da_utils.save2file('image width : %d' % self.img_w, self.ckptDir, self.model) da_utils.save2file('num class : %d' % self.num_class, self.ckptDir, self.model) da_utils.save2file('train phase : %s' % self.train_phase, self.ckptDir, self.model)
def train(self): print('Start to run in mode [Domain Adaptation Across Source and Target Domain]') self.sess.run(tf.global_variables_initializer()) self.reloadSaver = tf.train.Saver(var_list=self.g_var) self.reloadSaver.restore(self.sess, self.reloadPath) print('Pre-trained G2 model has been successfully reloaded !') self.itr_epoch = len(self.source_training_data[0]) // self.bs source_training_acc = 0.0 source_training_loss = 0.0 for e in range(1, self.eps + 1): for itr in range(self.itr_epoch): feed_dict_train, feed_dict_eval = self.getBatchData() _ = self.sess.run(self.d_train_op, feed_dict=feed_dict_train) feed_dict_train, feed_dict_eval = self.getBatchData() _ = self.sess.run(self.g_train_op, feed_dict=feed_dict_train) _training_accuracy, _training_loss = self.sess.run([self.accuracy_source, self.supervised_loss], feed_dict=feed_dict_eval) source_training_acc += _training_accuracy source_training_loss += _training_loss summary = self.sess.run(self.merged, feed_dict=feed_dict_eval) source_training_acc = float(source_training_acc / self.itr_epoch) source_training_loss = float(source_training_loss / self.itr_epoch) source_validation_acc, source_validation_loss = self.validation_procedure( validation_data=self.source_validation_data, distribution_op=self.distribution_source, loss_op=self.supervised_loss, inputX=self.x_source, inputY=self.y_source) log1 = "Epoch: [%d], Domain: Source, Training Accuracy: [%g], Validation Accuracy: [%g], " \ "Training Loss: [%g], Validation Loss: [%g], Time: [%s]" % ( e, source_training_acc, source_validation_acc, source_training_loss, source_validation_loss, time.ctime(time.time())) self.plt_epoch.append(e) self.plt_training_accuracy.append(source_training_acc) self.plt_training_loss.append(source_training_loss) self.plt_validation_accuracy.append(source_validation_acc) self.plt_validation_loss.append(source_validation_loss) da_utils.plotAccuracy(x=self.plt_epoch, y1=self.plt_training_accuracy, y2=self.plt_validation_accuracy, figName=self.model, line1Name='training', line2Name='validation', savePath=self.ckptDir) da_utils.plotLoss(x=self.plt_epoch, y1=self.plt_training_loss, y2=self.plt_validation_loss, figName=self.model, line1Name='training', line2Name='validation', savePath=self.ckptDir) da_utils.save2file(log1, self.ckptDir, self.model) self.writer.add_summary(summary, e) self.saver.save(self.sess, self.ckptDir + self.model + '-' + str(e)) self.test_procedure(self.source_test_data, distribution_op=self.distribution_source, inputX=self.x_source, inputY=self.y_source, mode='source') self.test_procedure(self.target_test_data, distribution_op=self.distribution_target, inputX=self.x_target, inputY=self.y_target, mode='target') source_training_acc = 0.0 source_training_loss = 0.0