Example #1
0
 def train(self):
     if self.conf.reload_step > 0:
         self.reload(self.conf.reload_step)
     train_reader = H5DataLoader(self.conf.data_dir + self.conf.train_data)
     valid_reader = H5DataLoader(self.conf.data_dir + self.conf.valid_data)
     for epoch_num in range(self.conf.max_step + 1):
         if epoch_num and epoch_num % self.conf.test_interval == 0:
             inputs, labels = valid_reader.next_batch(self.conf.batch)
             feed_dict = {self.inputs: inputs, self.labels: labels}
             loss, summary = self.sess.run(
                 [self.loss_op, self.valid_summary], feed_dict=feed_dict)
             self.save_summary(summary, epoch_num + self.conf.reload_step)
             print('----testing loss', loss)
         if epoch_num and epoch_num % self.conf.summary_interval == 0:
             inputs, labels = train_reader.next_batch(self.conf.batch)
             feed_dict = {self.inputs: inputs, self.labels: labels}
             loss, _, summary = self.sess.run(
                 [self.loss_op, self.train_op, self.train_summary],
                 feed_dict=feed_dict)
             self.save_summary(summary, epoch_num + self.conf.reload_step)
         else:
             inputs, labels = train_reader.next_batch(self.conf.batch)
             feed_dict = {self.inputs: inputs, self.labels: labels}
             loss, _ = self.sess.run([self.loss_op, self.train_op],
                                     feed_dict=feed_dict)
             print('----training loss', loss)
         if epoch_num and epoch_num % self.conf.save_interval == 0:
             self.save(epoch_num + self.conf.reload_step)
 def train(self):
     if self.conf.reload_step > 0:
         self.reload(self.conf.reload_step)
     if self.conf.data_type == '2D':
         train_reader = H5DataLoader(self.conf.data_dir +
                                     self.conf.train_data)
         valid_reader = H5DataLoader(self.conf.data_dir +
                                     self.conf.valid_data)
     else:
         train_reader = H53DDataLoader(
             self.conf.data_dir + self.conf.train_data, self.input_shape)
         valid_reader = H53DDataLoader(
             self.conf.data_dir + self.conf.valid_data, self.input_shape)
     for epoch_num in range(self.conf.max_step):
         if epoch_num % self.conf.test_interval == 0:
             inputs, annotations = valid_reader.next_batch(self.conf.batch)
             feed_dict = {
                 self.inputs: inputs,
                 self.annotations: annotations
             }
             loss, summary, accuracy, dice_accuracy = self.sess.run(
                 [
                     self.loss_op, self.valid_summary, self.accuracy_op,
                     self.dice_accuracy_op
                 ],
                 feed_dict=feed_dict)
             self.save_summary(summary, epoch_num + self.conf.reload_step)
             print('----valid loss', loss)
             print('----valid accuracy', accuracy)
             print('----valid dice accuracy', dice_accuracy)
         elif epoch_num % self.conf.summary_interval == 0:
             inputs, annotations = train_reader.next_batch(self.conf.batch)
             feed_dict = {
                 self.inputs: inputs,
                 self.annotations: annotations
             }
             loss, _, summary = self.sess.run(
                 [self.loss_op, self.train_op, self.train_summary],
                 feed_dict=feed_dict)
             self.save_summary(summary, epoch_num + self.conf.reload_step)
         else:
             inputs, annotations = train_reader.next_batch(self.conf.batch)
             feed_dict = {
                 self.inputs: inputs,
                 self.annotations: annotations
             }
             loss, summary, _, accuracy, dice_accuracy= self.sess.run(
                 [self.loss_op, self.train_summary, self.train_op, self.accuracy_op, \
                 self.dice_accuracy_op], feed_dict=feed_dict)
             print('----train loss', loss)
             print('----train accuracy', accuracy)
             print('----train dice accuracy', dice_accuracy)
             self.save_summary(summary, epoch_num + self.conf.reload_step)
         if epoch_num % self.conf.save_interval == 0:
             self.save(epoch_num + self.conf.reload_step)
