Exemplo n.º 1
0
    def valid_one_epoch(self, sess, data_loader, data_num, epoch, step):
        '''
        Generate once of image which remove instruments.
        Args:
            data_loader: training or validation data_loader.
            data_num: number of data.
        '''
        total_loss, total_im_warp_loss, total_flow_loss = 0, 0, 0
        total_l1, total_im_l1, total_l2, total_im_l2, total_psnr = 0, 0, 0, 0, 0
        count = 0
        t0 = time.time()

        im_warp_loss, flow_loss = 0, 0

        #True number to loss
        t_count = 0

        for count in range(1, data_num):
            step += 1

            images, images_da, mask, fn, flag, rect_param = data_loader.get_next_sequence_valid(
            )
            if flag == False:
                #print(fn)
                # Can't find prev data.
                continue
            # Origin optical flow
            for frame in range(1, cfgs.seq_frames):
                im, last_im = images[frame], images[frame - 1]
                flow = sess.run(self.flow_tensor,
                                feed_dict={
                                    self.flow_img0: im,
                                    self.flow_img1: last_im
                                })
            # Da optical flow
            for frame in range(1, cfgs.seq_frames):
                im, last_im = images_da[frame], images_da[frame - 1]
                flow_da = sess.run(self.flow_tensor,
                                   feed_dict={
                                       self.flow_img0: im,
                                       self.flow_img1: last_im
                                   })

            normal_flow, normal_max_v = normal_data(flow)
            inpt_input, flag_ = concat_data(normal_flow, mask, cfgs.grid)
            #epd_sd_mask, epd_rect_param = expand_sdm(rect_param, cfgs.gamma, cfgs.epd_ratio)
            sd_mask = sdm(rect_param, cfgs.gamma)
            if flag_ == False:
                print(fn)
                #After grid, area of mask = 0
                continue

            t_count += 1
            loss, im_warp_loss, flow_loss,\
            l_inpt_p, \
            l1_e, l2_e, im_l1_e, im_l2_e, psnr,\
            l_ori_im, l_warped_im,\
            inpt_warped_im, inpt_flow, pred_flow \
            =sess.run([self.loss, self.im_warp_loss, self.flow_loss_,\
                       self.local_patch_inpt_flow, \
                       self.l1_e, self.l2_e, self.w_l1_e, self.w_l2_e, self.w_psnr,\
                       self.local_ori_im, self.local_warped_im,\
                       self.inpt_warped_im, self.inpt_pred_flow, self.pred_complete_flow],\
                                               feed_dict={self.flow_img1: last_im,
                                                          self.flow_img0: im,
                                                          self.ori_cur_im: images[1],
                                                          self.ori_flow: flow,
                                                          self.max_v: normal_max_v,
                                                          self.inpt_data: inpt_input,
                                                          #self.epd_sd_mask: epd_sd_mask,
                                                          #self.epd_rect_param: epd_rect_param,
                                                          self.sd_mask: sd_mask,
                                                          self.rect_param: rect_param})

            #if count % 10 == 0:
            if False:
                self.view(np.expand_dims(fn, 0),
                          inpt_warped_im,
                          images_da[cfgs.seq_frames - 2],
                          images_da[cfgs.seq_frames - 1],
                          step,
                          f_path='valid')

                self.view_patch_one(l_ori_im[0],
                                    fn,
                                    step,
                                    'l_ori',
                                    f_path='valid')
                self.view_patch_one(l_warped_im[0],
                                    fn,
                                    step,
                                    'l_warped',
                                    f_path='valid')
                self.view_flow_patch_one(l_inpt_p[0],
                                         fn,
                                         step,
                                         'l_inpt_p',
                                         f_path='valid')
                self.view_flow_one(flow[0], fn, step, f_path='valid')
                self.view_flow_one(inpt_flow[0], fn, step, '_inpt', 'valid')
                self.view_flow_one(flow_da[0], fn, step, '_da', 'valid')
                self.view_flow_one(pred_flow[0],
                                   fn,
                                   step,
                                   'complete',
                                   f_path='valid')

            #3. calculate loss
            total_loss += loss
            total_im_warp_loss += im_warp_loss
            total_flow_loss += flow_loss
            total_l1 += l1_e
            total_im_l1 += im_l1_e
            total_l2 += l2_e
            total_im_l2 += im_l2_e
            total_psnr += psnr

            #4. time consume
            time_consumed = time.time() - t0
            time_per_batch = time_consumed / count

            #5. print
            #line = 'Valid epoch %2d\t lr = %g\t step = %4d\t count = %4d\t loss = %.4f\t m_loss=%.4f\t m_imW_loss = %.4f\t m_f_loss = %.4f\t time = %.2f' % (epoch, cfgs.inpt_lr, step, count, loss, (total_loss/t_count), (total_im_warp_loss/t_count), (total_flow_loss/t_count), time_per_batch)
            line = 'Valid epoch %2d\t step = %4d\t count = %4d\t m_imW_loss = %.2f\t m_f_loss = %.2f\t l1_e = %.3f\t im_l1_e = %.3f\t l2_e = %.3f\t im_l2_e = %.3f\t psnr = %4f\t' % (
                epoch, step, count, (total_im_warp_loss / t_count),
                (total_flow_loss / t_count), (total_l1 / count),
                (total_im_l1 / count), (total_l2 / count),
                (total_im_l2 / count), (total_psnr / count))

            utils.clear_line(len(line))
            print('\r' + line, end='')

        #End one epoch
        #count -= 1
        print(
            '\nepoch %5d\t learning_rate = %g\t mean_loss = %.4f\t m_imW_loss = %.4f\t m_f_loss = %.4f\t '
            % (epoch, cfgs.inpt_lr, (total_loss / t_count),
               (total_im_warp_loss / t_count), (total_flow_loss / t_count)))
        print('Take time %3.1f' % (time.time() - t0))

        return step
