コード例 #1
0
ファイル: vnet3d_predict.py プロジェクト: neo-cc/LUNA16
def predict():
    src_path = str(
        Path(ROOT_DIR + "/data/LIDC/LUNA16/segmentation/Image/3_98/"))
    mask_path = str(
        Path(ROOT_DIR + "/data/LIDC/LUNA16/segmentation/Mask/3_98/"))
    imges = []
    masks = []
    for z in range(16):
        img = cv2.imread(src_path + str(z) + ".bmp", cv2.IMREAD_GRAYSCALE)
        mask = cv2.imread(mask_path + str(z) + ".bmp", cv2.IMREAD_GRAYSCALE)
        imges.append(img)
        masks.append(mask)

    test_imges = np.array(imges)
    test_imges = np.reshape(test_imges, (16, 96, 96))

    test_masks = np.array(masks)
    test_masks = np.reshape(test_masks, (16, 96, 96))
    Vnet3d = Vnet3dModule(
        96,
        96,
        16,
        channels=1,
        costname=("dice coefficient", ),
        inference=True,
        model_path=Path(ROOT_DIR +
                        "/model/trained/segmeation/model/Vnet3d.pd-50000"))
    predict = Vnet3d.prediction(test_imges)
    test_images = np.multiply(test_imges, 1.0 / 255.0)
    test_masks = np.multiply(test_masks, 1.0 / 255.0)
    save_images(test_images, [4, 4], "test_src.bmp")
    save_images(test_masks, [4, 4], "test_mask.bmp")
    save_images(predict, [4, 4], "test_predict.bmp")
コード例 #2
0
ファイル: vnet3d_predict.py プロジェクト: larry-11/CT-GAN
def predict():
    # src_path = "G:\Data\LIDC\LUNA16\segmentation\Image\\3_98\\"
    npy_path = '/data/shanyx/larry/LUNA16-Lung-Nodule-Analysis-2016-Challenge/gen_stage4_iter4999.npy'
    # npy_path = '/data/shanyx/larry/LUNA16-Lung-Nodule-Analysis-2016-Challenge/inp_stage4_iter249.npy'
    # mask_path = "G:\Data\LIDC\LUNA16\segmentation\Mask\\3_98\\"
    imges = []
    # masks = []
    # for z in range(16):
    #     img = cv2.imread(src_path + str(z) + ".bmp", cv2.IMREAD_GRAYSCALE)
    #     # mask = cv2.imread(mask_path + str(z) + ".bmp", cv2.IMREAD_GRAYSCALE)
    #     imges.append(img)
    #     # masks.append(mask)
    test = np.load(npy_path) * 255
    test = test[:, :, 5:20, 55:115, 40:100]
    test = torch.from_numpy(test)
    x_tmp = F.interpolate(test, (16, 96, 96),\
             mode='trilinear', align_corners=True)
    test_imges = x_tmp.numpy()
    test_imges = np.reshape(test_imges, (16, 96, 96))

    # test_masks = np.array(masks)
    # test_masks = np.reshape(test_masks, (16, 96, 96))
    Vnet3d = Vnet3dModule(
        96,
        96,
        16,
        channels=1,
        costname=("dice coefficient", ),
        inference=True,
        model_path=
        "/data/shanyx/larry/LUNA16-Lung-Nodule-Analysis-2016-Challenge/segmeation/model/Vnet3d.pd-50000"
    )
    predict = Vnet3d.prediction(test_imges)
    print(predict.shape)
    # print(predict)
    test_images = np.multiply(test_imges, 1.0 / 255.0)
    # test_masks = np.multiply(test_masks, 1.0 / 255.0)
    save_images(test_images, [4, 4], "test_src.bmp")
    # save_images(test_masks, [4, 4], "test_mask.bmp")
    save_images(predict, [4, 4], "test_predict.bmp")