Example #3
0
 def train(self):
     if self.conf.reload_step > 0:
         self.reload(self.conf.reload_step)
     train_reader = H5DataLoader(
         self.conf.data_dir+self.conf.train_data)
     valid_reader = H5DataLoader(
         self.conf.data_dir+self.conf.valid_data)
     iteration = train_reader.iter + self.conf.reload_step
     pre_iter = iteration
     epoch_num = 0
     while iteration < self.conf.max_step:
         if pre_iter != iteration:
             pre_iter = iteration
             inputs, labels, catgory = valid_reader.next_batch(self.conf.batch)
             feed_dict = {self.inputs: inputs,
                          self.labels: labels,
                          self.catgory: catgory}
             loss, summary = self.sess.run(
                 [self.d_loss_total, self.valid_summary], feed_dict=feed_dict)
             self.save_summary(summary, iteration)
             print('----testing d loss', loss)
             loss, summary = self.sess.run(
                 [self.g_loss_total, self.valid_summary], feed_dict=feed_dict)
             self.save_summary(summary, iteration)
             self.save(iteration)
             print('----testing g loss', loss)
         elif epoch_num % self.conf.summary_interval == 0:
             inputs, labels, catgory = train_reader.next_batch(self.conf.batch)
             feed_dict = {self.inputs: inputs,
                          self.labels: labels,
                          self.catgory: catgory}
             loss, _, summary = self.sess.run(
                 [self.d_loss_total, self.d_train, self.train_summary], feed_dict=feed_dict)
             self.save_summary(summary, epoch_num+self.conf.reload_step)
             loss, _, summary = self.sess.run(
                 [self.g_loss_total, self.g_train, self.train_summary],
                 feed_dict=feed_dict)
             self.save_summary(summary, epoch_num+self.conf.reload_step)
         else:
             inputs, labels, catgory = train_reader.next_batch(self.conf.batch)
             feed_dict = {self.inputs: inputs,
                          self.labels: labels,
                          self.catgory: catgory}
             loss, _, summary = self.sess.run(
                 [self.d_loss_total, self.d_train, self.train_summary], feed_dict=feed_dict)
             print('----training d loss', loss)
             loss, _, summary = self.sess.run(
                 [self.g_loss_total, self.g_train, self.train_summary], feed_dict=feed_dict)
             print('----training g loss', loss)
         iteration = train_reader.iter + self.conf.reload_step
         epoch_num += 1
 def test(self):
     sig = True
     print('---->testing ', self.conf.test_step)
     if self.conf.test_step > 0:
         self.reload(self.conf.test_step)
     else:
         print("please set a reasonable test_step")
         return
     if self.conf.data_type == '2D':
         test_reader = H5DataLoader(
             self.conf.data_dir + self.conf.test_data, False)
     else:
         test_reader = H53DDataLoader(
             self.conf.data_dir + self.conf.test_data, self.input_shape)
     self.sess.run(tf.local_variables_initializer())
     count = 0
     losses = []
     accuracies = []
     m_ious = []
     while sig:
         sig, inputs, labels = test_reader.next_batch(self.conf.batch)
         if inputs is None:
             break
         feed_dict = {self.inputs: inputs, self.labels: labels}
         loss, accuracy, m_iou, _ = self.sess.run(
             [self.loss_op, self.accuracy_op, self.m_iou, self.miou_op],
             feed_dict=feed_dict)
         print('values----->', loss, accuracy, m_iou)
         count += 1
         losses.append(loss)
         accuracies.append(accuracy)
         m_ious.append(m_iou)
     print('Loss: ', np.mean(losses))
     print('Accuracy: ', np.mean(accuracies))
     print('M_iou: ', m_ious[-1])
