def valid_one_epoch(self, sess, data_loader, data_num, epoch, step): ''' Generate once of image which remove instruments. Args: data_loader: training or validation data_loader. data_num: number of data. ''' total_loss, total_im_warp_loss, total_flow_loss = 0, 0, 0 total_l1, total_im_l1, total_l2, total_im_l2, total_psnr = 0, 0, 0, 0, 0 count = 0 t0 = time.time() im_warp_loss, flow_loss = 0, 0 #True number to loss t_count = 0 for count in range(1, data_num): step += 1 images, images_da, mask, fn, flag, rect_param = data_loader.get_next_sequence_valid( ) if flag == False: #print(fn) # Can't find prev data. continue # Origin optical flow for frame in range(1, cfgs.seq_frames): im, last_im = images[frame], images[frame - 1] flow = sess.run(self.flow_tensor, feed_dict={ self.flow_img0: im, self.flow_img1: last_im }) # Da optical flow for frame in range(1, cfgs.seq_frames): im, last_im = images_da[frame], images_da[frame - 1] flow_da = sess.run(self.flow_tensor, feed_dict={ self.flow_img0: im, self.flow_img1: last_im }) normal_flow, normal_max_v = normal_data(flow) inpt_input, flag_ = concat_data(normal_flow, mask, cfgs.grid) #epd_sd_mask, epd_rect_param = expand_sdm(rect_param, cfgs.gamma, cfgs.epd_ratio) sd_mask = sdm(rect_param, cfgs.gamma) if flag_ == False: print(fn) #After grid, area of mask = 0 continue t_count += 1 loss, im_warp_loss, flow_loss,\ l_inpt_p, \ l1_e, l2_e, im_l1_e, im_l2_e, psnr,\ l_ori_im, l_warped_im,\ inpt_warped_im, inpt_flow, pred_flow \ =sess.run([self.loss, self.im_warp_loss, self.flow_loss_,\ self.local_patch_inpt_flow, \ self.l1_e, self.l2_e, self.w_l1_e, self.w_l2_e, self.w_psnr,\ self.local_ori_im, self.local_warped_im,\ self.inpt_warped_im, self.inpt_pred_flow, self.pred_complete_flow],\ feed_dict={self.flow_img1: last_im, self.flow_img0: im, self.ori_cur_im: images[1], self.ori_flow: flow, self.max_v: normal_max_v, self.inpt_data: inpt_input, #self.epd_sd_mask: epd_sd_mask, #self.epd_rect_param: epd_rect_param, self.sd_mask: sd_mask, self.rect_param: rect_param}) #if count % 10 == 0: if False: self.view(np.expand_dims(fn, 0), inpt_warped_im, images_da[cfgs.seq_frames - 2], images_da[cfgs.seq_frames - 1], step, f_path='valid') self.view_patch_one(l_ori_im[0], fn, step, 'l_ori', f_path='valid') self.view_patch_one(l_warped_im[0], fn, step, 'l_warped', f_path='valid') self.view_flow_patch_one(l_inpt_p[0], fn, step, 'l_inpt_p', f_path='valid') self.view_flow_one(flow[0], fn, step, f_path='valid') self.view_flow_one(inpt_flow[0], fn, step, '_inpt', 'valid') self.view_flow_one(flow_da[0], fn, step, '_da', 'valid') self.view_flow_one(pred_flow[0], fn, step, 'complete', f_path='valid') #3. calculate loss total_loss += loss total_im_warp_loss += im_warp_loss total_flow_loss += flow_loss total_l1 += l1_e total_im_l1 += im_l1_e total_l2 += l2_e total_im_l2 += im_l2_e total_psnr += psnr #4. time consume time_consumed = time.time() - t0 time_per_batch = time_consumed / count #5. print #line = 'Valid epoch %2d\t lr = %g\t step = %4d\t count = %4d\t loss = %.4f\t m_loss=%.4f\t m_imW_loss = %.4f\t m_f_loss = %.4f\t time = %.2f' % (epoch, cfgs.inpt_lr, step, count, loss, (total_loss/t_count), (total_im_warp_loss/t_count), (total_flow_loss/t_count), time_per_batch) line = 'Valid epoch %2d\t step = %4d\t count = %4d\t m_imW_loss = %.2f\t m_f_loss = %.2f\t l1_e = %.3f\t im_l1_e = %.3f\t l2_e = %.3f\t im_l2_e = %.3f\t psnr = %4f\t' % ( epoch, step, count, (total_im_warp_loss / t_count), (total_flow_loss / t_count), (total_l1 / count), (total_im_l1 / count), (total_l2 / count), (total_im_l2 / count), (total_psnr / count)) utils.clear_line(len(line)) print('\r' + line, end='') #End one epoch #count -= 1 print( '\nepoch %5d\t learning_rate = %g\t mean_loss = %.4f\t m_imW_loss = %.4f\t m_f_loss = %.4f\t ' % (epoch, cfgs.inpt_lr, (total_loss / t_count), (total_im_warp_loss / t_count), (total_flow_loss / t_count))) print('Take time %3.1f' % (time.time() - t0)) return step
def valid_one_epoch(self, sess, data_loader, data_num, epoch, step): ''' Generate once of image which remove instruments. Args: data_loader: training or validation data_loader. data_num: number of data. ''' sum_acc, sum_acc_iou, sum_acc_ellip, total_loss = 0, 0, 0, 0 count = 0 t0 = time.time() mean_acc, mean_acc_iou, mean_acc_label, mean_acc_ellip = 0, 0, 0, 0 self.ellip_acc, self.accu, self.accu_iou, loss = 0, 0, 0, 0 for count in range(1, data_num): step += 1 images, images_da, mask, fn, flag, mask_sum = data_loader.get_next_sequence( ) if flag == False: print(fn) # Can't find prev data. continue # Origin optical flow for frame in range(1, cfgs.seq_frames): im, last_im = images[frame], images[frame - 1] flow = sess.run(self.flow_tensor, feed_dict={ self.flow_img0: im, self.flow_img1: last_im }) # Da optical flow for frame in range(1, cfgs.seq_frames): im, last_im = images_da[frame], images_da[frame - 1] flow_da = sess.run(self.flow_tensor, feed_dict={ self.flow_img0: im, self.flow_img1: last_im }) normal_flow, normal_max_v = normal_data(flow) inpt_input, flag_ = concat_data(normal_flow, mask, cfgs.grid) if flag_ == False: print(fn) #After grid, area of mask = 0 continue loss, inpt_warped_im, inpt_flow, pred_flow =sess.run([self.loss, \ self.inpt_warped_im, self.inpt_pred_flow, self.pred_complete_flow],\ feed_dict={self.flow_img1: last_im, self.flow_img0: im, self.ori_flow: flow, self.max_v: normal_max_v, self.inpt_data: inpt_input}) #if count % 20 == 0: if True: self.view(np.expand_dims(fn, 0), inpt_warped_im, images_da[cfgs.seq_frames - 2], images_da[cfgs.seq_frames - 1], step, f_path='valid') self.view_flow_one(flow[0], fn, step, f_path='valid') self.view_flow_one(inpt_flow[0], fn, step, '_inpt', 'valid') self.view_flow_one(flow_da[0], fn, step, '_da', 'valid') self.view_flow_one(pred_flow[0], fn, step, 'complete', f_path='valid') #2. calculate accurary self.ellip_acc = 0 sum_acc += self.accu sum_acc_iou += self.accu_iou sum_acc_ellip += self.ellip_acc mean_acc = sum_acc / count mean_acc_iou = sum_acc_iou / count mean_acc_ellip = sum_acc_ellip / count #3. calculate loss total_loss += loss #4. time consume time_consumed = time.time() - t0 time_per_batch = time_consumed / count #5. print line = 'Valid epoch %2d\t lr = %g\t step = %4d\t count = %4d\t loss = %.4f\t m_loss=%.4f\t max_v = %.2f\t fn = %s\t time = %.2f' % ( epoch, cfgs.inpt_lr, step, count, loss, (total_loss / count), normal_max_v, fn, time_per_batch) utils.clear_line(len(line)) print('\r' + line, end='') #End one epoch #count -= 1 print( '\nepoch %5d\t learning_rate = %g\t mean_loss = %.4f\t train_acc = %.2f%%\t train_iou_acc = %.2f%%\t train_ellip_acc = %.2f' % (epoch, cfgs.inpt_lr, (total_loss / count), (sum_acc / count), (sum_acc_iou / count), (sum_acc_ellip / count))) print('Take time %3.1f' % (time.time() - t0)) return step