コード例 #3
0
def predict():
    src_path = "G:\Data\LIDC\LUNA16\segmentation\Image\\3_98\\"
    mask_path = "G:\Data\LIDC\LUNA16\segmentation\Mask\\3_98\\"
    imges = []
    masks = []
    for z in range(16):
        img = cv2.imread(src_path + str(z) + ".bmp", cv2.IMREAD_GRAYSCALE)
        mask = cv2.imread(mask_path + str(z) + ".bmp", cv2.IMREAD_GRAYSCALE)
        imges.append(img)
        masks.append(mask)

    test_imges = np.array(imges)
    test_imges = np.reshape(test_imges, (16, 96, 96))

    test_masks = np.array(masks)
    test_masks = np.reshape(test_masks, (16, 96, 96))
    Vnet3d = Vnet3dModule(96, 96, 16, channels=1, costname=("dice coefficient",), inference=True,
                          model_path="log\segmeation\model\Vnet3d.pd-50000")
    predict = Vnet3d.prediction(test_imges)
    test_images = np.multiply(test_imges, 1.0 / 255.0)
    test_masks = np.multiply(test_masks, 1.0 / 255.0)
    save_images(test_images, [4, 4], "test_src.bmp")
    save_images(test_masks, [4, 4], "test_mask.bmp")
    save_images(predict, [4, 4], "test_predict.bmp")
    def train(self, train_images, train_labels, model_path, logs_path, learning_rate,
              dropout_conv=0.8, train_epochs=5, batch_size=1):
        if not os.path.exists(logs_path):
            os.makedirs(logs_path)
        if not os.path.exists(logs_path + "model/"):
            os.makedirs(logs_path + "model/")
        #model_path = logs_path + "model\\" + model_path # this doesn't make sense.
        model_path = logs_path + "model/"
        train_op = tf.train.AdamOptimizer(self.lr).minimize(self.cost)

        init = tf.global_variables_initializer()
        saver = tf.train.Saver(tf.global_variables(scope=None), max_to_keep=10) # changed from tf.all_variables()

        tf.summary.scalar("loss", self.cost)
        tf.summary.scalar("accuracy", self.accuracy)
        merged_summary_op = tf.summary.merge_all()
        sess = tf.InteractiveSession(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))
        summary_writer = tf.summary.FileWriter(logs_path, graph=tf.get_default_graph())
        sess.run(init)

        DISPLAY_STEP = 1
        index_in_epoch = 0

        train_epochs = train_images.shape[0] * train_epochs
        for i in range(train_epochs):
            # get new batch
            batch_xs_path, batch_ys_path, index_in_epoch = _next_batch(train_images, train_labels, batch_size,
                                                                       index_in_epoch)
            batch_xs = np.empty((len(batch_xs_path), self.image_depth, self.image_height, self.image_width,
                                 self.channels))
            batch_ys = np.empty((len(batch_ys_path), self.image_depth, self.image_height, self.image_width,
                                 self.channels))
            for num in range(len(batch_xs_path)):
                index = 0
                for _ in os.listdir(batch_xs_path[num][0]):
                    image = cv2.imread(batch_xs_path[num][0] + "/" + str(index) + ".bmp", cv2.IMREAD_GRAYSCALE)
                    label = cv2.imread(batch_ys_path[num][0] + "/" + str(index) + ".bmp", cv2.IMREAD_GRAYSCALE)

                    batch_xs[num, index, :, :, :] = np.reshape(image, (self.image_height, self.image_width,
                                                                       self.channels))
                    batch_ys[num, index, :, :, :] = np.reshape(label, (self.image_height, self.image_width,
                                                                       self.channels))
                    index += 1
            # Extracting images and labels from given data
            batch_xs = batch_xs.astype(np.float)
            batch_ys = batch_ys.astype(np.float)
            # Normalize from [0:255] => [0.0:1.0]
            batch_xs = np.multiply(batch_xs, 1.0 / 255.0)
            batch_ys = np.multiply(batch_ys, 1.0 / 255.0)
            # check progress on every 1st,2nd,...,10th,20th,...,100th... step
            if i % DISPLAY_STEP == 0 or (i + 1) == train_epochs:
                train_loss, train_accuracy = sess.run([self.cost, self.accuracy],
                                                      feed_dict={self.X: batch_xs,
                                                                 self.Y_gt: batch_ys,
                                                                 self.lr: learning_rate,
                                                                 self.phase: 1,
                                                                 self.drop: dropout_conv})
                print('epochs %d Training_loss ,Training_accuracy => %.5f,%.5f ' % (i, train_loss, train_accuracy))

                pred = sess.run(self.Y_pred, feed_dict={self.X: batch_xs,
                                                        self.Y_gt: batch_ys,
                                                        self.phase: 1,
                                                        self.drop: 1})

                gt = np.reshape(batch_xs[0], (self.image_depth, self.image_height, self.image_width))

                gt = gt.astype(np.float32)
                save_images(gt, [8, 8], path=logs_path + 'src_%d_epoch.png' % (i)) # changed from [4 4]

                gt = np.reshape(batch_ys[0], (self.image_depth, self.image_height, self.image_width))
                gt = gt.astype(np.float32)
                save_images(gt, [8, 8], path=logs_path + 'gt_%d_epoch.png' % (i)) # changed from [4 4]

                result = np.reshape(pred[0], (self.image_depth, self.image_height, self.image_width))
                result = result.astype(np.float32)
                save_images(result, [8, 8], path=logs_path + 'predict_%d_epoch.png' % (i)) # changed from [4 4]

                save_path = saver.save(sess, model_path, global_step=i)
                print("Model saved in file:", save_path)
                if i % (DISPLAY_STEP * 10) == 0 and i:
                    DISPLAY_STEP *= 10

                    # train on batch
            _, summary = sess.run([train_op, merged_summary_op], feed_dict={self.X: batch_xs,
                                                                            self.Y_gt: batch_ys,
                                                                            self.lr: learning_rate,
                                                                            self.phase: 1,
                                                                            self.drop: dropout_conv})
            summary_writer.add_summary(summary, i)
        summary_writer.close()

        save_path = saver.save(sess, model_path)
        print("Model saved in file:", save_path)