Example #5
0
 def store(self):
     print('---->storing ', self.conf.test_step)
     test_reader = H5DataLoader(self.conf.data_dir + self.conf.test_data,
                                False)
     images = []
     ground_truth = []
     while True:
         inputs, annotations = test_reader.next_batch(self.conf.batch)
         if inputs.shape[0] < self.conf.batch:
             break
         images.append(inputs)
         ground_truth.append(annotations)
     print(images)
     print('----->saving inputs and annotations')
     for index, image in enumerate(images):
         print(index)
         for i in range(image.shape[0]):
             scipy.misc.imsave(
                 "JPEGImages/" + str(index * image.shape[0] + i) + '.jpg',
                 image[i])
     print("Done storing JPEG imeges")
     for index, annotation in enumerate(ground_truth):
         print(index)
         for i in range(annotation.shape[0]):
             imsave(
                 annotation[i], 'Annotations/' +
                 str(index * annotation.shape[0] + i) + '.png')
     print("Done storing annotations")
Example #6
0
 def test(self):
     print('---->predicting ', self.conf.test_step)
     if self.conf.test_step > 0:
         self.reload(self.conf.test_step)
     else:
         print("please set a reasonable test_step")
         return
     test_reader = H5DataLoader(self.conf.data_dir + self.conf.test_data,
                                True)
     predictions = []
     labels = []
     while test_reader.iter < 1:
         inputs, annotations = test_reader.next_batch(self.conf.batch)
         feed_dict = {self.inputs: inputs, self.annotations: annotations}
         res, acc = self.sess.run(
             [self.decoded_predictions, self.accuracy_op],
             feed_dict=feed_dict)
         print(acc)
         #res = np.concatenate(res,axis=0)
         predictions.append(res)
         labels.append(annotations)
     predictions = np.concatenate(predictions, axis=0)
     labels = np.concatenate(labels, axis=0)
     print('---', predictions.shape)
     print(labels.shape)
     np.savez('temp', predictions, labels)
     print(predictions.shape)
     print(ops.dice_ratio(predictions[0], labels[0]))
     print(ops.dice_ratio(predictions[1], labels[1]))
Example #7
0
 def predict(self):
     print('---->predicting ', self.conf.test_step)
     if self.conf.test_step > 0:
         self.reload(self.conf.test_step)
     else:
         print("please set a reasonable test_step")
         return
     if self.conf.data_type == '2D':
         test_reader = H5DataLoader(
             self.conf.data_dir+self.conf.test_data, False)
     else:
         test_reader = H53DDataLoader(
             self.conf.data_dir+self.conf.test_data, self.input_shape)
     predictions = []
     while True:
         inputs, labels = test_reader.next_batch(self.conf.batch)
         if inputs.shape[0] < self.conf.batch:
             break
         feed_dict = {self.inputs: inputs, self.labels: labels}
         predictions.append(self.sess.run(
             self.decoded_preds, feed_dict=feed_dict))
     print('----->saving predictions')
     for index, prediction in enumerate(predictions):
         for i in range(prediction.shape[0]):
             imsave(prediction[i], self.conf.sampledir +
                    str(index*prediction.shape[0]+i)+'.png')