Exemplo n.º 2
0
    def valid_one_epoch(self, sess, data_loader, data_num, epoch, step):
        '''
        Generate once of image which remove instruments.
        Args:
            data_loader: training or validation data_loader.
            data_num: number of data.
        '''
        sum_acc, sum_acc_iou, sum_acc_ellip, total_loss = 0, 0, 0, 0
        count = 0
        t0 = time.time()
        mean_acc, mean_acc_iou, mean_acc_label, mean_acc_ellip = 0, 0, 0, 0
        self.ellip_acc, self.accu, self.accu_iou, loss = 0, 0, 0, 0

        for count in range(1, data_num):
            step += 1

            images, images_da, mask, fn, flag, mask_sum = data_loader.get_next_sequence(
            )
            if flag == False:
                print(fn)
                # Can't find prev data.
                continue
            # Origin optical flow
            for frame in range(1, cfgs.seq_frames):
                im, last_im = images[frame], images[frame - 1]
                flow = sess.run(self.flow_tensor,
                                feed_dict={
                                    self.flow_img0: im,
                                    self.flow_img1: last_im
                                })
            # Da optical flow
            for frame in range(1, cfgs.seq_frames):
                im, last_im = images_da[frame], images_da[frame - 1]
                flow_da = sess.run(self.flow_tensor,
                                   feed_dict={
                                       self.flow_img0: im,
                                       self.flow_img1: last_im
                                   })

            normal_flow, normal_max_v = normal_data(flow)
            inpt_input, flag_ = concat_data(normal_flow, mask, cfgs.grid)
            if flag_ == False:
                print(fn)
                #After grid, area of mask = 0
                continue



            loss, inpt_warped_im, inpt_flow, pred_flow =sess.run([self.loss, \
                                        self.inpt_warped_im, self.inpt_pred_flow, self.pred_complete_flow],\
                                               feed_dict={self.flow_img1: last_im,
                                                          self.flow_img0: im,
                                                          self.ori_flow: flow,
                                                          self.max_v: normal_max_v,
                                                          self.inpt_data: inpt_input})
            #if count % 20 == 0:
            if True:
                self.view(np.expand_dims(fn, 0),
                          inpt_warped_im,
                          images_da[cfgs.seq_frames - 2],
                          images_da[cfgs.seq_frames - 1],
                          step,
                          f_path='valid')

                self.view_flow_one(flow[0], fn, step, f_path='valid')
                self.view_flow_one(inpt_flow[0], fn, step, '_inpt', 'valid')
                self.view_flow_one(flow_da[0], fn, step, '_da', 'valid')
                self.view_flow_one(pred_flow[0],
                                   fn,
                                   step,
                                   'complete',
                                   f_path='valid')

            #2. calculate accurary
            self.ellip_acc = 0
            sum_acc += self.accu
            sum_acc_iou += self.accu_iou
            sum_acc_ellip += self.ellip_acc
            mean_acc = sum_acc / count
            mean_acc_iou = sum_acc_iou / count
            mean_acc_ellip = sum_acc_ellip / count
            #3. calculate loss
            total_loss += loss

            #4. time consume
            time_consumed = time.time() - t0
            time_per_batch = time_consumed / count

            #5. print
            line = 'Valid epoch %2d\t lr = %g\t step = %4d\t count = %4d\t loss = %.4f\t m_loss=%.4f\t  max_v = %.2f\t fn = %s\t time = %.2f' % (
                epoch, cfgs.inpt_lr, step, count, loss,
                (total_loss / count), normal_max_v, fn, time_per_batch)
            utils.clear_line(len(line))
            print('\r' + line, end='')

        #End one epoch
        #count -= 1
        print(
            '\nepoch %5d\t learning_rate = %g\t mean_loss = %.4f\t train_acc = %.2f%%\t train_iou_acc = %.2f%%\t train_ellip_acc = %.2f'
            % (epoch, cfgs.inpt_lr, (total_loss / count), (sum_acc / count),
               (sum_acc_iou / count), (sum_acc_ellip / count)))
        print('Take time %3.1f' % (time.time() - t0))

        return step