コード例 #5
0
    def train(self, train_images, train_labels, model_path, logs_path, learning_rate,
              dropout_conv=0.8, train_epochs=5, batch_size=1, showwind=[6, 8]):
        num_sample = 1
        if not os.path.exists(logs_path):
            os.makedirs(logs_path)
        if not os.path.exists(logs_path + "model\\"):
            os.makedirs(logs_path + "model\\")
        model_path = logs_path + "model\\" + model_path
        train_op = tf.train.AdamOptimizer(self.lr).minimize(self.cost)

        init = tf.global_variables_initializer()
        saver = tf.train.Saver(tf.all_variables(), max_to_keep=10)

        tf.summary.scalar("loss", self.cost)
        tf.summary.scalar("accuracy", self.accuracy)
        merged_summary_op = tf.summary.merge_all()

        summary_writer = tf.summary.FileWriter(logs_path, graph=tf.get_default_graph())
        self.sess.run(init)

        ckpt = tf.train.get_checkpoint_state(logs_path + "model\\")
        if ckpt and ckpt.model_checkpoint_path:
            print('Checkpoint file: {}'.format(ckpt.model_checkpoint_path))
            saver.restore(self.sess, ckpt.model_checkpoint_path)

        DISPLAY_STEP = 1
        num_sample_index_in_epoch = 0
        index_in_epoch = 0
        train_epochs = train_images.shape[0] * train_epochs
        for i in range(train_epochs):
            # Extracting num_sample images and labels from given data
            if i % num_sample == 0 or i == 0:
                subbatch_xs, subbatch_ys, num_sample_index_in_epoch = self.__loadnumtraindata(train_images,
                                                                                              train_labels, num_sample,
                                                                                              num_sample_index_in_epoch)
            # get new batch
            batch_xs, batch_ys, index_in_epoch = _next_batch(subbatch_xs, subbatch_ys, batch_size, index_in_epoch)
            # convert label to one hot type
            batch_ys_onehot = convert_to_one_hot(batch_ys, self.numclass)
            batch_ys_onehot = np.reshape(batch_ys_onehot,
                                         (-1, self.image_depth, self.image_height, self.image_width, self.numclass))
            # check progress on every 1st,2nd,...,10th,20th,...,100th... step
            if i % DISPLAY_STEP == 0 or (i + 1) == train_epochs:
                train_loss, train_accuracy = self.sess.run([self.cost, self.accuracy],
                                                           feed_dict={self.X: batch_xs,
                                                                      self.Y_gt: batch_ys_onehot,
                                                                      self.lr: learning_rate,
                                                                      self.phase: 1,
                                                                      self.drop: dropout_conv})
                print('epochs %d training_loss ,Training_accuracy => %.5f,%.5f ' % (i, train_loss, train_accuracy))

                pred_arg = self.sess.run(self.Y_pred_arg, feed_dict={self.X: batch_xs,
                                                                     self.Y_gt: batch_ys_onehot,
                                                                     self.phase: 1,
                                                                     self.drop: 1})
                batch_ys_tmp = np.argmax(batch_ys_onehot, axis=-1)
                gt = np.reshape(batch_ys_tmp[0], (self.image_depth, self.image_height, self.image_width))
                gt = gt.astype(np.float)
                save_images(gt, showwind, path=logs_path + 'gt_%d_epoch.png' % (i), pixelvalue=85)

                result = np.reshape(pred_arg[0], (self.image_depth, self.image_height, self.image_width))
                result = result.astype(np.float)
                save_images(result, showwind, path=logs_path + 'predict_%d_epoch.png' % (i), pixelvalue=85)

                save_path = saver.save(self.sess, model_path, global_step=i)
                print("Model saved in file:", save_path)
                if i % (DISPLAY_STEP * 10) == 0 and i:
                    DISPLAY_STEP *= 10

                    # train on batch
            _, summary = self.sess.run([train_op, merged_summary_op], feed_dict={self.X: batch_xs,
                                                                                 self.Y_gt: batch_ys_onehot,
                                                                                 self.lr: learning_rate,
                                                                                 self.phase: 1,
                                                                                 self.drop: dropout_conv})
            summary_writer.add_summary(summary, i)
        summary_writer.close()

        save_path = saver.save(self.sess, model_path)
        print("Model saved in file:", save_path)