Example #8
0
 def train(self):
     if self.conf.reload_step > 0:
         self.reload(self.conf.reload_step)
     train_reader = H5DataLoader(self.conf.data_dir + self.conf.train_data)
     valid_reader = H5DataLoader(self.conf.data_dir + self.conf.valid_data)
     iteration = train_reader.iter
     pre_iter = iteration
     epoch_num = 0
     while iteration < self.conf.max_step:
         if pre_iter != iteration:
             pre_iter = iteration
             inputs, annotations = valid_reader.next_batch(self.conf.batch)
             feed_dict = {
                 self.inputs: inputs,
                 self.annotations: annotations
             }
             loss, summary = self.sess.run(
                 [self.loss_op, self.valid_summary], feed_dict=feed_dict)
             self.save_summary(summary, iteration)
             self.save(iteration)
             print('----testing loss', loss)
         elif epoch_num % self.conf.summary_interval == 0:
             inputs, annotations = train_reader.next_batch(self.conf.batch)
             feed_dict = {
                 self.inputs: inputs,
                 self.annotations: annotations
             }
             loss, _, summary = self.sess.run(
                 [self.loss_op, self.train_op, self.train_summary],
                 feed_dict=feed_dict)
             self.save_summary(summary, epoch_num + self.conf.reload_step)
         else:
             inputs, annotations = train_reader.next_batch(self.conf.batch)
             feed_dict = {
                 self.inputs: inputs,
                 self.annotations: annotations
             }
             loss, _ = self.sess.run([self.loss_op, self.train_op],
                                     feed_dict=feed_dict)
             print('----training loss', loss)
         iteration = train_reader.iter
         epoch_num += 1
Example #9
0
 def train(self):
     self.restore()
     self.sess.run(tf.local_variables_initializer())
     train_reader = H5DataLoader(self.conf.data_dir + self.conf.train_data)
     valid_reader = H5DataLoader(self.conf.data_dir + self.conf.valid_data)
     start_step = 0 if self.global_step is None else self.global_step + 1
     for epoch_num in range(start_step, self.conf.max_step + 1):
         print(epoch_num)
         if epoch_num % self.conf.test_interval == 0:
             inputs, annotations = valid_reader.next_batch(self.conf.batch)
             feed_dict = {
                 self.inputs: inputs,
                 self.annotations: annotations
             }
             loss, summary = self.sess.run(
                 [self.loss_op, self.valid_summary], feed_dict=feed_dict)
             self.save_summary(summary, epoch_num + self.conf.reload_step)
             print("Step: %d, Test_loss:%g" % (epoch_num, loss))
         if epoch_num % self.conf.summary_interval == 0:
             inputs, annotations = train_reader.next_batch(self.conf.batch)
             feed_dict = {
                 self.inputs: inputs,
                 self.annotations: annotations
             }
             loss, _, summary = self.sess.run(
                 [self.loss_op, self.train_op, self.train_summary],
                 feed_dict=feed_dict)
             self.save_summary(summary, epoch_num + self.conf.reload_step)
         else:
             inputs, annotations = train_reader.next_batch(self.conf.batch)
             feed_dict = {
                 self.inputs: inputs,
                 self.annotations: annotations
             }
             loss, _ = self.sess.run([self.loss_op, self.train_op],
                                     feed_dict=feed_dict)
             print("Step: %d, Train_loss:%g" % (epoch_num, loss))
         if epoch_num % self.conf.save_interval == 0:
             self.save(epoch_num + self.conf.reload_step)
Example #10
0
    def test(self):
        super(base_model_metric, self).test()
        '''test the metric learning part in the end-to-end model'''
        print('testing metric learning results...')
        train_reader = GenDataLoader(model_type='unpaired',
                                     conf=self.conf,
                                     portion=self.conf.portion)
        test_reader = H5DataLoader(self.conf.data_dir + self.conf.test_data,
                                   model_type='paired',
                                   is_train=False)

        acc, f1, RI, purity = self.calculate_cost(self.sess, train_reader,
                                                  test_reader)
        print(
            'knn accuracy f1, and cluster purity, RI &%0.4f &%0.4f &&%0.4f &%0.4f'
            % (acc, f1, purity, RI))
    def predict(self):
        print('---->predicting ', self.conf.test_step)
        if self.conf.test_step > 0:
            self.reload(self.conf.test_step)
        else:
            print("please set a reasonable test_step")
            return
        if self.conf.data_type == '2D':
            test_reader = H5DataLoader(
                self.conf.data_dir + self.conf.test_data, False)
        else:
            test_reader = H53DDataLoader(
                self.conf.data_dir + self.conf.test_data, self.input_shape)

        predictions = []
        sig = True
        final_inputs = []
        final_labels = []
        while sig:
            sig, inputs, labels = test_reader.next_batch(self.conf.batch)
            if inputs is None:
                break

            final_inputs.append(inputs)
            final_labels.append(labels)
            feed_dict = {self.inputs: inputs, self.labels: labels}
            predictions.append(
                self.sess.run(self.decoded_preds, feed_dict=feed_dict))
        final_inputs = np.array(final_inputs, 'uint8')
        final_labels = np.array(final_labels, 'uint8')

        print('----->saving predictions')
        for index, prediction in enumerate(predictions):
            for i in range(prediction.shape[0]):
                imsave(
                    prediction[i], self.conf.sampledir +
                    str(index * prediction.shape[0] + i) + '.png')
                scipy.misc.imsave(
                    self.conf.sampledir +
                    str(index * prediction.shape[0] + i) + '.jpg',
                    final_inputs[index][i])
                scipy.misc.imsave(
                    self.conf.sampledir +
                    str(index * prediction.shape[0] + i) + '_label.png',
                    final_labels[index][i] * 45)