コード例 #6
0
    def train(self, train_images, train_lanbels, model_path, logs_path, learning_rate,
              dropout_conv=0.8, train_epochs=5, batch_size=1, showwindow=[8, 8]):
        num_sample = 100
        if not os.path.exists(logs_path):
            os.makedirs(logs_path)
        if not os.path.exists(logs_path + "model\\"):
            os.makedirs(logs_path + "model\\")
        model_path = logs_path + "model\\" + model_path
        train_op = tf.train.AdamOptimizer(self.lr).minimize(self.cost)

        init = tf.global_variables_initializer()
        saver = tf.train.Saver(tf.all_variables(), max_to_keep=10)

        tf.summary.scalar("loss", self.cost)
        tf.summary.scalar("accuracy", self.accuracy)
        merged_summary_op = tf.summary.merge_all()
        sess = tf.InteractiveSession(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))
        summary_writer = tf.summary.FileWriter(logs_path, graph=tf.get_default_graph())
        sess.run(init)

        if os.path.exists(model_path):
            saver.restore(sess, model_path)

        # load data and show result param
        DISPLAY_STEP = 1
        num_sample_index_in_epoch = 0
        index_in_epoch = 0

        train_epochs = train_images.shape[0] * train_epochs

        subbatch_xs = np.empty((num_sample, self.image_depth, self.image_height, self.image_width, self.channels))
        subbatch_ys = np.empty((num_sample, self.image_depth, self.image_height, self.image_width, self.numclass))

        for i in range(train_epochs):
            # Extracting num_sample images and labels from given data
            if i % num_sample == 0 or i == 0:
                batch_xs_path, batch_ys_path, num_sample_index_in_epoch = _next_batch(train_images, train_lanbels,
                                                                                      num_sample,
                                                                                      num_sample_index_in_epoch)
                for num in range(len(batch_xs_path)):
                    image = np.load(batch_xs_path[num])
                    label = np.load(batch_ys_path[num])
                    # prepare 3 model output
                    batch_ys1 = label.copy()
                    batch_ys1[label == 1.] = 1.
                    batch_ys1[label != 1.] = 0.
                    batch_ys2 = label.copy()
                    batch_ys2[label == 2.] = 1.
                    batch_ys2[label != 2.] = 0.
                    batch_ys3 = label.copy()
                    batch_ys3[label == 4.] = 1.
                    batch_ys3[label != 4.] = 0.
                    subbatch_xs[num, :, :, :, :] = np.reshape(image,
                                                              (self.image_depth, self.image_height, self.image_width,
                                                               self.channels))
                    label_ys = np.empty((self.image_depth, self.image_height, self.image_width, self.numclass))
                    label_ys[:, :, :, 0] = batch_ys1
                    label_ys[:, :, :, 1] = batch_ys2
                    label_ys[:, :, :, 2] = batch_ys3
                    subbatch_ys[num, :, :, :, :] = np.reshape(label_ys,
                                                              (self.image_depth, self.image_height, self.image_width,
                                                               self.numclass))

                subbatch_xs = subbatch_xs.astype(np.float)
                subbatch_ys = subbatch_ys.astype(np.float)
            # get new batch
            batch_xs, batch_ys, index_in_epoch = _next_batch(subbatch_xs, subbatch_ys, batch_size, index_in_epoch)
            # check progress on every 1st,2nd,...,10th,20th,...,100th... step
            if i % DISPLAY_STEP == 0 or (i + 1) == train_epochs:
                train_loss, train_accuracy = sess.run(
                    [self.cost, self.accuracy], feed_dict={self.X: batch_xs,
                                                           self.Y_gt: batch_ys,
                                                           self.lr: learning_rate,
                                                           self.phase: 1,
                                                           self.drop: dropout_conv})
                print('epochs %d training_loss ,training_accuracy ''=> %.5f,%.5f ' % (i, train_loss, train_accuracy))

                pred = sess.run(self.Y_pred, feed_dict={self.X: batch_xs,
                                                        self.Y_gt: batch_ys,
                                                        self.phase: 1,
                                                        self.drop: 1})
                gt = np.reshape(batch_ys[0], (self.image_depth, self.image_height, self.image_width, self.numclass))
                gt1 = gt[:, :, :, 0]
                gt1 = np.reshape(gt1, (self.image_depth, self.image_height, self.image_width))
                gt1 = gt1.astype(np.float)
                save_images(gt1, showwindow, path=logs_path + 'gt1_%d_epoch.png' % i)
                gt2 = gt[:, :, :, 1]
                gt2 = np.reshape(gt2, (self.image_depth, self.image_height, self.image_width))
                gt2 = gt2.astype(np.float)
                save_images(gt2, showwindow, path=logs_path + 'gt2_%d_epoch.png' % i)
                gt3 = gt[:, :, :, 2]
                gt3 = np.reshape(gt3, (self.image_depth, self.image_height, self.image_width))
                gt3 = gt3.astype(np.float)
                save_images(gt3, showwindow, path=logs_path + 'gt3_%d_epoch.png' % i)

                result = np.reshape(pred[0], (self.image_depth, self.image_height, self.image_width, self.numclass))
                result1 = result[:, :, :, 0]
                result1 = np.reshape(result1, (self.image_depth, self.image_height, self.image_width))
                result1 = result1.astype(np.float)
                save_images(result1, showwindow, path=logs_path + 'predict1_%d_epoch.png' % i)
                result2 = result[:, :, :, 1]
                result2 = np.reshape(result2, (self.image_depth, self.image_height, self.image_width))
                result2 = result2.astype(np.float)
                save_images(result2, showwindow, path=logs_path + 'predict2_%d_epoch.png' % i)
                result3 = result[:, :, :, 2]
                result3 = np.reshape(result3, (self.image_depth, self.image_height, self.image_width))
                result3 = result3.astype(np.float)
                save_images(result3, showwindow, path=logs_path + 'predict3_%d_epoch.png' % i)

                save_path = saver.save(sess, model_path, global_step=i)
                print("Model saved in file:", save_path)
                if i % (DISPLAY_STEP * 10) == 0 and i:
                    DISPLAY_STEP *= 10

                    # train on batch
            _, summary = sess.run([train_op, merged_summary_op], feed_dict={self.X: batch_xs,
                                                                            self.Y_gt: batch_ys,
                                                                            self.lr: learning_rate,
                                                                            self.phase: 1,
                                                                            self.drop: dropout_conv})
            summary_writer.add_summary(summary, i)
        summary_writer.close()

        save_path = saver.save(sess, model_path)
        print("Model saved in file:", save_path)