Example #12
0
 def test(self):
     print('---->testing ', self.conf.test_step)
     if self.conf.test_step > 0:
         self.reload(self.conf.test_step)
     else:
         print("please set a reasonable test_step")
         return
     test_reader = H5DataLoader(self.conf.data_dir + self.conf.test_data,
                                False)
     self.sess.run(tf.local_variables_initializer())
     count = 0
     losses = []
     accuracies = []
     m_ious = []
     confusion_matrix_total = tf.zeros(
         [self.conf.class_num, self.conf.class_num], tf.int32)
     #start = time.time()
     while True:
         inputs, annotations = test_reader.next_batch(self.conf.batch)
         if inputs.shape[0] < self.conf.batch:
             break
         feed_dict = {self.inputs: inputs, self.annotations: annotations}
         loss, accuracy, m_iou, _, confusion_matrix, decoded_predictions = self.sess.run(
             [
                 self.loss_op, self.accuracy_op, self.m_iou, self.miou_op,
                 self.confusion_matrix, self.decoded_predictions
             ],
             feed_dict=feed_dict)
         print('values----->', loss, accuracy, m_iou)
         count += 1
         losses.append(loss)
         accuracies.append(accuracy)
         m_ious.append(m_iou)
         confusion_matrix_total = tf.add(confusion_matrix_total,
                                         confusion_matrix)
     print('Loss: ', np.mean(losses))
     print('Accuracy: ', np.mean(accuracies))
     print('M_iou: ', m_ious[-1])
     print('Confusion Matrix:')
     print('Dumping confusion matrix')
     with open('data_best.pickle', 'wb') as f:
         pickle.dump(confusion_matrix_total.eval(), f,
                     pickle.HIGHEST_PROTOCOL)
     #end = time.time()
     print(confusion_matrix_total.eval())
Example #13
0
 def test(self):
     print('---->testing ', self.conf.test_step)
     if self.conf.test_step > 0:
         self.reload(self.conf.test_step)
     else:
         print("please set a reasonable test_step")
         return
     test_reader = H5DataLoader(self.conf.data_dir + self.conf.test_data,
                                False)
     accuracies = []
     while True:
         inputs, labels = test_reader.next_batch(self.conf.batch)
         if inputs is None or inputs.shape[0] < self.conf.batch:
             break
         feed_dict = {self.inputs: inputs, self.labels: labels}
         accur = self.sess.run(self.accuracy_op, feed_dict=feed_dict)
         accuracies.append(accur)
     print('accuracy is ', sum(accuracies) / len(accuracies))
Example #14
0
    def train(self):
        if self.conf.reload_step > 0:
            self.reload(self.conf.reload_step)
        train_pair_reader = H5DataLoader(self.conf.data_dir +
                                         self.conf.train_pair,
                                         model_type='paired_metric',
                                         portion=self.conf.portion)
        train_unpair_reader = H5DataLoader(self.conf.data_dir +
                                           self.conf.train_unpair,
                                           conf=self.conf,
                                           model_type='unpaired_metric')
        valid_reader = H5DataLoader(self.conf.data_dir + self.conf.valid_data,
                                    model_type='paired_metric',
                                    is_train=False)

        train_p2 = GenDataLoader(model_type='unpaired',
                                 conf=self.conf,
                                 portion=self.conf.portion)
        valid_p2 = H5DataLoader(self.conf.data_dir + self.conf.valid_data,
                                model_type='paired',
                                is_train=False)

        iteration = train_pair_reader.iter + self.conf.reload_step
        pre_iter = iteration
        epoch_num = self.conf.reload_step
        start_time = time.time()
        bestsim = 0
        best_acc = 0
        epochs_no_performance_gain = 0
        while epoch_num < self.conf.max_step:
            Ap, Bp, p_label = train_pair_reader.next_batch(self.conf.batch)
            Au, Bu, up_label = train_unpair_reader.next_batch(self.conf.batch)
            pl_lab = get_pairwiselabel(np.concatenate((p_label, up_label), 0))
            feed_dict = {
                self.a_p: Ap,
                self.b_p: Bp,
                self.cat: p_label,
                self.a_u: Au,
                self.b_u: Bu,
                self.pl_total: pl_lab
            }
            d_loss, m_loss_d, _ = self.sess.run(
                [self.d_loss_total, self.m_loss, self.d_train],
                feed_dict=feed_dict)
            g_loss, g1, g2, m_loss_g, _, tmp = self.sess.run(
                [
                    self.g_loss_total, self.g1_loss, self.g2_loss, self.m_loss,
                    self.g_train, self.tmp
                ],
                feed_dict=feed_dict)

            if epoch_num % 200 == 0:
                print(
                    'epoch %d, duration %0.2f, d_loss %0.3f, m_loss_d %0.3f, g_loss %0.3f, g1 %0.3f, g2 %0.3f, m_loss_g %0.3f'
                    % (epoch_num, time.time() - start_time, d_loss, m_loss_d,
                       g_loss, g1, g2, m_loss_g))
                start_time = time.time()
            if epoch_num % 1000 == 0:
                ######################### validation costs ################################
                Ap, Bp, p_label = valid_reader.next_batch(
                    self.conf.batch)  #### mri, pet, categ_labels
                pl_lab = get_pairwiselabel(
                    np.concatenate((p_label, p_label, p_label), 0))
                feed_dict = {
                    self.a_p: Ap,
                    self.b_p: Bp,
                    self.cat: p_label,
                    self.a_u: Ap,
                    self.b_u:
                    Bp,  ## tensors except for a_p, b_p are used for placeholder only
                    self.pl_total: pl_lab
                }
                d_loss = self.sess.run(self.d_loss_total, feed_dict=feed_dict)
                g_loss, gen_a, gen_b = self.sess.run(
                    [self.g_loss_total, self.fake_ap, self.fake_bp],
                    feed_dict=feed_dict)
                ps = 0.
                ss = 0.
                a_p = np.squeeze(Ap, axis=-1)
                b_p = np.squeeze(Bp, axis=-1)
                gen_a = np.squeeze(gen_a, axis=-1)
                gen_b = np.squeeze(gen_b, axis=-1)
                for i, (ap, ga, bp,
                        gb) in enumerate(zip(a_p, gen_a, b_p, gen_b)):
                    ap, bp, ga, gb = ops.normalize(ap), ops.normalize(
                        bp), ops.normalize(ga), ops.normalize(gb)
                    ps += (ops.psnr(ap, ga) + ops.psnr(bp, gb)) / 2
                    ss += (ops.ssim(ap, ga) + ops.ssim(bp, gb)) / 2
                acc, f1, RI, purity = self.calculate_cost(
                    self.sess, train_p2, valid_p2)
                print(
                    'proposed_model valid d loss %0.3f, g loss %0.3f, PSNR %0.3f, SSIM %0.3f, acc %0.4f, f1 %0.4f, RI %0.4f, purity %0.4f'
                    % (d_loss, g_loss, ps / self.conf.batch,
                       ss / self.conf.batch, acc, f1, RI, purity))

                if best_acc < acc:
                    best_acc = acc
                    best_f1, best_purity, best_RI = f1, purity, RI
                    self.save(epoch_num)
                    epochs_no_performance_gain = 0
                else:
                    epochs_no_performance_gain += 1
                    if epochs_no_performance_gain > 5:
                        print('stop since no improvement for %d epochs' %
                              epochs_no_performance_gain)
                        # print('acc, f1, purity, RI &%.4f &%.4f & &%.4f &%.4f'% (best_acc, best_f1, best_purity, best_RI))
                        print('acc, f1, purity, RI [%.4f, %.4f, %.4f, %.4f]' %
                              (best_acc, best_f1, best_purity, best_RI))
                        break

            epoch_num += 1