コード例 #7
0
ファイル: model_vnet3d.py プロジェクト: kirangpcet/KRCC
    def train(self,
              train_images,
              train_lanbels,
              model_path,
              logs_path,
              learning_rate,
              dropout_conv=0.8,
              train_epochs=5,
              batch_size=1,
              imagenum=[4, 8]):
        if not os.path.exists(logs_path):
            os.makedirs(logs_path)
        if not os.path.exists(logs_path + "model\\"):
            os.makedirs(logs_path + "model\\")
        model_path = logs_path + "model\\" + model_path
        train_op = tf.train.AdamOptimizer(self.lr).minimize(self.cost)

        init = tf.global_variables_initializer()
        saver = tf.train.Saver(tf.all_variables(), max_to_keep=10)

        tf.summary.scalar("loss", self.cost)
        tf.summary.scalar("accuracy", self.accuracy)
        merged_summary_op = tf.summary.merge_all()
        sess = tf.InteractiveSession(config=tf.ConfigProto(
            allow_soft_placement=True, log_device_placement=False))
        summary_writer = tf.summary.FileWriter(logs_path,
                                               graph=tf.get_default_graph())
        sess.run(init)
        saver.restore(
            sess,
            "E:\junqiangchen\project\KiTS19Challege\log\segmeation\model\Vnet3d.pd-100000"
        )

        DISPLAY_STEP = 1
        index_in_epoch = 0

        train_epochs = train_images.shape[0] * train_epochs
        for i in range(train_epochs):
            # get new batch
            batch_xs_path, batch_ys_path, index_in_epoch = _next_batch(
                train_images, train_lanbels, batch_size, index_in_epoch)
            batch_xs = np.empty(
                (len(batch_xs_path), self.image_depth, self.image_height,
                 self.image_width, self.channels))
            batch_ys = np.empty(
                (len(batch_ys_path), self.image_depth, self.image_height,
                 self.image_width, self.channels))
            for num in range(len(batch_xs_path)):
                image = np.load(batch_xs_path[num])
                label = np.load(batch_ys_path[num])
                batch_xs[num, :, :, :, :] = np.reshape(
                    image, (self.image_depth, self.image_height,
                            self.image_width, self.channels))
                batch_ys[num, :, :, :, :] = np.reshape(
                    label, (self.image_depth, self.image_height,
                            self.image_width, self.channels))
            # Extracting images and labels from given data
            batch_xs = batch_xs.astype(np.float)
            batch_ys = batch_ys.astype(np.float)
            # Normalize from [0:255] => [0.0:1.0]
            batch_xs = np.multiply(batch_xs, 1.0 / 255.0)
            batch_ys = np.multiply(batch_ys, 1.0 / 255.0)
            # check progress on every 1st,2nd,...,10th,20th,...,100th... step
            if i % DISPLAY_STEP == 0 or (i + 1) == train_epochs:
                train_loss, train_accuracy = sess.run(
                    [self.cost, self.accuracy],
                    feed_dict={
                        self.X: batch_xs,
                        self.Y_gt: batch_ys,
                        self.lr: learning_rate,
                        self.phase: 1,
                        self.drop: dropout_conv
                    })
                print(
                    'epochs %d training_loss ,Training_accuracy => %.5f,%.5f '
                    % (i, train_loss, train_accuracy))

                pred = sess.run(self.Y_pred,
                                feed_dict={
                                    self.X: batch_xs,
                                    self.Y_gt: batch_ys,
                                    self.phase: 1,
                                    self.drop: 1
                                })

                gt_src = np.reshape(
                    batch_xs[0],
                    (self.image_depth, self.image_height, self.image_width))
                gt_src = gt_src.astype(np.float32)
                save_images(gt_src,
                            imagenum,
                            path=logs_path + 'src_%d_epoch.png' % (i))

                gt = np.reshape(
                    batch_ys[0],
                    (self.image_depth, self.image_height, self.image_width))
                gt = gt.astype(np.float32)
                save_images(gt,
                            imagenum,
                            path=logs_path + 'gt_%d_epoch.png' % (i))

                result = np.reshape(
                    pred[0],
                    (self.image_depth, self.image_height, self.image_width))
                result = result.astype(np.float32)
                save_images(result,
                            imagenum,
                            path=logs_path + 'predict_%d_epoch.png' % (i))

                save_path = saver.save(sess, model_path, global_step=i)
                print("Model saved in file:", save_path)
                if i % (DISPLAY_STEP * 10) == 0 and i:
                    DISPLAY_STEP *= 10

                    # train on batch
            _, summary = sess.run(
                [train_op, merged_summary_op],
                feed_dict={
                    self.X: batch_xs,
                    self.Y_gt: batch_ys,
                    self.lr: learning_rate,
                    self.phase: 1,
                    self.drop: dropout_conv
                })
            summary_writer.add_summary(summary, i)
        summary_writer.close()

        save_path = saver.save(sess, model_path)
        print("Model saved in file:", save_path)
コード例 #8
0
    def train(self,
              train_images,
              pos_data,
              model_path,
              logs_path,
              learning_rate,
              dropout_conv=0.8,
              train_epochs=10,
              batch_size=1):
        if not os.path.exists(logs_path):
            os.makedirs(logs_path)
        if not os.path.exists(logs_path + "model/"):
            os.makedirs(logs_path + "model/")
        model_path = logs_path + "model/" + model_path
        train_op = tf.train.AdamOptimizer(self.lr).minimize(self.cost)

        init = tf.global_variables_initializer()
        saver = tf.train.Saver(tf.all_variables(), max_to_keep=10)

        tf.summary.scalar("loss", self.cost)
        tf.summary.scalar("accuracy", self.accuracy)
        merged_summary_op = tf.summary.merge_all()
        sess = tf.InteractiveSession(config=tf.ConfigProto(
            allow_soft_placement=True, log_device_placement=False))
        summary_writer = tf.summary.FileWriter(logs_path,
                                               graph=tf.get_default_graph())
        sess.run(init)

        DISPLAY_STEP = 1
        index_in_epoch = 0

        train_epochs = train_images.shape[0] * train_epochs
        for i in range(train_epochs):
            # get new batch
            batch_xs_path, batch_ys_path, index_in_epoch = _next_batch(
                train_images, batch_size, index_in_epoch)
            batch_xs = np.empty(
                (len(batch_xs_path), self.image_depth, self.image_height,
                 self.image_width, self.channels))
            batch_ys = np.empty(
                (len(batch_ys_path), self.image_depth, self.image_height,
                 self.image_width, self.channels))
            for num in range(len(batch_xs_path)):
                image, label = get_denoise_data(batch_xs_path[num], pos_data)
                batch_xs[num, :, :, :, :] = np.reshape(
                    image, (self.image_depth, self.image_height,
                            self.image_width, self.channels))
                batch_ys[num, :, :, :, :] = np.reshape(
                    label, (self.image_depth, self.image_height,
                            self.image_width, self.channels))
            # Normalize from [0:255] => [0.0:1.0]
            batch_ys = np.clip(batch_ys, 0, 1).astype(np.float)
            # check progress on every 1st,2nd,...,10th,20th,...,100th... step
            if (i + 1) == train_epochs:
                train_loss, train_accuracy = sess.run(
                    [self.cost, self.accuracy],
                    feed_dict={
                        self.X: batch_xs,
                        self.Y_gt: batch_ys,
                        self.lr: learning_rate,
                        self.phase: 1,
                        self.drop: dropout_conv
                    })
                print(
                    'epochs %d training_loss ,Training_accuracy => %.5f,%.5f '
                    % (i, train_loss, train_accuracy))

                pred = sess.run(self.Y_pred,
                                feed_dict={
                                    self.X: batch_xs,
                                    self.Y_gt: batch_ys,
                                    self.phase: 1,
                                    self.drop: 1
                                })

                gt = np.reshape(
                    batch_xs[0],
                    (self.image_depth, self.image_height, self.image_width))
                gt = gt.astype(np.float32)
                save_images(gt, [4, 4],
                            path=logs_path + 'src_%d_epoch.png' % (i))

                gt = np.reshape(
                    batch_ys[0],
                    (self.image_depth, self.image_height, self.image_width))
                gt = gt.astype(np.float32)
                save_images(gt, [4, 4],
                            path=logs_path + 'gt_%d_epoch.png' % (i))

                result = np.reshape(
                    pred[0],
                    (self.image_depth, self.image_height, self.image_width))
                result = result.astype(np.float32)
                save_images(result, [4, 4],
                            path=logs_path + 'predict_%d_epoch.png' % (i))

                save_path = saver.save(sess, model_path, global_step=i)
                print("Model saved in file:", save_path)
                if i % (DISPLAY_STEP * 10) == 0 and i:
                    DISPLAY_STEP *= 10

            # train on batch
            print("[Batch %d/%d]" % (i, train_epochs))
            _, summary = sess.run(
                [train_op, merged_summary_op],
                feed_dict={
                    self.X: batch_xs,
                    self.Y_gt: batch_ys,
                    self.lr: learning_rate,
                    self.phase: 1,
                    self.drop: dropout_conv
                })
            summary_writer.add_summary(summary, i)
        summary_writer.close()

        save_path = saver.save(sess, model_path)
        print("Model saved in file:", save_path)