Example #15
0
    def test(self):
        if self.conf.test_all:
            model_name = os.path.basename(self.conf.modeldir)
            if not os.path.exists('./test'):
                os.makedirs('./test')
            f = open(os.path.join('./test', model_name + '.csv'), "w+")
            print('testing all')
            latest_checkpoint_path = tf.train.latest_checkpoint(
                self.conf.modeldir)
            latest_checkpoint = int(latest_checkpoint_path.rsplit('-')[1])
            checkpoint = 1 + self.conf.save_interval
            while checkpoint <= latest_checkpoint:
                self.reload(checkpoint)
                accuracies = []
                test_reader = H5DataLoader(
                    self.conf.data_dir + self.conf.test_data, False)
                while True:
                    inputs, labels = test_reader.next_batch(self.conf.batch)
                    if inputs is None or inputs.shape[0] < self.conf.batch:
                        break
                    feed_dict = {self.inputs: inputs, self.labels: labels}
                    accur = self.sess.run(self.accuracy_op,
                                          feed_dict=feed_dict)
                    accuracies.append(accur)
                f.write('%d, \t %f' %
                        (checkpoint, sum(accuracies) / len(accuracies)))
                print('%d, \t %f' %
                      (checkpoint, sum(accuracies) / len(accuracies)))
                checkpoint += self.conf.save_interval
            f.close()

        else:
            print('---->testing ', self.conf.test_step)
            if self.conf.test_step > 0:
                self.reload(self.conf.test_step)
            else:
                print("please set a reasonable test_step")
                return
            test_reader = H5DataLoader(
                self.conf.data_dir + self.conf.test_data, False)
            accuracies = []
            model_name = os.path.basename(self.conf.modeldir)
            test_dir = './test/' + model_name
            if not os.path.exists(test_dir):
                os.makedirs(test_dir)
            index = 0
            wrong_preds = []
            debug_stat = [[0 for i in range(10)] for i in range(10)]
            while True:
                inputs, labels = test_reader.next_batch(self.conf.batch)
                if inputs is None or inputs.shape[0] < self.conf.batch:
                    break
                feed_dict = {self.inputs: inputs, self.labels: labels}
                accur, preds = self.sess.run(
                    [self.accuracy_op, self.decoded_preds],
                    feed_dict=feed_dict)
                for i in range(len(preds)):
                    if preds[i] != labels[i]:
                        img = inputs[i]
                        img[:, :,
                            0] = (img[:, :, 0] * 0.24703233 + 0.49139968) * 255
                        img[:, :,
                            1] = (img[:, :, 1] * 0.24348505 + 0.48215827) * 255
                        img[:, :,
                            2] = (img[:, :, 2] * 0.26158768 + 0.44653118) * 255
                        img = np.array(np.uint8(img))
                        pil_image = Image.fromarray(img)
                        pil_image.save(
                            os.path.join(test_dir,
                                         str(index) + '.png'))
                        debug_stat[int(labels[i])][int(preds[i])] += 1
                        wrong_preds.append({
                            index: {
                                'label': int(labels[i]),
                                'pred': int(preds[i])
                            }
                        })
                        index += 1
                accuracies.append(accur)
            label_list = [
                'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
                'horse', 'ship', 'truck'
            ]
            print('\t', end='')
            for item in label_list:
                print(item + '\t', end=' ')
            print("total \t percentage")
            for i in range(10):
                print(label_list[i], end='\t')
                for j in range(10):
                    print(debug_stat[i][j], end='\t')
                print("%d \t %f" % (sum(debug_stat[i]), sum(debug_stat[i]) /
                                    sum([sum(i) for i in debug_stat])))

            json.dump(wrong_preds,
                      open(os.path.join(test_dir,
                                        str(index) + '.json'), 'w+'),
                      sort_keys=True,
                      indent=4)
            json.dump(debug_stat,
                      open(os.path.join(test_dir,
                                        str(index) + '.json'), 'w+'),
                      sort_keys=True,
                      indent=4)
            print('step: %d, accuracy %f' %
                  (self.conf.test_step, sum(accuracies) / len(accuracies)))
Example #16
0
    def train(self):
        def random_flip(image):
            return random.choice([image, np.fliplr(image)])

        def random_crop(image):
            pad_width = ((4, 4), (4, 4), (0, 0))
            image = np.lib.pad(image,
                               pad_width=pad_width,
                               mode='constant',
                               constant_values=0)
            start_h = random.randint(0, 8)
            start_w = random.randint(0, 8)
            return image[start_h:start_h + 32, start_w:start_w + 32]

        train_reader = H5DataLoader(self.conf.data_dir + self.conf.train_data)
        valid_reader = H5DataLoader(self.conf.data_dir + self.conf.valid_data)

        for epoch_num in range(
                tf.train.global_step(self.sess, self.global_step),
                self.conf.max_step + 1):
            if epoch_num and epoch_num % self.conf.test_interval == 0:
                inputs, labels = valid_reader.next_batch(self.conf.batch)
                feed_dict = {
                    self.inputs:
                    inputs,
                    self.labels:
                    labels,
                    self.learning_rate_placeholder:
                    self.learning_rate_schedule(epoch_num)
                }
                loss, summary = self.sess.run(
                    [self.loss_op, self.valid_summary], feed_dict=feed_dict)
                self.save_summary(summary, epoch_num)
                print('global step: %d; training loss %f' % (epoch_num, loss))
            if epoch_num and epoch_num % self.conf.summary_interval == 0:
                inputs, labels = train_reader.next_batch(self.conf.batch)
                inputs = np.array(list(map(random_flip, inputs)))
                inputs = np.array(list(map(random_crop, inputs)))
                feed_dict = {
                    self.inputs:
                    inputs,
                    self.labels:
                    labels,
                    self.learning_rate_placeholder:
                    self.learning_rate_schedule(epoch_num)
                }
                loss, _, summary = self.sess.run(
                    [self.loss_op, self.train_op, self.train_summary],
                    feed_dict=feed_dict)
                self.save_summary(summary, epoch_num)
            else:
                inputs, labels = train_reader.next_batch(self.conf.batch)
                inputs = np.array(list(map(random_flip, inputs)))
                inputs = np.array(list(map(random_crop, inputs)))
                feed_dict = {
                    self.inputs:
                    inputs,
                    self.labels:
                    labels,
                    self.learning_rate_placeholder:
                    self.learning_rate_schedule(epoch_num)
                }
                loss, _ = self.sess.run([self.loss_op, self.train_op],
                                        feed_dict=feed_dict)
                print('global step: %d; training loss %f' % (epoch_num, loss))
            if epoch_num and epoch_num % self.conf.save_interval == 0:
                self.save(epoch_num)