def test(test_loader, model, configs): batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') # switch to evaluate mode model.train() # if use model.val(), the performance become worse with torch.no_grad(): start_time = time.time() for batch_idx, (origin_imgs, resized_imgs, org_ball_pos_xy, global_ball_pos_xy, event_class, target_seg) in enumerate( tqdm(test_loader)): data_time.update(time.time() - start_time) batch_size = resized_imgs.size(0) target_seg = target_seg.to(configs.device, non_blocking=True) resized_imgs = resized_imgs.to(configs.device, non_blocking=True).float() # compute output if 'local' in configs.tasks: origin_imgs = origin_imgs.to(configs.device, non_blocking=True).float() pred_ball_global, pred_ball_local, pred_events, pred_seg, local_ball_pos_xy, total_loss, _ = model(origin_imgs, resized_imgs, org_ball_pos_xy, global_ball_pos_xy, event_class, target_seg) else: pred_ball_global, pred_ball_local, pred_events, pred_seg, local_ball_pos_xy, total_loss, _ = model(None, resized_imgs, org_ball_pos_xy, global_ball_pos_xy, event_class, target_seg) print('total_loss: {}'.format(total_loss.item())) # Transfer output to cpu pred_ball_global = pred_ball_global.cpu().numpy() global_ball_pos_xy = global_ball_pos_xy.numpy() if pred_ball_local is not None: pred_ball_local = pred_ball_local.cpu().numpy() local_ball_pos_xy = local_ball_pos_xy.cpu().numpy() # Ground truth of the local stage if pred_events is not None: pred_events = pred_events.cpu().numpy() if pred_seg is not None: pred_seg = pred_seg.cpu().numpy() target_seg = target_seg.cpu().numpy() org_ball_pos_xy = org_ball_pos_xy.numpy() seg_thresh = 0.5 event_thresh = 0.5 events_idx_to_names = { 0: 'bounce', 1: 'net', 2: 'empty' } fig, axes = plt.subplots(nrows=batch_size, ncols=2, figsize=(10, 5)) plt.tight_layout() axes.ravel() saved_dir = '../../docs/test_output_full' if not os.path.isdir(saved_dir): os.makedirs(saved_dir) for sample_idx in range(batch_size): w, h = configs.input_size # Get target sample_org_ball_pos_xy = org_ball_pos_xy[sample_idx] sample_global_ball_pos_xy = global_ball_pos_xy[sample_idx] # Target # Process the global stage sample_pred_ball_global = pred_ball_global[sample_idx] sample_pred_ball_global[sample_pred_ball_global < configs.thresh_ball_pos_mask] = 0. sample_pred_ball_global_x = np.argmax(sample_pred_ball_global[:w]) sample_pred_ball_global_y = np.argmax(sample_pred_ball_global[w:]) print('Global stage: (x, y) - org: ({}, {}), gt = ({}, {}), prediction = ({}, {})'.format( sample_org_ball_pos_xy[0], sample_org_ball_pos_xy[1], sample_global_ball_pos_xy[0], sample_global_ball_pos_xy[1], sample_pred_ball_global_x, sample_pred_ball_global_y)) # Process event stage if pred_events is not None: sample_target_event = event_class[sample_idx].item() sample_pred_event = (pred_events[sample_idx] > event_thresh).astype(np.int) print('Event stage: gt = {}, prediction: {}'.format(sample_target_event, pred_events[sample_idx])) if pred_seg is not None: sample_target_seg = target_seg[sample_idx].transpose(1, 2, 0) sample_pred_seg = pred_seg[sample_idx].transpose(1, 2, 0) print('Segmentation: Shape sample_target_seg: {}, sample_pred_seg: {}'.format( sample_target_seg.shape, sample_pred_seg.shape)) print('Segmentation: Max values sample_target_seg: {}, sample_pred_seg: {}'.format( sample_target_seg.max(), sample_pred_seg.max())) print('Before cast Segmentation sample_target_seg R: {}, G: {}, B: {}'.format(sample_target_seg[:, :, 0].sum(), sample_target_seg[:, :, 1].sum(), sample_target_seg[:, :, 2].sum())) print('Before cast Segmentation sample_pred_seg R: {}, G: {}, B: {}'.format( sample_pred_seg[:, :, 0].sum(), sample_pred_seg[:, :, 1].sum(), sample_pred_seg[:, :, 2].sum())) sample_target_seg = sample_target_seg.astype(np.int) sample_pred_seg = (sample_pred_seg > seg_thresh).astype(np.int) print('After Segmentation sample_target_seg R: {}, G: {}, B: {}'.format(sample_target_seg[:, :, 0].sum(), sample_target_seg[:, :, 1].sum(), sample_target_seg[:, :, 2].sum())) print('After Segmentation sample_pred_seg R: {}, G: {}, B: {}'.format( sample_pred_seg[:, :, 0].sum(), sample_pred_seg[:, :, 1].sum(), sample_pred_seg[:, :, 2].sum())) axes[2 * sample_idx].imshow(sample_target_seg * 255) axes[2 * sample_idx + 1].imshow(sample_pred_seg * 255) # title target_title = 'target seg' pred_title = 'pred seg' if pred_events is not None: target_title += ', event: {}'.format(events_idx_to_names[sample_target_event]) pred_title += ', is bounce: {}, is net: {}'.format(sample_pred_event[0], sample_pred_event[1]) axes[2 * sample_idx].set_title(target_title) axes[2 * sample_idx + 1].set_title(pred_title) plt.savefig( os.path.join(saved_dir, 'batch_idx_{}_sample_idx_{}.jpg'.format(batch_idx, sample_idx))) batch_time.update(time.time() - start_time) start_time = time.time() print('Done testing')
def eval(_run, _log): cfg = edict(_run.config) torch.manual_seed(cfg.seed) np.random.seed(cfg.seed) random.seed(cfg.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0") device1 = torch.device("cuda:1") checkpoint_dir = os.path.join('experiments/predict', str(_run._id), 'checkpoints') if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) angle_net = AngleNet(cfg.model) contact_net = ContactNet(cfg.model) seg_net = SegNet(cfg.model) if not cfg.resume_angle == 'None': model_dict = torch.load(cfg.resume_angle) angle_net.load_state_dict(model_dict) if not cfg.resume_contact == 'None': model_dict = torch.load(cfg.resume_contact) contact_net.load_state_dict(model_dict) if not cfg.resume_seg == 'None': model_dict = torch.load(cfg.resume_seg) seg_net.load_state_dict(model_dict) # load nets into gpu if cfg.num_gpus > 1 and torch.cuda.is_available(): angle_net = torch.nn.DataParallel(angle_net) contact_net = torch.nn.DataParallel(contact_net) seg_net = torch.nn.DataParallel(seg_net) angle_net.to(device) contact_net.to(device1) seg_net.to(device) if cfg.input_method == "planercnn": val_dataset = PlaneDataset(cfg.dataset, split='test', random=False, evaluation=True) elif cfg.input_method == "planeae": val_dataset = PlaneDatasetAE(cfg.dataset, split='test', random=False, evaluation=True) else: print('input method ' + cfg.input_method + ' not supported!') exit() val_loader = data.DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=cfg.dataset.num_workers) use_gt_relation = False write_relation = False assess_seg = False depth_only = True angle_net.eval() contact_net.eval() seg_net.eval() angle_accuracies = AverageMeter() contact_accuracies = AverageMeter() contact_ious = AverageMeter() stat_parallel = np.zeros(4) # classification precision recall stat_ortho = np.zeros(4) stat_contact = np.zeros(4) normal_errors = [AverageMeter(), AverageMeter(), AverageMeter()] normal_diff_errors = [AverageMeter(), AverageMeter(), AverageMeter()] depth_errors = [AverageMeter(), AverageMeter(), AverageMeter()] offset_errors = [AverageMeter(), AverageMeter(), AverageMeter()] contact_depth_errors = [AverageMeter(), AverageMeter(), AverageMeter()] with torch.no_grad(): for iter, sample in enumerate(val_loader): if iter == 100: break sceneIndex = sample["sceneIndex"].item() imageIndex = sample["imageIndex"].item() save_path = os.path.join('experiments/predict', str( _run._id)) + f'/results/{iter}/' if not os.path.exists(save_path): os.makedirs(save_path) image = sample["image"][0] #3x224x224 #pd_points = sample["pd_points"][0] #3x224x224 tg_planes = sample["tg_planes"][0] pd_planes = sample["pd_planes"][0] #planenumx4 matched_single = sample["matched_single"][0] pd_masks_small = sample["pd_masks_small"][0] #planenumx224x224 pd_masks_small_c = sample["pd_masks_small_c"][0] segment_depths_small = sample["segment_depths_small"][0] gt_angle = sample["gt_angle"][0] gt_contact = sample["gt_contact"][0] gt_contactline = sample["gt_contactline"][0] matched_pair = sample["matched_pair"][0] matched_contact = sample["matched_contact"][0] planepair_index = sample["planepair_index"][0] ori_image = sample["ori_image"][0] camera = sample["camera"][0] ori_pd_points = sample["ori_pd_points"][0] pd_masks = sample["pd_masks"][0] # planenumx480x640 ransac_masks = sample["ransac_masks"][0] ransac_planes = sample["ransac_planes"][0] sensor_depth = sample["sensor_depth"][0] tg_masks = sample["tg_masks"][0] image = image.to(device) #pd_points = pd_points.to(device) pd_masks_small = pd_masks_small.to(device) pd_masks_small_c = pd_masks_small_c.to(device) segment_depths_small = segment_depths_small.to(device) gt_angle = gt_angle.to(device) gt_contact = gt_contact.to(device1) # build the input to network input_tensor, seg_tensor = [], [] planepair_index = planepair_index.numpy() planepair_num = planepair_index.shape[0] pd_planes = pd_planes.numpy() tg_planes = tg_planes.numpy() ori_image = ori_image.numpy().astype(np.uint8) ori_magnitude = get_magnitude(ori_image) matched_single = matched_single.numpy().astype(np.bool) matched_pair = matched_pair.numpy().astype(np.bool) matched_contact = matched_contact.numpy().astype(np.bool) for ppindex in planepair_index: p, q = ppindex pm = pd_masks_small[p:p + 1] qm = pd_masks_small[q:q + 1] pdepth = segment_depths_small[p:p + 1] qdepth = segment_depths_small[q:q + 1] dot_map = torch.full_like(pm, np.abs( np.dot(pd_planes[p, :3], pd_planes[q, :3])), device=device) if cfg.model.input_channel == 8: input_tensor.append( torch.cat((image, pm, qm, pdepth, qdepth, dot_map), dim=0)) elif cfg.model.input_channel == 7: input_tensor.append( torch.cat((image, pm, qm, pdepth, qdepth), dim=0)) elif cfg.model.input_channel == 5: input_tensor.append(torch.cat((image, pm, qm), dim=0)) else: input_tensor.append(torch.cat((pm, qm), dim=0)) input_tensor = torch.stack(input_tensor) print(iter, input_tensor.size()) # inference try: angle_prob = angle_net(input_tensor) except: half_num = int(planepair_num / 2.0 + 0.5) angle_prob0 = angle_net(input_tensor[0:half_num, :, :, :]) angle_prob1 = angle_net( input_tensor[half_num:planepair_num, :, :, :]) angle_prob = torch.cat((angle_prob0, angle_prob1), dim=0) input_tensor = input_tensor.to(device1) try: contact_prob, contactline_prob = contact_net(input_tensor) except: half_num = int(planepair_num / 2.0 + 0.5) contact_prob0, contactline_prob0 = contact_net( input_tensor[0:half_num, :, :, :]) contact_prob1, contactline_prob1 = contact_net( input_tensor[half_num:planepair_num, :, :, :]) contact_prob = torch.cat((contact_prob0, contact_prob1), dim=0) contactline_prob = torch.cat( (contactline_prob0, contactline_prob1), dim=0) del input_tensor # relation assessment acc_angle, pred_angle = accuracy(angle_prob, gt_angle, angle=True) acc_contact, pred_contact = accuracy(contact_prob, gt_contact, angle=False) matched_pair_num = matched_pair.sum() if matched_pair_num > 0: acc_angle, acc_contact = np.sum( acc_angle[matched_pair]) * 100 / matched_pair_num, np.sum( acc_contact[matched_pair]) * 100 / matched_pair_num angle_accuracies.update(acc_angle, matched_pair_num) contact_accuracies.update(acc_contact, matched_pair_num) camera = camera.numpy() ranges2d = get_ranges2d(camera) contactline_prob = contactline_prob.cpu().numpy().squeeze() pd_masks = pd_masks.numpy() gt_contactline = gt_contactline.numpy() gt_angle = gt_angle.cpu().numpy() gt_contact = gt_contact.cpu().numpy() ## precision and recall if matched_pair_num > 0: #pred_angle, pred_contact = eval_relation_baseline(planepair_index, pd_planes, pd_masks) stat_parallel += comp_precision_recall( pred_angle[matched_pair] == 1, gt_angle[matched_pair] == 1) stat_ortho += comp_precision_recall( pred_angle[matched_pair] == 0, gt_angle[matched_pair] == 0) stat_contact += comp_precision_recall( pred_contact[matched_contact] == 1, gt_contact[matched_contact] == 1) iou_flag = (gt_contact == 1) & matched_contact if np.sum(iou_flag) > 0: pred_iou = comp_conrel_iou(gt_contact, gt_contactline, contactline_prob) contact_ious.update(np.mean(pred_iou[iou_flag]), np.sum(iou_flag)) contact_list = [] pair_areas = np.zeros((planepair_num, 1)) contact_line_probs = [] for i, ppindex in enumerate(planepair_index): p, q = ppindex pm = pd_masks[p] qm = pd_masks[q] #pair_areas[i] = pm.sum() + qm.sum() pair_areas[i] = pm.sum() / 640 * qm.sum() / 480 if write_relation: tmp_img = ori_image.copy() tmp_img[pm > 0.5, 0] = 255 tmp_img[qm > 0.5, 2] = 255 if (gt_angle[i] if use_gt_relation else pred_angle[i]) == 1: cv2.imwrite(f'{save_path}para_{i}.png', tmp_img) if (gt_contact[i] if use_gt_relation else pred_contact[i]) == 1: cv2.imwrite(f'{save_path}coplane_{i}.png', tmp_img) elif (gt_angle[i] if use_gt_relation else pred_angle[i]) == 0: cv2.imwrite(f'{save_path}ortho_{i}.png', tmp_img) if (gt_contact[i] if use_gt_relation else pred_contact[i]) == 0: continue gt_mask = gt_contactline[i] re_mask = cv2.resize(contactline_prob[i], dsize=(640, 480)) if use_gt_relation: re_mask = gt_mask contact_line_probs.append(re_mask) mask_thres = 0.5 if pred_angle[i] == 1: mask_thres = 0.25 ylist, xlist = extract_line2d(re_mask, mask_thres) raydirs = ranges2d[ylist, xlist, :] contact_list.append([p, q, raydirs, i]) if write_relation: black_img = np.zeros((480, 640, 3), dtype=np.uint8) black_img[:, :, 0] = re_mask * 255 black_img[:, :, 1] = gt_mask * 255 black_img[re_mask > mask_thres, 2] = 255 tmp_img[re_mask > 0.25, 1] = 255 cv2.imwrite(f'{save_path}contact_{i}.png', np.concatenate([tmp_img, black_img], 1)) contact_line_probs = np.asarray(contact_line_probs) # optimization if use_gt_relation: flag_para = (matched_pair & (gt_angle == 1)) flag_ortho = (matched_pair & (gt_angle == 0)) flag_contact = (matched_contact & (gt_contact == 1)) para_list = planepair_index[flag_para, :] ortho_list = planepair_index[flag_ortho, :] para_weight = pair_areas[flag_para, :] ortho_weight = pair_areas[flag_ortho, :] contact_weight = pair_areas[flag_contact, :] coplane_list = planepair_index[flag_para & flag_contact, :] coplane_weight = pair_areas[flag_para & flag_contact, :] else: flag_para, flag_ortho, flag_contact = pred_angle == 1, pred_angle == 0, pred_contact == 1 para_list = planepair_index[flag_para, :] para_weight = pair_areas[flag_para, :] ortho_list = planepair_index[flag_ortho, :] ortho_weight = pair_areas[flag_ortho, :] contact_weight = pair_areas[flag_contact, :] coplane_list = planepair_index[flag_para & flag_contact, :] coplane_weight = pair_areas[flag_para & flag_contact, :] para_weight /= np.sum(para_weight) ortho_weight /= np.sum(ortho_weight) contact_weight /= np.sum(contact_weight) coplane_weight /= np.sum(coplane_weight) ori_pd_points = ori_pd_points.numpy() point_list = [] for i in range(pd_masks.shape[0]): point_list.append(ori_pd_points[pd_masks[i] > 0.5, :]) # --------------------------------- # -- solve the plane parameters by optimization cv2.imwrite(f"{save_path}image.jpg", ori_image) sensor_depth = sensor_depth.numpy() visualize(ori_image, sensor_depth, [], camera, 'point_sensor', save_path, depthonly=depth_only, sensor=True) ransac_masks = ransac_masks.numpy() ransac_planes = ransac_planes.numpy() p_depth_gt = visualize(ori_image, ransac_masks, ransac_planes, camera, 'point_gt', save_path, depthonly=depth_only) seg_gt = blend_image_mask(ori_image, ransac_masks, thres=0.5) cv2.imwrite(f"{save_path}seg_gt.png", seg_gt) p_depth_planercnn = visualize(ori_image, pd_masks, pd_planes, camera, 'point_planercnn', save_path, depthonly=depth_only) seg_planercnn = blend_image_mask(ori_image, pd_masks, thres=0.5) cv2.imwrite(f"{save_path}seg_planercnn.png", seg_planercnn) alpha = np.array([1., 0., 10., 1., 1., 0., 0.]) re_planes_angle = plane_minimize(pd_planes, point_list, para_list, para_weight, ortho_list, ortho_weight, contact_list, contact_weight, coplane_list, coplane_weight, alpha) # p_depth_angle = visualize(ori_image, pd_masks, re_planes_angle, camera, 'point_result_angle', save_path, depthonly=depth_only) #alpha = np.array([1.,0.,10.,1.,1.,10.,0.])# ae alpha = np.array([1., 10., 10., 1., 1., 1., 0.]) re_planes = plane_minimize(pd_planes, point_list, para_list, para_weight, ortho_list, ortho_weight, contact_list, contact_weight, coplane_list, coplane_weight, alpha) #p_depth_contact = visualize(ori_image, pd_masks, re_planes, camera, 'point_result', save_path, depthonly=depth_only) # --------------------------------- # -- split the contact using optimized 3d parameters contact_split, line_equs, line_flags, along_line_mask = expand_masks1( re_planes, contact_list, contact_line_probs, ranges2d, pd_masks) seg_contact = blend_image_mask(ori_image, contact_split) cv2.imwrite(f"{save_path}seg_contact.png", seg_contact) cv2.imwrite(f"{save_path}seg_contact_line.png", along_line_mask.astype(np.uint8) * 255) #cv2.imwrite(f"{save_path}seg_expanded_line.png", expanded_seg*0.7+line_img*0.3) # refine segmentation by network and contact image = image.repeat(pd_masks_small.size(0), 1, 1, 1) contact_split_small = np.zeros((contact_split.shape[0], 224, 224), dtype=np.float32) for i in range(contact_split.shape[0]): contact_split_small[i] = cv2.resize(contact_split[i], dsize=(224, 224)) contact_split_small = torch.cuda.FloatTensor(contact_split_small) input_tensor = torch.cat([ image, pd_masks_small.unsqueeze(1), contact_split_small.unsqueeze(1) ], dim=1) seg_prob_small = seg_net(input_tensor) del input_tensor, pd_masks_small, contact_split_small seg_prob_small = seg_prob_small.cpu().numpy().squeeze() seg_prob = np.zeros((seg_prob_small.shape[0], 480, 640)) for i, m in enumerate(seg_prob_small): seg_prob[i] = cv2.resize(m, dsize=(640, 480)) seg_prob = clean_prob_mask(seg_prob) seg_refined = blend_image_mask(ori_image, seg_prob, thres=0.5) cv2.imwrite(f"{save_path}seg_refined.png", seg_refined) p_depth_all = visualize(ori_image, seg_prob, re_planes, camera, 'point_result_ex', save_path, depthonly=depth_only) # -------------------------------- # -- do the evaluation # 1. evaluate depth tg_masks = tg_masks.numpy() comp_depth_error(ori_image, camera, tg_masks[matched_single], tg_planes[matched_single], pd_planes[matched_single], re_planes_angle[matched_single], re_planes[matched_single], depth_errors, save_path) # 2. evalute normal comp_parameter_error(tg_planes[matched_single], pd_planes[matched_single], re_planes_angle[matched_single], re_planes[matched_single], tg_masks[matched_single], normal_errors, offset_errors) # 3. contact depth consistency flag_contact = (gt_contact == 1) comp_contact_error(gt_contactline[flag_contact], planepair_index[flag_contact], ranges2d, pd_planes, re_planes_angle, re_planes, contact_depth_errors) #comp_contact_error(gt_contactline, planepair_index, ranges2d, pd_planes, re_planes_angle, re_planes, contact_depth_errors, gt_contact, gt_angle, tg_planes, pd_masks) ## sensor depth semantic_gt = ransac_masks.max(0) semantic_pd = pd_masks.max(0) semantic_re = seg_prob.max(0) p_depth_gt[semantic_gt < 0.5] = 0. p_depth_planercnn[semantic_pd < 0.5] = 0. p_depth_all[semantic_re < 0.5] = 0. cv2.imwrite(f"{save_path}depth_sensor.png", drawDepthImage(sensor_depth)) cv2.imwrite(f"{save_path}depth_gt.png", drawDepthImage(p_depth_gt)) cv2.imwrite(f"{save_path}depth_prcnn.png", drawDepthImage(p_depth_planercnn)) cv2.imwrite(f"{save_path}depth_all.png", drawDepthImage(p_depth_all)) # evaluate pairwise angular difference diff_flag = matched_pair #&(gt_angle!=2) if np.sum(diff_flag) > 0: diff_areas = pair_areas[diff_flag, :].reshape(-1) normal_diff_errors[0].update( eval_planepair_diff(pd_planes, tg_planes, planepair_index[diff_flag, :], diff_areas), np.sum(diff_areas)) normal_diff_errors[1].update( eval_planepair_diff(re_planes_angle, tg_planes, planepair_index[diff_flag, :], diff_areas), np.sum(diff_areas)) normal_diff_errors[2].update( eval_planepair_diff(re_planes, tg_planes, planepair_index[diff_flag, :], diff_areas), np.sum(diff_areas)) print("-----------geometry accuracy--------------") print( f'normal error: planercnn {normal_errors[0].avg}, angle {normal_errors[1].avg}, all {normal_errors[2].avg}' ) print( f'offset error: planercnn {offset_errors[0].avg}, angle {offset_errors[1].avg}, all {offset_errors[2].avg}' ) print( f'depth error: planercnn {depth_errors[0].avg}, angle {depth_errors[1].avg}, all {depth_errors[2].avg}' ) print( f'contact depth error: planercnn {contact_depth_errors[0].avg}, angle {contact_depth_errors[1].avg}, all {contact_depth_errors[2].avg}' ) print( f'angular diff error: planercnn {normal_diff_errors[0].avg}, angle {normal_diff_errors[1].avg}, all {normal_diff_errors[2].avg}' ) print('\n---------relation classificaiton----------') print(angle_accuracies.avg, contact_accuracies.avg) precision, recall = stat_parallel[0] / stat_parallel[1], stat_parallel[ 2] / stat_parallel[3] print( f'parallel precision: {precision} recall: {recall} f1score: {2*(recall*precision)/(recall+precision)}' ) precision, recall = stat_ortho[0] / stat_ortho[1], stat_ortho[ 2] / stat_ortho[3] print( f'ortho precision: {precision} recall: {recall} f1score: {2*(recall*precision)/(recall+precision)}' ) precision, recall = stat_contact[0] / stat_contact[1], stat_contact[ 2] / stat_contact[3] print( f'contact precision: {precision} recall: {recall} f1score: {2*(recall*precision)/(recall+precision)}' ) print(f'contact iou: {contact_ious.avg}')
def validate(val_loader, net, criterion, optim, epoch, calc_metrics=True, dump_assets=False, dump_all_images=False): """ Run validation for one epoch :val_loader: data loader for validation :net: the network :criterion: loss fn :optimizer: optimizer :epoch: current epoch :calc_metrics: calculate validation score :dump_assets: dump attention prediction(s) images :dump_all_images: dump all images, not just N """ dumper = ImageDumper(val_len=len(val_loader), dump_all_images=dump_all_images, dump_assets=dump_assets, dump_for_auto_labelling=args.dump_for_auto_labelling, dump_for_submission=args.dump_for_submission) net.eval() val_loss = AverageMeter() iou_acc = 0 for val_idx, data in enumerate(val_loader): input_images, labels, img_names, _ = data if args.dump_for_auto_labelling or args.dump_for_submission: submit_fn = '{}.png'.format(img_names[0]) if val_idx % 20 == 0: logx.msg(f'validating[Iter: {val_idx + 1} / {len(val_loader)}]') if os.path.exists(os.path.join(dumper.save_dir, submit_fn)): continue # Run network assets, _iou_acc = \ eval_minibatch(data, net, criterion, val_loss, calc_metrics, args, val_idx) iou_acc += _iou_acc input_images, labels, img_names, _ = data dumper.dump({'gt_images': labels, 'input_images': input_images, 'img_names': img_names, 'assets': assets}, val_idx) if val_idx > 5 and args.test_mode: break if val_idx % 20 == 0: logx.msg(f'validating[Iter: {val_idx + 1} / {len(val_loader)}]') was_best = False if calc_metrics: was_best = eval_metrics(iou_acc, args, net, optim, val_loss, epoch) # Write out a summary html page and tensorboard image table if not args.dump_for_auto_labelling and not args.dump_for_submission: dumper.write_summaries(was_best)
def train_iae(trainloader, model, class_name, testloader, y_train, device, args): """ model train function. :param trainloader: :param model: :param class_name: :param testloader: :param y_train: numpy array, sample normal/abnormal labels, [1 1 1 1 0 0] like, original sample size. :param device: cpu or gpu:0/1/... :param args: :return: """ global_step = 0 losses = AverageMeter() l2_losses = AverageMeter() svdd_losses = AverageMeter() start_time = time.time() epoch_time = AverageMeter() svdd_loss = torch.tensor(0, device=device) R = torch.tensor(0, device=device) c = torch.randn(256, device=device) for epoch in range(1, args.epochs + 1): model.train() need_hour, need_mins, need_secs = convert_secs2time( epoch_time.avg * (args.epochs - epoch)) need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format( need_hour, need_mins, need_secs) print('{:3d}/{:3d} ----- {:s} {:s}'.format(epoch, args.epochs, time_string(), need_time)) mse = nn.MSELoss(reduction='mean') # default lr = 0.1 / pow(2, np.floor(epoch / args.lr_schedule)) logger.add_scalar(class_name + "/lr", lr, epoch) if args.optimizer == 'sgd': optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=args.weight_decay) elif args.optimizer == 'adam': optimizer = optim.Adam(model.parameters(), eps=1e-7, weight_decay=args.weight_decay) else: print('not implemented.') for batch_idx, (input, _, _) in enumerate(trainloader): optimizer.zero_grad() input = input.to(device) reps, output = model(input) if epoch > args.pretrain_epochs: dist = torch.sum((reps - c)**2, dim=1) scores = dist - R**2 svdd_loss = args.para_lambda * ( R**2 + (1 / args.para_nu) * torch.mean(torch.max(torch.zeros_like(scores), scores))) l2_loss = mse(input, output) loss = l2_loss + svdd_loss l2_losses.update(l2_loss.item(), 1) svdd_losses.update(svdd_loss.item(), 1) losses.update(loss.item(), 1) logger.add_scalar(class_name + '/l2_loss', l2_losses.avg, global_step) logger.add_scalar(class_name + '/svdd_loss', svdd_losses.avg, global_step) logger.add_scalar(class_name + '/loss', losses.avg, global_step) logger.add_scalar(class_name + '/R', R.data, global_step) global_step = global_step + 1 loss.backward() optimizer.step() # Update hypersphere radius R on mini-batch distances if epoch > args.pretrain_epochs: R.data = torch.tensor(get_radius(dist, args.para_nu), device=device) # print losses print('Epoch: [{} | {}], loss: {:.4f}'.format(epoch, args.epochs, losses.avg)) # log images if epoch % args.log_img_steps == 0: os.makedirs(os.path.join(RESULTS_DIR, class_name), exist_ok=True) fpath = os.path.join(RESULTS_DIR, class_name, 'pretrain_epoch_' + str(epoch) + '.png') visualize(input, output, fpath, num=32) # test while training if epoch % args.log_auc_steps == 0: rep, losses_result = test(testloader, model, class_name, args, device, epoch) centroid = torch.mean(rep, dim=0, keepdim=True) losses_result = losses_result - losses_result.min() losses_result = losses_result / (1e-8 + losses_result.max()) scores = 1 - losses_result auroc_rec = roc_auc_score(y_train, scores) _, p = dec_loss_fun(rep, centroid) score_p = p[:, 0] auroc_dec = roc_auc_score(y_train, score_p) print("Epoch: [{} | {}], auroc_rec: {:.4f}; auroc_dec: {:.4f}". format(epoch, args.epochs, auroc_rec, auroc_dec)) logger.add_scalar(class_name + '/auroc_rec', auroc_rec, epoch) logger.add_scalar(class_name + '/auroc_dec', auroc_dec, epoch) # initial centroid c before pretrain finished if epoch == args.pretrain_epochs: rep, losses_result = test(testloader, model, class_name, args, device, epoch) c = update_center_c(rep) # time epoch_time.update(time.time() - start_time) start_time = time.time()
def train(train_loader, model, criterion, optimizer): ''' 模型训练 :param train_loader: :param model: :param criterion: :param optimizer: :return: ''' # 定义保存更新变量 data_time = AverageMeter() batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() end = time.time() ################# # train the model ################# model.train() # 训练每批数据,然后进行模型的训练 ## 定义bar 变量 bar = Bar('Processing', max=len(train_loader)) for batch_index, (inputs, targets) in enumerate(train_loader): data_time.update(time.time() - end) # move tensors to GPU if cuda is_available inputs, targets = inputs.to(device), targets.to(device) # 在进行反向传播之前,我们使用zero_grad方法清空梯度 optimizer.zero_grad() # 模型的预测 outputs = model(inputs) # 计算loss loss = criterion(outputs, targets) # backward pass: loss.backward() # perform as single optimization step (parameter update) optimizer.step() # 计算acc和变量更新 prec1, _ = accuracy(outputs.data, targets.data, topk=(1, 1)) losses.update(loss.item(), inputs.size(0)) top1.update(prec1.item(), inputs.size(0)) batch_time.update(time.time() - end) end = time.time() # plot progress ## 把主要的参数打包放进bar中 # plot progress bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f}'.format( batch=batch_index + 1, size=len(train_loader), data=data_time.val, bt=batch_time.val, total=bar.elapsed_td, eta=bar.eta_td, loss=losses.avg, top1=top1.avg) bar.next() bar.finish() return (losses.avg, top1.avg)
def validate(val_loader, net, criterion, optim, curr_epoch, writer): """ Runs the validation loop after each training epoch val_loader: Data loader for validation net: thet network criterion: loss fn optimizer: optimizer curr_epoch: current epoch writer: tensorboard writer return: val_avg for step function if required """ net.eval() val_loss = AverageMeter() iou_acc = 0 dump_images = [] for val_idx, data in enumerate(val_loader): # input = torch.Size([1, 3, 713, 713]) # gt_image = torch.Size([1, 713, 713]) inputs, gt_image, img_names = data assert len(inputs.size()) == 4 and len(gt_image.size()) == 3 assert inputs.size()[2:] == gt_image.size()[1:] batch_pixel_size = inputs.size(0) * inputs.size(2) * inputs.size(3) inputs, gt_cuda = inputs.cuda(), gt_image.cuda() with torch.no_grad(): output = net(inputs) # output = (1, 19, 713, 713) assert output.size()[2:] == gt_image.size()[1:] assert output.size()[1] == args.dataset_cls.num_classes val_loss.update(criterion(output, gt_cuda).item(), batch_pixel_size) # Collect data from different GPU to a single GPU since # encoding.parallel.criterionparallel function calculates distributed loss # functions predictions = output.data.max(1)[1].cpu() # Logging if val_idx % 20 == 0: if args.local_rank == 0: logging.info("validating: %d / %d", val_idx + 1, len(val_loader)) if val_idx > 10 and args.test_mode: break # Image Dumps if val_idx < 10: dump_images.append([gt_image, predictions, img_names]) iou_acc += fast_hist(predictions.numpy().flatten(), gt_image.numpy().flatten(), args.dataset_cls.num_classes) del output, val_idx, data if args.apex: iou_acc_tensor = torch.cuda.FloatTensor(iou_acc) torch.distributed.all_reduce(iou_acc_tensor, op=torch.distributed.ReduceOp.SUM) iou_acc = iou_acc_tensor.cpu().numpy() if args.local_rank == 0: evaluate_eval(args, net, optim, val_loss, iou_acc, dump_images, writer, curr_epoch, args.dataset_cls) return val_loss.avg
def validate_one_epoch(val_loader, model, epoch, configs, logger): batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') losses = AverageMeter('Loss', ':.4e') progress = ProgressMeter(len(val_loader), [batch_time, data_time, losses], prefix="Validation - Epoch: [{}]".format(epoch)) # switch to evaluate mode model.eval() with torch.no_grad(): start_time = time.time() for batch_idx, (origin_imgs, resized_imgs, org_ball_pos_xy, global_ball_pos_xy, event_class, target_seg) in enumerate(tqdm(val_loader)): data_time.update(time.time() - start_time) batch_size = resized_imgs.size(0) target_seg = target_seg.to(configs.device, non_blocking=True) resized_imgs = resized_imgs.to(configs.device, non_blocking=True).float() # Only move origin_imgs to cuda if the model has local stage for ball detection if not configs.no_local: origin_imgs = origin_imgs.to(configs.device, non_blocking=True).float() # compute output pred_ball_global, pred_ball_local, pred_events, pred_seg, local_ball_pos_xy, total_loss, _ = model( origin_imgs, resized_imgs, org_ball_pos_xy, global_ball_pos_xy, event_class, target_seg) else: pred_ball_global, pred_ball_local, pred_events, pred_seg, local_ball_pos_xy, total_loss, _ = model( None, resized_imgs, org_ball_pos_xy, global_ball_pos_xy, event_class, target_seg) # For torch.nn.DataParallel case if (not configs.distributed) and (configs.gpu_idx is None): total_loss = torch.mean(total_loss) losses.update(total_loss.item(), batch_size) # measure elapsed time batch_time.update(time.time() - start_time) # Log message if logger is not None: if ((batch_idx + 1) % configs.print_freq) == 0: logger.info(progress.get_message(batch_idx)) start_time = time.time() return losses.avg
def validate_topn(val_loader, net, criterion, optim, epoch, args): """ Find worse case failures ... Only single GPU for now First pass = calculate TP, FP, FN pixels per image per class Take these stats and determine the top20 images to dump per class Second pass = dump all those selected images """ assert args.bs_val == 1 ###################################################################### # First pass ###################################################################### logx.msg('First pass') image_metrics = {} net.eval() val_loss = AverageMeter() iou_acc = 0 for val_idx, data in enumerate(val_loader): # Run network assets, _iou_acc = \ run_minibatch(data, net, criterion, val_loss, True, args, val_idx) # per-class metrics input_images, labels, img_names, _ = data fp, fn = metrics_per_image(_iou_acc) img_name = img_names[0] image_metrics[img_name] = (fp, fn) iou_acc += _iou_acc if val_idx % 20 == 0: logx.msg(f'validating[Iter: {val_idx + 1} / {len(val_loader)}]') if val_idx > 5 and args.test_mode: break eval_metrics(iou_acc, args, net, optim, val_loss, epoch) ###################################################################### # Find top 20 worst failures from a pixel count perspective ###################################################################### from collections import defaultdict worst_images = defaultdict(dict) class_to_images = defaultdict(dict) for classid in range(cfg.DATASET.NUM_CLASSES): tbl = {} for img_name in image_metrics.keys(): fp, fn = image_metrics[img_name] fp = fp[classid] fn = fn[classid] tbl[img_name] = fp + fn worst = sorted(tbl, key=tbl.get, reverse=True) for img_name in worst[:args.dump_topn]: fail_pixels = tbl[img_name] worst_images[img_name][classid] = fail_pixels class_to_images[classid][img_name] = fail_pixels msg = str(worst_images) logx.msg(msg) # write out per-gpu jsons # barrier # make single table ###################################################################### # 2nd pass ###################################################################### logx.msg('Second pass') attn_map = None for val_idx, data in enumerate(val_loader): in_image, gt_image, img_names, _ = data # Only process images that were identified in first pass if not args.dump_topn_all and img_names[0] not in worst_images: continue with torch.no_grad(): inputs = in_image.cuda() inputs = {'images': inputs, 'gts': gt_image} if cfg.MODEL.MSCALE: output, attn_map = net(inputs) else: output = net(inputs) output = torch.nn.functional.softmax(output, dim=1) prob_mask, predictions = output.data.max(1) predictions = predictions.cpu() # this has shape [bs, h, w] img_name = img_names[0] for classid in worst_images[img_name].keys(): err_mask = calc_err_mask(predictions.numpy(), gt_image.numpy(), cfg.DATASET.NUM_CLASSES, classid) class_name = cfg.DATASET_INST.trainid_to_name[classid] error_pixels = worst_images[img_name][classid] logx.msg(f'{img_name} {class_name}: {error_pixels}') img_names = [img_name + f'_{class_name}'] to_dump = { 'gt_images': gt_image, 'input_images': in_image, 'predictions': predictions.numpy(), 'err_mask': err_mask, 'prob_mask': prob_mask, 'img_names': img_names } if attn_map is not None: to_dump['attn_maps'] = attn_map # FIXME! # do_dump_images([to_dump]) html_fn = os.path.join(args.result_dir, 'best_images', 'topn_failures.html') from utils.results_page import ResultsPage ip = ResultsPage('topn failures', html_fn) for classid in class_to_images: class_name = cfg.DATASET_INST.trainid_to_name[classid] img_dict = class_to_images[classid] for img_name in sorted(img_dict, key=img_dict.get, reverse=True): fail_pixels = class_to_images[classid][img_name] img_cls = f'{img_name}_{class_name}' pred_fn = f'{img_cls}_prediction.png' gt_fn = f'{img_cls}_gt.png' inp_fn = f'{img_cls}_input.png' err_fn = f'{img_cls}_err_mask.png' prob_fn = f'{img_cls}_prob_mask.png' img_label_pairs = [(pred_fn, 'pred'), (gt_fn, 'gt'), (inp_fn, 'input'), (err_fn, 'errors'), (prob_fn, 'prob')] ip.add_table(img_label_pairs, table_heading=f'{class_name}-{fail_pixels}') ip.write_page() return val_loss.avg
def train_epoch(self, epoch_num): batch_time = AverageMeter() losses_edge = AverageMeter() losses_corner = AverageMeter() self.model.train() end = time.time() for iter_i, batch_data in enumerate(self.train_loader): image_inputs = batch_data['image'] if self.mode == 'corner': corner_target_maps = batch_data['corner_gt_map'] edge_target_maps = batch_data['edge_gt_map'] room_masks_map = batch_data['room_masks_map'] else: raise ValueError('Invalid mode {}'.format(self.mode)) mean_normal = batch_data['mean_normal'] # contour_image = batch_data['contour_image'] if self.configs.use_cuda: image_inputs = image_inputs.cuda() mean_normal = mean_normal.cuda() corner_target_maps = corner_target_maps.cuda() edge_target_maps = edge_target_maps.cuda() room_masks_map = room_masks_map.cuda() inputs = torch.cat([ image_inputs.unsqueeze(1), mean_normal, room_masks_map.unsqueeze(1) ], dim=1) corner_preds_logits, edge_preds_logits, edge_preds, corner_preds = self.model( inputs) # # mask the binning part, only predicting directions for places with corners loss_mask_c = corner_target_maps[:, 0, :, :].clone().unsqueeze( 1) * 4 + 1 loss_c = self.criterion(corner_preds_logits, corner_target_maps) loss_c = loss_c * loss_mask_c loss_c = loss_c.mean(2).mean(2).mean( 0).sum() # take mean over batch, H, W, sum over C loss_mask_e = edge_target_maps[:, 0, :, :].clone().unsqueeze( 1) * 4 + 1 loss_e = self.criterion(edge_preds_logits, edge_target_maps) loss_e = loss_e * loss_mask_e loss_e = loss_e.mean(2).mean(2).mean( 0).sum() # take mean over batch, H, W, sum over C loss = loss_e + loss_c losses_edge.update(loss_e.data, image_inputs.size(0)) losses_corner.update(loss_c.data, image_inputs.size(0)) self.optimizer.zero_grad() loss.backward() self.optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Edge pred Loss {loss1.val:.4f} ({loss1.avg:.4f})\t' 'Corner pred Loss {loss2.val:.4f} ({loss2.avg:.4f})'.format( epoch_num, iter_i, len(self.train_loader), batch_time=batch_time, loss1=losses_edge, loss2=losses_corner)) if iter_i % self.configs.visualize_iter == 0: viz_dir = os.path.join(self.configs.exp_dir, 'training_viz') gt_file_path = os.path.join( viz_dir, 'epoch_{}_iter_{}_gt.png'.format(epoch_num, iter_i)) gt_edge_file_path = os.path.join( viz_dir, 'epoch_{}_iter_{}_gt_edge.png'.format(epoch_num, iter_i)) heatmap_path = os.path.join( viz_dir, 'epoch_{}_iter_{}_preds.png'.format(epoch_num, iter_i)) heatmap_edge_path = os.path.join( viz_dir, 'epoch_{}_iter_{}_preds_edge.png'.format( epoch_num, iter_i)) # corner_edge_path = os.path.join(viz_dir, 'epoch_{}_iter_{}_corner_edge.png'.format(epoch_num, iter_i)) gt_map_path = os.path.join( viz_dir, 'epoch_{}_iter_{}_gt_corner_edge.png'.format( epoch_num, iter_i)) # simply use the first element in the batch corner_preds_np = corner_preds[0].detach().cpu().numpy() edge_preds_np = edge_preds[0].detach().cpu().numpy() edge_gt_np = edge_target_maps[0].cpu().numpy() corner_gt_np = corner_target_maps[0].cpu().numpy() if self.mode == 'corner': _, gt_corner_edge_map = get_corner_dir_map( corner_gt_np, 256) imsave(gt_map_path, gt_corner_edge_map) imsave(gt_file_path, corner_gt_np[0]) imsave(heatmap_path, corner_preds_np[0]) imsave(gt_edge_file_path, edge_gt_np[0]) imsave(heatmap_edge_path, edge_preds_np[0])
def validate(val_loader, dataset, net, criterion, optim, scheduler, curr_epoch, writer, curr_iter, save_pth=True): """ Runs the validation loop after each training epoch val_loader: Data loader for validation dataset: dataset name (str) net: thet network criterion: loss fn optimizer: optimizer curr_epoch: current epoch writer: tensorboard writer return: val_avg for step function if required """ net.eval() val_loss = AverageMeter() iou_acc = 0 error_acc = 0 dump_images = [] for val_idx, data in enumerate(val_loader): # input = torch.Size([1, 3, 713, 713]) # gt_image = torch.Size([1, 713, 713]) inputs, gt_image, img_names, _ = data if len(inputs.shape) == 5: B, D, C, H, W = inputs.shape inputs = inputs.view(-1, C, H, W) gt_image = gt_image.view(-1, 1, H, W) assert len(inputs.size()) == 4 and len(gt_image.size()) == 3 assert inputs.size()[2:] == gt_image.size()[1:] batch_pixel_size = inputs.size(0) * inputs.size(2) * inputs.size(3) inputs, gt_cuda = inputs.cuda(), gt_image.cuda() with torch.no_grad(): if args.use_wtloss: output, f_cor_arr = net(inputs, visualize=True) else: output = net(inputs) del inputs assert output.size()[2:] == gt_image.size()[1:] assert output.size()[1] == datasets.num_classes val_loss.update(criterion(output, gt_cuda).item(), batch_pixel_size) del gt_cuda # Collect data from different GPU to a single GPU since # encoding.parallel.criterionparallel function calculates distributed loss # functions predictions = output.data.max(1)[1].cpu() # Logging if val_idx % 20 == 0: if args.local_rank == 0: logging.info("validating: %d / %d", val_idx + 1, len(val_loader)) if val_idx > 10 and args.test_mode: break # Image Dumps if val_idx < 10: dump_images.append([gt_image, predictions, img_names]) iou_acc += fast_hist(predictions.numpy().flatten(), gt_image.numpy().flatten(), datasets.num_classes) del output, val_idx, data iou_acc_tensor = torch.cuda.FloatTensor(iou_acc) torch.distributed.all_reduce(iou_acc_tensor, op=torch.distributed.ReduceOp.SUM) iou_acc = iou_acc_tensor.cpu().numpy() if args.local_rank == 0: evaluate_eval(args, net, optim, scheduler, val_loss, iou_acc, dump_images, writer, curr_epoch, dataset, None, curr_iter, save_pth=save_pth) if args.use_wtloss: visualize_matrix(writer, f_cor_arr, curr_iter, '/Covariance/Feature-') return val_loss.avg
def train(): try: os.makedirs(opt.checkpoints_dir) except OSError: pass CNN.to(device) CNN.train() torchsummary.summary(CNN, (1, 28, 28)) ################################################ # Set loss function and Adam optimier ################################################ criterion = torch.nn.CrossEntropyLoss() optimizer = optim.Adam(CNN.parameters(), lr=opt.lr) for epoch in range(opt.epochs): # train for one epoch print(f"\nBegin Training Epoch {epoch + 1}") # Calculate and return the top-k accuracy of the model # so that we can track the learning process. batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() end = time.time() for i, data in enumerate(train_dataloader): # get the inputs; data is a list of [inputs, labels] inputs, targets = data inputs = inputs.to(device) targets = targets.to(device) # compute output output = CNN(inputs) loss = criterion(output, targets) # measure accuracy and record loss prec1, prec5 = accuracy(output, targets, topk=(1, 5)) losses.update(loss.item(), inputs.size(0)) top1.update(prec1, inputs.size(0)) top5.update(prec5, inputs.size(0)) # compute gradients in a backward pass optimizer.zero_grad() loss.backward() # Call step of optimizer to update model params optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % 15 == 0: print( f"Epoch [{epoch + 1}] [{i}/{len(train_dataloader)}]\t" f"Loss {loss.item():.4f}\t" f"Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t" f"Prec@5 {top5.val:.3f} ({top5.avg:.3f})", end="\r") # save model file torch.save(CNN.state_dict(), MODEL_PATH)
def train(train_loader, net, optim, curr_epoch, writer, scheduler, max_iter): """ Runs the training loop per epoch train_loader: Data loader for train net: thet network optimizer: optimizer curr_epoch: current epoch writer: tensorboard writer return: """ net.train() train_total_loss = AverageMeter() time_meter = AverageMeter() curr_iter = curr_epoch * len(train_loader) for i, data in enumerate(train_loader): if curr_iter >= max_iter: break inputs, gts, _, aux_gts = data # Multi source and AGG case if len(inputs.shape) == 5: B, D, C, H, W = inputs.shape num_domains = D inputs = inputs.transpose(0, 1) gts = gts.transpose(0, 1).squeeze(2) aux_gts = aux_gts.transpose(0, 1).squeeze(2) inputs = [ input.squeeze(0) for input in torch.chunk(inputs, num_domains, 0) ] gts = [gt.squeeze(0) for gt in torch.chunk(gts, num_domains, 0)] aux_gts = [ aux_gt.squeeze(0) for aux_gt in torch.chunk(aux_gts, num_domains, 0) ] else: B, C, H, W = inputs.shape num_domains = 1 inputs = [inputs] gts = [gts] aux_gts = [aux_gts] batch_pixel_size = C * H * W for di, ingredients in enumerate(zip(inputs, gts, aux_gts)): input, gt, aux_gt = ingredients start_ts = time.time() img_gt = None input, gt = input.cuda(), gt.cuda() optim.zero_grad() if args.use_isw: outputs = net(input, gts=gt, aux_gts=aux_gt, img_gt=img_gt, visualize=args.visualize_feature, apply_wtloss=False if curr_epoch <= args.cov_stat_epoch else True) else: outputs = net(input, gts=gt, aux_gts=aux_gt, img_gt=img_gt, visualize=args.visualize_feature) outputs_index = 0 main_loss = outputs[outputs_index] outputs_index += 1 aux_loss = outputs[outputs_index] outputs_index += 1 total_loss = main_loss + (0.4 * aux_loss) if args.use_wtloss and (not args.use_isw or (args.use_isw and curr_epoch > args.cov_stat_epoch)): wt_loss = outputs[outputs_index] outputs_index += 1 total_loss = total_loss + (args.wt_reg_weight * wt_loss) else: wt_loss = 0 if args.visualize_feature: f_cor_arr = outputs[outputs_index] outputs_index += 1 log_total_loss = total_loss.clone().detach_() torch.distributed.all_reduce(log_total_loss, torch.distributed.ReduceOp.SUM) log_total_loss = log_total_loss / args.world_size train_total_loss.update(log_total_loss.item(), batch_pixel_size) total_loss.backward() optim.step() time_meter.update(time.time() - start_ts) del total_loss, log_total_loss if args.local_rank == 0: if i % 50 == 49: if args.visualize_feature: visualize_matrix(writer, f_cor_arr, curr_iter, '/Covariance/Feature-') msg = '[epoch {}], [iter {} / {} : {}], [loss {:0.6f}], [lr {:0.6f}], [time {:0.4f}]'.format( curr_epoch, i + 1, len(train_loader), curr_iter, train_total_loss.avg, optim.param_groups[-1]['lr'], time_meter.avg / args.train_batch_size) logging.info(msg) if args.use_wtloss: print("Whitening Loss", wt_loss) # Log tensorboard metrics for each iteration of the training phase writer.add_scalar('loss/train_loss', (train_total_loss.avg), curr_iter) train_total_loss.reset() time_meter.reset() curr_iter += 1 scheduler.step() if i > 5 and args.test_mode: return curr_iter return curr_iter
def test(args, model, video_val=None): reward_avg = AverageMeter() loss_avg = AverageMeter() value_loss_avg = AverageMeter() policy_loss_avg = AverageMeter() root_dir = '/home/youngfly/DL_project/RL_Tracking/dataset/VOT' data_type = 'VOT' model.eval() env = Env(seqs_path=root_dir, data_set_type=data_type, save_path='/dataset/Result/VOT') for video_name in video_val: actions = [] rewards = [] values = [] entropies = [] logprobs = [] # reset for new video observation1, observation2 = env.reset(video_name) img1 = ReadSingleImage(observation2) img1 = Variable(img1).cuda() hidden_prev = model.init_hidden_state( batch_size=1) # variable cuda tensor _, _, _, _, hidden_pres = model(imgs=img1, hidden_prev=hidden_prev) # for loop init parameter hidden_prev = hidden_pres observation = observation2 FLAG = 1 i = 2 while FLAG: img = ReadSingleImage(observation) img = Variable(img).cuda() action_prob, action_logprob, action_sample, value, hidden_pres = model( imgs=img, hidden_prev=hidden_prev) entropy = -(action_logprob * action_prob).sum(1, keepdim=True) entropies.append(entropy) actions.append(action_sample.long()) # list, Variable cuda inner action_np = action_sample.data.cpu().numpy() sample = Variable(torch.LongTensor(action_np).cuda()).unsqueeze(0) hidden_prev = hidden_pres logprob = action_logprob.gather(1, sample) logprobs.append(logprob) reward, new_observation, done = env.step(action=action_np) env.show_all() # env.show_tracking_result() print( 'test:', 'frame:%d' % (i), 'Action:%d' % action_np[0], 'rewards:%.6f' % reward, 'probability:%.6f, %.6f' % (action_prob.data.cpu().numpy()[0, 0], action_prob.data.cpu().numpy()[0, 1])) i = i + 1 rewards.append(reward) # just list values.append(value) # list, Variable cuda inner observation = new_observation if done: FLAG = 0 num_seqs = len(rewards) running_add = Variable(torch.FloatTensor([0])).cuda() value_loss = 0 policy_loss = 0 gae = torch.FloatTensor([0]).cuda() values.append(running_add) for i in reversed(range(len(rewards))): # if rewards[i] < 0.2: # rewards[i] = rewards[i] ** 2 running_add = args.gamma * running_add + rewards[i] advantage = running_add - values[i] value_loss = value_loss + 0.5 * advantage.pow(2) delta_t = rewards[i] + args.gamma * values[i + 1].data - values[i].data gae = gae * args.gamma * args.tau + delta_t policy_loss = policy_loss - logprobs[i] * Variable( gae) - args.entropy_coef * entropies[i] # value_loss = value_loss / num_seqs # policy_loss = policy_loss / num_seqs # # values.append(running_add) # for i in reversed(range(len(rewards))): # running_add = args.gamma * running_add + rewards[i] # advantage = running_add - values[i] # value_loss = value_loss + 0.5 * advantage.pow(2) # policy_loss = policy_loss - logprobs[i] * advantage - args.entropy_coef * entropies[i] # value_loss = value_loss / num_seqs policy_loss = policy_loss / num_seqs loss = args.value_loss_coef * value_loss + policy_loss print(video_name, 'rewards:%.6f' % np.mean(rewards), 'loss:%.6f' % loss.data[0], 'value_loss:%6f' % value_loss.data[0], 'policy_loss:%.6f' % policy_loss.data[0]) # update the loss loss_avg.update(loss.data.cpu().numpy()) value_loss_avg.update(value_loss.data.cpu().numpy()) policy_loss_avg.update(policy_loss.data.cpu().numpy()) reward_avg.update(np.mean(rewards)) return reward_avg.avg, loss_avg.avg, value_loss_avg.avg, policy_loss_avg.avg
def test(test_loader, model, configs): batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') iou_seg = AverageMeter('IoU_Seg', ':6.4f') mse_global = AverageMeter('MSE_Global', ':6.4f') mse_local = AverageMeter('MSE_Local', ':6.4f') mse_overall = AverageMeter('MSE_Overall', ':6.4f') pce = AverageMeter('PCE', ':6.4f') spce = AverageMeter('Smooth_PCE', ':6.4f') w_original = 1920. h_original = 1080. w, h = configs.input_size # switch to evaluate mode model.eval() with torch.no_grad(): start_time = time.time() for batch_idx, (resized_imgs, org_ball_pos_xy, global_ball_pos_xy, target_events, target_seg) in enumerate(tqdm(test_loader)): print( '\n===================== batch_idx: {} ================================' .format(batch_idx)) data_time.update(time.time() - start_time) batch_size = resized_imgs.size(0) target_seg = target_seg.to(configs.device, non_blocking=True) resized_imgs = resized_imgs.to(configs.device, non_blocking=True).float() # compute output pred_ball_global, pred_ball_local, pred_events, pred_seg, local_ball_pos_xy, total_loss, _ = model( resized_imgs, org_ball_pos_xy, global_ball_pos_xy, target_events, target_seg) org_ball_pos_xy = org_ball_pos_xy.numpy() global_ball_pos_xy = global_ball_pos_xy.numpy() # Transfer output to cpu target_seg = target_seg.cpu().numpy() for sample_idx in range(batch_size): # Get target sample_org_ball_pos_xy = org_ball_pos_xy[sample_idx] sample_global_ball_pos_xy = global_ball_pos_xy[ sample_idx] # Target # Process the global stage sample_pred_ball_global = pred_ball_global[sample_idx] sample_prediction_ball_global_xy = get_prediction_ball_pos( sample_pred_ball_global, w, configs.thresh_ball_pos_mask) # Calculate the MSE if (sample_global_ball_pos_xy[0] > 0) and (sample_global_ball_pos_xy[1] > 0) and ( sample_prediction_ball_global_xy[0] > 0) and (sample_prediction_ball_global_xy[1] > 0): mse = (sample_prediction_ball_global_xy[0] - sample_global_ball_pos_xy[0]) ** 2 + \ (sample_prediction_ball_global_xy[1] - sample_global_ball_pos_xy[1]) ** 2 mse_global.update(mse) print( '\nBall Detection - \t Global stage: \t (x, y) - gt = ({}, {}), prediction = ({}, {})' .format(sample_global_ball_pos_xy[0], sample_global_ball_pos_xy[1], sample_prediction_ball_global_xy[0], sample_prediction_ball_global_xy[1])) sample_pred_org_x = sample_prediction_ball_global_xy[0] * ( w_original / w) sample_pred_org_y = sample_prediction_ball_global_xy[1] * ( h_original / h) # Process local ball stage if pred_ball_local is not None: # Get target local_ball_pos_xy = local_ball_pos_xy.cpu().numpy( ) # Ground truth of the local stage sample_local_ball_pos_xy = local_ball_pos_xy[ sample_idx] # Target # Process the local stage sample_pred_ball_local = pred_ball_local[sample_idx] sample_prediction_ball_local_xy = get_prediction_ball_pos( sample_pred_ball_local, w, configs.thresh_ball_pos_mask) # Calculate the MSE if (sample_local_ball_pos_xy[0] > 0) and (sample_local_ball_pos_xy[1] > 0): mse = (sample_prediction_ball_local_xy[0] - sample_local_ball_pos_xy[0])**2 + ( sample_prediction_ball_local_xy[1] - sample_local_ball_pos_xy[1])**2 mse_local.update(mse) sample_pred_org_x += sample_prediction_ball_local_xy[ 0] - w / 2 sample_pred_org_y += sample_prediction_ball_local_xy[ 1] - h / 2 print( 'Ball Detection - \t Local stage: \t (x, y) - gt = ({}, {}), prediction = ({}, {})' .format(sample_local_ball_pos_xy[0], sample_local_ball_pos_xy[1], sample_prediction_ball_local_xy[0], sample_prediction_ball_local_xy[1])) print( 'Ball Detection - \t Overall: \t (x, y) - org: ({}, {}), prediction = ({}, {})' .format(sample_org_ball_pos_xy[0], sample_org_ball_pos_xy[1], int(sample_pred_org_x), int(sample_pred_org_y))) mse = (sample_org_ball_pos_xy[0] - sample_pred_org_x)**2 + ( sample_org_ball_pos_xy[1] - sample_pred_org_y)**2 mse_overall.update(mse) # Process event stage if pred_events is not None: sample_target_events = target_events[sample_idx].numpy() sample_prediction_events = prediction_get_events( pred_events[sample_idx], configs.event_thresh) print( 'Event Spotting - \t gt = (is bounce: {}, is net: {}), prediction: (is bounce: {:.4f}, is net: {:.4f})' .format(sample_target_events[0], sample_target_events[1], pred_events[sample_idx][0], pred_events[sample_idx][1])) # Compute metrics spce.update( SPCE(sample_prediction_events, sample_target_events, thresh=0.5)) pce.update( PCE(sample_prediction_events, sample_target_events)) # Process segmentation stage if pred_seg is not None: sample_target_seg = target_seg[sample_idx].transpose( 1, 2, 0).astype(np.int) sample_prediction_seg = get_prediction_seg( pred_seg[sample_idx], configs.seg_thresh) # Calculate the IoU iou = 2 * np.sum( sample_target_seg * sample_prediction_seg) / ( np.sum(sample_target_seg) + np.sum(sample_prediction_seg) + 1e-9) iou_seg.update(iou) print('Segmentation - \t \t IoU = {:.4f}'.format(iou)) if configs.save_test_output: fig, axes = plt.subplots(nrows=batch_size, ncols=2, figsize=(10, 5)) plt.tight_layout() axes.ravel() axes[2 * sample_idx].imshow(sample_target_seg * 255) axes[2 * sample_idx + 1].imshow(sample_prediction_seg * 255) # title target_title = 'target seg' pred_title = 'pred seg' if pred_events is not None: target_title += ', is bounce: {}, is net: {}'.format( sample_target_events[0], sample_target_events[1]) pred_title += ', is bounce: {}, is net: {}'.format( sample_prediction_events[0], sample_prediction_events[1]) axes[2 * sample_idx].set_title(target_title) axes[2 * sample_idx + 1].set_title(pred_title) plt.savefig( os.path.join( configs.saved_dir, 'batch_idx_{}_sample_idx_{}.jpg'.format( batch_idx, sample_idx))) if ((batch_idx + 1) % configs.print_freq) == 0: print( 'batch_idx: {} - Average iou_seg: {:.4f}, mse_global: {:.1f}, mse_local: {:.1f}, mse_overall: {:.1f}, pce: {:.4f} spce: {:.4f}' .format(batch_idx, iou_seg.avg, mse_global.avg, mse_local.avg, mse_overall.avg, pce.avg, spce.avg)) batch_time.update(time.time() - start_time) start_time = time.time() print( 'Average iou_seg: {:.4f}, mse_global: {:.1f}, mse_local: {:.1f}, mse_overall: {:.1f}, pce: {:.4f} spce: {:.4f}' .format(iou_seg.avg, mse_global.avg, mse_local.avg, mse_overall.avg, pce.avg, spce.avg)) print('Done testing')
def evaluate_mAP(val_loader, model, configs, logger): batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') progress = ProgressMeter(len(val_loader), [batch_time, data_time], prefix="Evaluation phase...") labels = [] sample_metrics = [] # List of tuples (TP, confs, pred) # switch to evaluate mode model.eval() with torch.no_grad(): start_time = time.time() for batch_idx, batch_data in enumerate(tqdm(val_loader)): metadatas, targets = batch_data batch_size = len(metadatas['img_path']) voxelinput = metadatas['voxels'] coorinput = metadatas['coors'] numinput = metadatas['num_points'] dtype = torch.float32 voxelinputr = torch.tensor(voxelinput, dtype=torch.float32, device=configs.device).to(dtype) coorinputr = torch.tensor(coorinput, dtype=torch.int32, device=configs.device) numinputr = torch.tensor(numinput, dtype=torch.int32, device=configs.device) t1 = time_synchronized() outputs = model(voxelinputr, coorinputr, numinputr) outputs = outputs._asdict() outputs['hm_cen'] = _sigmoid(outputs['hm_cen']) outputs['cen_offset'] = _sigmoid(outputs['cen_offset']) # detections size (batch_size, K, 10) detections = decode(outputs['hm_cen'], outputs['cen_offset'], outputs['direction'], outputs['z_coor'], outputs['dim'], K=configs.K) detections = detections.cpu().numpy().astype(np.float32) detections = post_processingv2(detections, configs.num_classes, configs.down_ratio, configs.peak_thresh) for sample_i in range(len(detections)): # print(output.shape) num = targets['count'][sample_i] # print(targets['batch'][sample_i][:num].shape) target = targets['batch'][sample_i][:num] #print(target[:, 8].tolist()) labels += target[:, 8].tolist() sample_metrics += get_batch_statistics_rotated_bbox( detections, targets, iou_threshold=configs.iou_thresh) t2 = time_synchronized() # measure elapsed time # torch.cuda.synchronize() batch_time.update(time.time() - start_time) # Log message if logger is not None: if ((batch_idx + 1) % configs.print_freq) == 0: logger.info(progress.get_message(batch_idx)) start_time = time.time() # Concatenate sample statistics true_positives, pred_scores, pred_labels = [ np.concatenate(x, 0) for x in list(zip(*sample_metrics)) ] precision, recall, AP, f1, ap_class = ap_per_class( true_positives, pred_scores, pred_labels, labels) return precision, recall, AP, f1, ap_class
def evaluate(val_loader, net): ''' Runs the evaluation loop and prints F score val_loader: Data loader for validation net: thet network return: ''' net.eval() # 0.0005 13.0 it/sec # 0.001875 4.80 it/sec # 0.00375 1.70 it/sec # 0.005 1.03 it/sec thresh = 0.0001 mf_score1 = AverageMeter() mf_pc_score1 = AverageMeter() ap_score1 = AverageMeter() ap_pc_score1 = AverageMeter() IOU_acc = 0 Fpc = np.zeros((args.dataset_cls.num_classes)) Fc = np.zeros((args.dataset_cls.num_classes)) for vi, data in enumerate(val_loader): input, mask, edge, img_names = data assert len(input.size()) == 4 and len(mask.size()) == 3 assert input.size()[2:] == mask.size()[1:] h, w = mask.size()[1:] batch_pixel_size = input.size(0) * input.size(2) * input.size(3) input, mask_cuda, edge_cuda = input.cuda(), mask.cuda(), edge.cuda() with torch.no_grad(): seg_out, edge_out = net(input) seg_predictions = seg_out.data.max(1)[1].cpu() edge_predictions = edge_out.max(1)[0].cpu() logging.info('evaluating: %d / %d' % (vi + 1, len(val_loader))) ''' _Fpc, _Fc = eval_mask_boundary(seg_predictions.numpy(), mask.numpy(), args.dataset_cls.num_classes, bound_th=float(thresh)) Fc += _Fc Fpc += _Fpc logging.info('F_Score: ' + str(np.sum(Fpc/Fc)/args.dataset_cls.num_classes)) ''' IOU_acc += fast_hist(seg_predictions.numpy().flatten(), mask.numpy().flatten(), args.dataset_cls.num_classes) del seg_out, edge_out, vi, data acc = np.diag(IOU_acc).sum() / IOU_acc.sum() acc_cls = np.diag(IOU_acc) / IOU_acc.sum(axis=1) acc_cls = np.nanmean(acc_cls) iu = np.diag(IOU_acc) / (IOU_acc.sum(axis=1) + IOU_acc.sum(axis=0) - np.diag(IOU_acc)) freq = IOU_acc.sum(axis=1) / IOU_acc.sum() mean_iu = np.nanmean(iu) fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() #logging.info('F_Score: ' + str(np.sum(Fpc/Fc)/args.dataset_cls.num_classes)) #logging.info('F_Score (Classwise): ' + str(Fpc/Fc)) results = { "mean_iu": mean_iu, "acc": acc, "acc_cls": acc_cls, "fwavacc": fwavacc } return results
def train(train_loader, net, optim, curr_epoch, writer): """ Runs the training loop per epoch train_loader: Data loader for train net: thet network optimizer: optimizer curr_epoch: current epoch writer: tensorboard writer return: """ net.train() train_main_loss = AverageMeter() curr_iter = curr_epoch * len(train_loader) for i, data in enumerate(train_loader): # inputs = (2,3,713,713) # gts = (2,713,713) inputs, gts, _img_name = data batch_pixel_size = inputs.size(0) * inputs.size(2) * inputs.size(3) inputs, gts = inputs.cuda(), gts.cuda() optim.zero_grad() main_loss = net(inputs, gts=gts) if args.apex and not args.local_computer: log_main_loss = main_loss.clone().detach_() torch.distributed.all_reduce(log_main_loss, torch.distributed.ReduceOp.SUM) log_main_loss = log_main_loss / args.world_size else: main_loss = main_loss.mean() log_main_loss = main_loss.clone().detach_() train_main_loss.update(log_main_loss.item(), batch_pixel_size) if args.fp16: # and 0: with amp.scale_loss(main_loss, optim) as scaled_loss: scaled_loss.backward() else: main_loss.backward() optim.step() curr_iter += 1 if args.local_rank == 0: msg = '[epoch {}], [iter {} / {}], [train main loss {:0.6f}], [lr {:0.6f}]'.format( curr_epoch, i + 1, len(train_loader), train_main_loss.avg, optim.param_groups[-1]['lr']) logging.info(msg) # Log tensorboard metrics for each iteration of the training phase writer.add_scalar('training/loss', (train_main_loss.val), curr_iter) writer.add_scalar('training/lr', optim.param_groups[-1]['lr'], curr_iter) if i > 5 and args.test_mode: return
def main(): print(args) os.makedirs(args.out, exist_ok=True) args.writer = SummaryWriter(args.out) h, w = map(int, args.input_size.split(',')) input_size = (h, w) cudnn.enabled = True gpu = args.gpu # create network #model = Res_Deeplab(num_classes=args.num_classes) model = DeepV3PlusW38(num_classes=args.num_classes) # load pretrained parameters saved_state_dict = torch.load(args.restore_from) new_params = model.state_dict().copy() for name, param in new_params.items(): if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) model.load_state_dict(new_params) model.train() model.cuda(args.gpu) model = torch.nn.DataParallel(model).cuda() cudnn.benchmark = True # init D model_D = s4GAN_discriminator(num_classes=args.num_classes, dataset=args.dataset) if args.restore_from_D is not None: model_D.load_state_dict(torch.load(args.restore_from_D)) model_D = torch.nn.DataParallel(model_D).cuda() cudnn.benchmark = True model_D.train() model_D.cuda(args.gpu) if not os.path.exists(args.checkpoint_dir): os.makedirs(args.checkpoint_dir) if args.dataset == 'pascal_voc': train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) #train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size, #scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) elif args.dataset == 'pascal_context': input_transform = transform.Compose([ transform.ToTensor(), transform.Normalize([.406, .456, .485], [.229, .224, .225]) ]) data_kwargs = { 'transform': input_transform, 'base_size': 505, 'crop_size': 321 } #train_dataset = get_segmentation_dataset('pcontext', split='train', mode='train', **data_kwargs) data_loader = get_loader('pascal_context') data_path = get_data_path('pascal_context') train_dataset = data_loader(data_path, split='train', mode='train', **data_kwargs) #train_gt_dataset = data_loader(data_path, split='train', mode='train', **data_kwargs) elif args.dataset == 'cityscapes': data_loader = get_loader('cityscapes') data_path = get_data_path('cityscapes') data_aug = Compose([ RandomCrop_city((input_size[0], input_size[1])), RandomHorizontallyFlip() ]) train_dataset = data_loader(data_path, is_transform=True, img_size=(input_size[0], input_size[1]), augmentations=data_aug) #train_gt_dataset = data_loader( data_path, is_transform=True, augmentations=data_aug) elif args.dataset == 'ade20k': train_dataset = ADE20K(mode='train', crop_size=input_size) train_dataset_size = len(train_dataset) print('dataset size: ', train_dataset_size) if args.labeled_ratio is None: trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) trainloader_gt = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) trainloader_remain = data.DataLoader(train_dataset, batch_size=args.batch_size_unlab, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) trainloader_remain_iter = iter(trainloader_remain) else: partial_size = int(args.labeled_ratio * train_dataset_size) print('labeled data: ', partial_size) print('unlabeled data: ', train_dataset_size - partial_size) if args.split_id is not None: train_ids = pickle.load(open(args.split_id, 'rb')) print('loading train ids from {}'.format(args.split_id)) else: train_ids = np.arange(train_dataset_size) np.random.shuffle(train_ids) pickle.dump( train_ids, open(os.path.join(args.checkpoint_dir, 'train_voc_split.pkl'), 'wb')) train_sampler = data.sampler.SubsetRandomSampler( train_ids[:partial_size]) train_remain_sampler = data.sampler.SubsetRandomSampler( train_ids[partial_size:]) train_gt_sampler = data.sampler.SubsetRandomSampler( train_ids[:partial_size]) trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=4, pin_memory=True, drop_last=True) trainloader_remain = data.DataLoader(train_dataset, batch_size=args.batch_size_unlab, sampler=train_remain_sampler, num_workers=4, pin_memory=True, drop_last=True) trainloader_gt = data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_gt_sampler, num_workers=4, pin_memory=True, drop_last=True) trainloader_remain_iter = iter(trainloader_remain) print('train dataloader created!') trainloader_iter = iter(trainloader) trainloader_gt_iter = iter(trainloader_gt) if args.dataset == 'pascal_voc': valloader = data.DataLoader(VOCDataSet(args.data_dir, args.data_list, crop_size=(505, 505), mean=IMG_MEAN, scale=False, mirror=False), batch_size=1, shuffle=False, pin_memory=True) interp_val = nn.Upsample(size=(505, 505), mode='bilinear', align_corners=True) elif args.dataset == 'cityscapes': val_dataset = data_loader(data_path, img_size=(512, 1024), is_transform=True, split='val') valloader = data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True) interp_val = nn.Upsample(size=(512, 1024), mode='bilinear', align_corners=True) elif args.dataset == 'ade20k': val_dataset = ADE20K(mode='val', crop_size=(505, 505)) valloader = data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=True) interp_val = nn.Upsample(size=(505, 505), mode='bilinear', align_corners=True) print('val dataloader created!') # optimizer for segmentation network optimizer = optim.SGD(model.module.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) scheduler = CosineAnnealingLR(optimizer, T_max=args.num_steps, eta_min=args.learning_rate / args.eta_min_factor) optimizer.zero_grad() # optimizer for discriminator network optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) scheduler_D = CosineAnnealingLR(optimizer_D, T_max=args.num_steps, eta_min=args.learning_rate_D / args.eta_min_factor) optimizer_D.zero_grad() interp = nn.Upsample(size=(input_size[0], input_size[1]), mode='bilinear', align_corners=True) # labels for adversarial training pred_label = 0 gt_label = 1 y_real_, y_fake_ = Variable(torch.ones(args.batch_size, 1).cuda()), Variable( torch.zeros(args.batch_size, 1).cuda()) losses_ce = AverageMeter() losses_st = AverageMeter() losses_S = AverageMeter() losses_D = AverageMeter() losses_fm = AverageMeter() counts = AverageMeter() for i_iter in range(args.num_steps): model.train() loss_ce_value = 0 loss_D_value = 0 loss_fm_value = 0 loss_S_value = 0 #args.threshold_st = adjust_threshold_st(i_iter) optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter) # train Segmentation Network # don't accumulate grads in D for param in model_D.parameters(): param.requires_grad = False # training loss for labeled data only try: batch = next(trainloader_iter) except: trainloader_iter = iter(trainloader) batch = next(trainloader_iter) images, labels, _, _, _ = batch images = images.cuda() pred = interp(model(images)) loss_ce = loss_calc(pred, labels, args.gpu) # Cross entropy loss for labeled data #training loss for remaining unlabeled data try: batch_remain = next(trainloader_remain_iter) except: trainloader_remain_iter = iter(trainloader_remain) batch_remain = next(trainloader_remain_iter) images_remain, _, _, _, _ = batch_remain images_remain = Variable(images_remain).cuda(args.gpu) pred_remain = interp(model(images_remain)) # concatenate the prediction with the input images images_remain = (images_remain - torch.min(images_remain)) / ( torch.max(images_remain) - torch.min(images_remain)) #print (pred_remain.size(), images_remain.size()) pred_cat = torch.cat((F.softmax(pred_remain, dim=1), images_remain), dim=1) D_out_z, D_out_y_pred = model_D( pred_cat) # predicts the D ouput 0-1 and feature map for FM-loss # find predicted segmentation maps above threshold pred_sel, labels_sel, count = find_good_maps(D_out_z, pred_remain) # training loss on above threshold segmentation predictions (Cross Entropy Loss) if count > 0 and i_iter > 0: loss_st = loss_calc(pred_sel, labels_sel, args.gpu) losses_st.update(loss_st.item()) else: loss_st = 0.0 # Concatenates the input images and ground-truth maps for the Districrimator 'Real' input try: batch_gt = next(trainloader_gt_iter) except: trainloader_gt_iter = iter(trainloader_gt) batch_gt = next(trainloader_gt_iter) images_gt, labels_gt, _, _, _ = batch_gt # Converts grounth truth segmentation into 'num_classes' segmentation maps. D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu) images_gt = images_gt.cuda() images_gt = (images_gt - torch.min(images_gt)) / (torch.max(images) - torch.min(images)) D_gt_v_cat = torch.cat((D_gt_v, images_gt), dim=1) D_out_z_gt, D_out_y_gt = model_D(D_gt_v_cat) # L1 loss for Feature Matching Loss loss_fm = torch.mean( torch.abs(torch.mean(D_out_y_gt, 0) - torch.mean(D_out_y_pred, 0))) if count > 0 and i_iter > 0: # if any good predictions found for self-training loss loss_S = loss_ce + args.lambda_fm * loss_fm + args.lambda_st * loss_st else: loss_S = loss_ce + args.lambda_fm * loss_fm loss_S.backward() loss_fm_value += args.lambda_fm * loss_fm loss_ce_value += loss_ce.item() loss_S_value += loss_S.item() # train D for param in model_D.parameters(): param.requires_grad = True # train with pred pred_cat = pred_cat.detach( ) # detach does not allow the graddients to back propagate. D_out_z, _ = model_D(pred_cat) y_fake_ = Variable(torch.zeros(D_out_z.size(0), 1).cuda()) loss_D_fake = criterion(D_out_z, y_fake_) # train with gt D_out_z_gt, _ = model_D(D_gt_v_cat) y_real_ = Variable(torch.ones(D_out_z_gt.size(0), 1).cuda()) loss_D_real = criterion(D_out_z_gt, y_real_) loss_D = (loss_D_fake + loss_D_real) / 2.0 loss_D.backward() loss_D_value += loss_D.item() optimizer.step() #scheduler.step() optimizer_D.step() #scheduler_D.step() losses_ce.update(loss_ce.item()) losses_S.update(loss_S.item()) losses_D.update(loss_D.item()) losses_fm.update(loss_fm.item()) counts.update(count) if i_iter % 10 == 0: log_idx = i_iter / 10 args.writer.add_scalar('train/1.train_loss_ce', losses_ce.avg, log_idx) args.writer.add_scalar('train/2.train_loss_st', losses_st.avg, log_idx) args.writer.add_scalar('train/3.train_loss_fm', losses_fm.avg, log_idx) args.writer.add_scalar('train/4.train_loss_S', losses_S.avg, log_idx) args.writer.add_scalar('train/5.train_loss_D', losses_D.avg, log_idx) args.writer.add_scalar('train/6.count', counts.avg, log_idx) args.writer.add_scalar('train/7.lr', optimizer.param_groups[0]['lr'], log_idx) losses_ce = AverageMeter() losses_st = AverageMeter() losses_S = AverageMeter() losses_D = AverageMeter() losses_fm = AverageMeter() counts = AverageMeter() print( 'iter = {0:8d}/{1:8d}, loss_ce = {2:.3f}, loss_fm = {3:.3f}, loss_S = {4:.3f}, loss_D = {5:.3f}' .format(i_iter, args.num_steps, loss_ce_value, loss_fm_value, loss_S_value, loss_D_value)) if i_iter % 200 == 0: miou_val, loss_val = validate(valloader, interp_val, model) print('miou_val: ', miou_val, ' loss_val; ', loss_val) #mious.update(miou_val) #losses_val.update(loss_val) args.writer.add_scalar('val/1.val_miou', miou_val, i_iter / 1000) args.writer.add_scalar('val/2.val_loss', loss_val, i_iter / 1000) #mious = AverageMeter() #losses_val = AverageMeter() if i_iter >= args.num_steps - 1: print('save model ...') torch.save( model.state_dict(), os.path.join(args.checkpoint_dir, 'VOC_' + str(args.num_steps) + '.pth')) torch.save( model_D.state_dict(), os.path.join(args.checkpoint_dir, 'VOC_' + str(args.num_steps) + '_D.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('saving checkpoint ...') torch.save( model.state_dict(), os.path.join(args.checkpoint_dir, 'VOC_' + str(i_iter) + '.pth')) torch.save( model_D.state_dict(), os.path.join(args.checkpoint_dir, 'VOC_' + str(i_iter) + '_D.pth')) end = timeit.default_timer() print(end - start, 'seconds')
def train(train_dataloader, model, criterion, optimizer, epoch): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to train mode model.train() end = time.time() for i, data in enumerate(train_dataloader): # measure data loading time data_time.update(time.time() - end) # get the inputs; data is a list of [inputs, labels] inputs, targets = data inputs = inputs.to(device) targets = targets.to(device) # compute output output = model(inputs) loss = criterion(output, targets) # measure accuracy and record loss prec1, prec5 = accuracy(output, targets, topk=(1, 5)) losses.update(loss.item(), inputs.size(0)) top1.update(prec1, inputs.size(0)) top5.update(prec5, inputs.size(0)) # compute gradients in a backward pass optimizer.zero_grad() loss.backward() # Call step of optimizer to update model params optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % 5 == 0: print(f"Epoch [{epoch + 1}] [{i}/{len(train_dataloader)}]\t" f"Time {data_time.val:.3f} ({data_time.avg:.3f})\t" f"Loss {loss.item():.4f}\t" f"Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t" f"Prec@5 {top5.val:.3f} ({top5.avg:.3f})", end="\r") torch.save(model.state_dict(), f"./checkpoints/{opt.datasets}_epoch_{epoch + 1}.pth")
def train_with_clustering(save_folder, tmp_seg_folder, startnet, args): print(save_folder.split('/')[-1]) skip_clustering = False device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") check_mkdir(save_folder) writer = SummaryWriter(save_folder) check_mkdir(tmp_seg_folder) # Network and weight loading model_config = model_configs.PspnetCityscapesConfig() net = model_config.init_network( n_classes=args['n_clusters'], for_clustering=True, output_features=True, use_original_base=args['use_original_base']).to(device) state_dict = torch.load(startnet) if 'resnet101' in startnet: load_resnet101_weights(net, state_dict) else: # needed since we slightly changed the structure of the network in pspnet state_dict = rename_keys_to_match(state_dict) # different amount of classes init_last_layers(state_dict, args['n_clusters']) net.load_state_dict(state_dict) # load original weights start_iter = 0 args['best_record'] = { 'iter': 0, 'val_loss_feat': 1e10, 'val_loss_out': 1e10, 'val_loss_cluster': 1e10 } # Data loading setup if args['corr_set'] == 'rc': corr_set_config = data_configs.RobotcarConfig() elif args['corr_set'] == 'pola': corr_set_config = data_configs.PolaConfig() elif args['corr_set'] == 'cmu': corr_set_config = data_configs.CmuConfig() elif args['corr_set'] == 'both': corr_set_config1 = data_configs.CmuConfig() corr_set_config2 = data_configs.RobotcarConfig() ref_image_lists = corr_set_config.reference_image_list # ref_image_lists = glob.glob("/media/HDD1/datasets/Creusot_Jan15/Creusot_3/*.jpg", recursive=True) # print(f'ici on print ref image list ---------------------------------------------------- {ref_image_lists}') # print(corr_set_config) # corr_im_paths = [corr_set_config.correspondence_im_path] # ref_featurs_pos = [corr_set_config.reference_feature_poitions] input_transform = model_config.input_transform #corr_set_train = correspondences.Correspondences(corr_set_config.correspondence_path, # corr_set_config.correspondence_im_path, # input_size=(713, 713), # input_transform=input_transform, # joint_transform=train_joint_transform_corr, # listfile=corr_set_config.correspondence_train_list_file) scales = [0, 1, 2, 3] # corr_set_train = Poladata.MonoDataset(corr_set_config, # seg_folder = "media/HDD1/NsemSEG/Result_fold/" , # im_file_ending = ".jpg" ) train_joint_transform = joint_transforms.Compose([ # train_joint_transform_corr = corr_transforms.Compose([ # corr_transforms.CorrResize(1024), # corr_transforms.CorrRandomCrop(713) joint_transforms.Resize(1024), joint_transforms.RandomCrop(713) ]) sliding_crop = joint_transforms.SlidingCrop(713, 2 / 3., 255) # corr_set_train = correspondences.Correspondences(corr_set_config.train_im_folder, # corr_set_config.train_im_folder, # input_size=(713, 713), # input_transform=input_transform, # joint_transform=train_joint_transform, # listfile=None) corr_set_train = Poladata.MonoDataset( corr_set_config.train_im_folder, corr_set_config.train_seg_folder, im_file_ending=".jpg", id_to_trainid=None, joint_transform=train_joint_transform, sliding_crop=sliding_crop, transform=input_transform, target_transform=None, #train_joint_transform, transform_before_sliding=None #sliding_crop ) #print (corr_set_train) # print(corr_set_train.mask) corr_loader_train = DataLoader(corr_set_train, batch_size=1, num_workers=args['n_workers'], shuffle=True) # corr_loader_train = input_transform(corr_loader_train) # print(corr_loader_train) seg_loss_fct = torch.nn.CrossEntropyLoss(reduction='elementwise_mean') # Optimizer setup optimizer = optim.SGD([{ 'params': [ param for name, param in net.named_parameters() if name[-4:] == 'bias' and param.requires_grad ], 'lr': 2 * args['lr'] }, { 'params': [ param for name, param in net.named_parameters() if name[-4:] != 'bias' and param.requires_grad ], 'lr': args['lr'], 'weight_decay': args['weight_decay'] }], momentum=args['momentum'], nesterov=True) # Clustering deepcluster = clustering.Kmeans(args['n_clusters']) if skip_clustering: deepcluster.set_index(cluster_centroids) open(os.path.join(save_folder, str(datetime.datetime.now()) + '.txt'), 'w').write(str(args) + '\n\n') f_handle = open(os.path.join(save_folder, 'log.log'), 'w', buffering=1) # clean_log_before_continuing(os.path.join(save_folder, 'log.log'), start_iter) # f_handle = open(os.path.join(save_folder, 'log.log'), 'a', buffering=1) val_iter = 0 curr_iter = start_iter while curr_iter <= args['max_iter']: net.eval() net.output_features = True # max_num_features_per_image = args['max_features_per_image'] # print('-----------------------------------------------------------------') # print (f'ref_image_lists est: {ref_image_lists},model_config es : {model_config} , net es: {net} , max feature par image es : {max_num_features_per_image} ') # print('-----------------------------------------------------------------') # print('le next du loader es : ---------------') # print(next(iter(corr_loader_train))) # features, _ = extract_features_for_reference(net, model_config, ref_image_lists, # corr_im_paths, ref_featurs_pos, # max_num_features_per_image=args['max_features_per_image'], # fraction_correspondeces=0.5) print( 'ici on a la len de la ref im list --------------------------------------------------------' ) print(len(ref_image_lists)) features = extract_features_for_reference_nocorr( net, model_config, corr_set_train, 10, max_num_features_per_image=args['max_features_per_image']) cluster_features = np.vstack(features) del features # cluster the features cluster_indices, clustering_loss, cluster_centroids, pca_info = deepcluster.cluster_imfeatures( cluster_features, verbose=True, use_gpu=False) # save cluster centroids h5f = h5py.File( os.path.join(save_folder, 'centroids_%d.h5' % curr_iter), 'w') h5f.create_dataset('cluster_centroids', data=cluster_centroids) h5f.create_dataset('pca_transform_Amat', data=pca_info[0]) h5f.create_dataset('pca_transform_bvec', data=pca_info[1]) h5f.close() # Print distribution of clusters cluster_distribution, _ = np.histogram( cluster_indices, bins=np.arange(args['n_clusters'] + 1), density=True) str2write = 'cluster distribution ' + \ np.array2string(cluster_distribution, formatter={ 'float_kind': '{0:.8f}'.format}).replace('\n', ' ') print(str2write) f_handle.write(str2write + "\n") # set last layer weight to a normal distribution reinit_last_layers(net) # make a copy of current network state to do cluster assignment net_for_clustering = copy.deepcopy(net) optimizer.param_groups[0]['lr'] = 2 * args['lr'] * ( 1 - float(curr_iter) / args['max_iter'])**args['lr_decay'] optimizer.param_groups[1]['lr'] = args['lr'] * ( 1 - float(curr_iter) / args['max_iter'])**args['lr_decay'] net.train() freeze_bn(net) net.output_features = False cluster_training_count = 0 # Train using the training correspondence set corr_train_loss = AverageMeter() seg_train_loss = AverageMeter() feature_train_loss = AverageMeter() while cluster_training_count < args[ 'cluster_interval'] and curr_iter <= args['max_iter']: # First extract cluster labels using saved network checkpoint print( 'on rentre dans la boucle extract cluster_______________________________________________' ) net.to("cpu") net_for_clustering.to(device) net_for_clustering.eval() net_for_clustering.output_features = True data_samples = [] extract_label_count = 0 while (extract_label_count < args['chunk_size']) and ( cluster_training_count + extract_label_count < args['cluster_interval'] ) and (val_iter + extract_label_count < args['val_interval']) and ( extract_label_count + curr_iter <= args['max_iter']): # img_ref, img_other, pts_ref, pts_other, _ = next(iter(corr_set_train)) corr_loader_train = input_transform(corr_loader_train) print( f'la valeur de corr loader train es de {corr_loader_train} lors de l iteration : {curr_iter}' ) img_ref, img_other, pts_ref, pts_other, _ = next( iter(corr_loader_train)) # print('le next du loader es : ---------------') # print(next(iter(corr_loader_train))) # print(img_ref) # Transfer data to device img_ref = img_ref.to(device) with torch.no_grad(): features = net_for_clustering(img_ref) # assign feature to clusters for entire patch output = features.cpu().numpy() output_flat = output.reshape( (output.shape[0], output.shape[1], -1)) cluster_image = np.zeros( (output.shape[0], output.shape[2], output.shape[3]), dtype=np.int64) for b in range(output_flat.shape[0]): out_f = output_flat[b] out_f2, _ = preprocess_features(np.swapaxes(out_f, 0, 1), pca_info=pca_info) cluster_labels = deepcluster.assign(out_f2) cluster_image[b] = cluster_labels.reshape( (output.shape[2], output.shape[3])) cluster_image = torch.from_numpy(cluster_image).to(device) # assign cluster to correspondence positions cluster_labels = assign_cluster_ids_to_correspondence_points( features, pts_ref, (deepcluster, pca_info), inds_other=pts_other, orig_im_size=(713, 713)) # Transfer data to cpu img_ref = img_ref.cpu() cluster_labels = [p.cpu() for p in cluster_labels] cluster_image = cluster_image.cpu() data_samples.append((img_ref, cluster_labels, cluster_image)) extract_label_count += 1 net_for_clustering.to("cpu") net.to(device) for data_sample in data_samples: img_ref, cluster_labels, cluster_image = data_sample # Transfer data to device img_ref = img_ref.to(device) cluster_labels = [p.to(device) for p in cluster_labels] cluster_image = cluster_image.to(device) optimizer.zero_grad() outputs_ref, aux_ref = net(img_ref) seg_main_loss = seg_loss_fct(outputs_ref, cluster_image) seg_aux_loss = seg_loss_fct(aux_ref, cluster_image) loss = args['seg_loss_weight'] * \ (seg_main_loss + 0.4 * seg_aux_loss) loss.backward() optimizer.step() cluster_training_count += 1 if type(seg_main_loss) == torch.Tensor: seg_train_loss.update(seg_main_loss.item(), 1) #################################################################################################### # LOGGING ETC #################################################################################################### curr_iter += 1 val_iter += 1 writer.add_scalar('train_seg_loss', seg_train_loss.avg, curr_iter) writer.add_scalar('lr', optimizer.param_groups[1]['lr'], curr_iter) if (curr_iter + 1) % args['print_freq'] == 0: str2write = '[iter %d / %d], [train seg loss %.5f], [train corr loss %.5f], [train feature loss %.5f]. [lr %.10f]' % ( curr_iter + 1, args['max_iter'], seg_train_loss.avg, optimizer.param_groups[1]['lr']) print(str2write) f_handle.write(str2write + "\n") if curr_iter > args['max_iter']: break # Post training f_handle.close() writer.close()
def train(train_loader, L, D, T, optim_L, optim_D, optim_T, epoch, device, args): L.train() D.train() T.train() batch_time = AverageMeter() data_time = AverageMeter() loss_2ds = AverageMeter() loss_3ds = AverageMeter() loss_advs = AverageMeter() loss_ts = AverageMeter() end = time.time() L2_loss = nn.MSELoss(reduction='mean') BCE_loss = nn.BCELoss(reduction='mean') for batch_idx, (xy, X, scale) in enumerate(train_loader): data_time.update(time.time() - end) # train D D.zero_grad() batch_sz = xy.size(0) pose_2d = xy[:, 0].to(device) # (bs, 17*2) xy_real = xy[:, 1:].to(device) # (bs, length-1,17*2) z_pred = L(pose_2d) # (bs,17) # random Rotation theta = np.random.uniform(-np.pi, np.pi, batch_sz).astype(np.float32) cos_theta = np.cos(theta)[:, None] sin_theta = np.sin(theta)[:, None] cos_theta = torch.from_numpy(cos_theta).to(device) sin_theta = torch.from_numpy(sin_theta).to(device) x = pose_2d[:, 0::2] y = pose_2d[:, 1::2] new_x = x*cos_theta + z_pred*sin_theta # create projection trans_3d_z = -x*sin_theta + z_pred*cos_theta xy_fake = torch.cat((new_x[:,:,None], y[:,:,None]), dim=2) xy_fake = xy_fake.view(batch_sz, -1) # (bs, 17*2) trans_3d_1 = torch.cat((new_x[:,:,None], y[:,:,None], trans_3d_z[:,:,None]), dim=2) trans_3d_1 = trans_3d_1.view(batch_sz, -1) # (bs,17*3) D_real_score = D( xy_real.view(batch_sz*(args.length-1), -1) ) # (bs*(length-1),17*2) y_real_ = torch.ones(batch_sz*(args.length-1), 1).to(device) D_real_loss = BCE_loss(D_real_score, y_real_) fake_pose_2d_t_repeat = xy_fake.repeat(args.length-1, 1) D_fake_score = D(fake_pose_2d_t_repeat) y_fake_ = torch.zeros(batch_sz, 1).to(device) D_fake_loss = BCE_loss(D_fake_score, y_fake_) D_train_loss = D_real_loss + D_fake_loss D_train_loss.backward(retain_graph=True) optim_D.step() # Train T T.zero_grad() pose_next = xy[:, 1].to(device) # (bs, 17*2) # Predict next pose z_pred_next = L(pose_next) # (bs,17) x_next = pose_next[:, 0::2] y_next = pose_next[:, 1::2] new_x_next = x_next*cos_theta + z_pred_next*sin_theta xy_fake_next = torch.cat((new_x_next[:,:, None], y_next[:,:, None]), dim=2) xy_fake_next = xy_fake_next.view(batch_sz,-1) # (bs, 17*2) T_real_score = T(pose_2d - pose_next) y_real_ = torch.ones(batch_sz, 1).to(device) T_real_loss = BCE_loss(T_real_score, y_real_) T_fake_score = T(xy_fake - xy_fake_next) y_fake_ = torch.zeros(batch_sz, 1).to(device) T_fake_loss = BCE_loss(T_fake_score, y_fake_) T_train_loss_gp = T_fake_loss + T_real_loss T_train_loss_gp.backward(retain_graph=True) optim_T.step() # Train L L.zero_grad() z_pred_fake_3d = L(xy_fake) # (bs,17) # Inverse 3D Transformation cos_theta_inv = np.cos(-theta)[:, None] sin_theta_inv = np.sin(-theta)[:, None] cos_theta_inv = torch.from_numpy(cos_theta_inv).to(device) sin_theta_inv = torch.from_numpy(sin_theta_inv).to(device) x_fake = xy_fake[:, 0::2] y_fake = xy_fake[:, 1::2] recover_new_x = x_fake*cos_theta_inv + z_pred_fake_3d*sin_theta_inv recover_xy = torch.cat((recover_new_x[:,:,None], y_fake[:,:,None]), dim=2) recover_xy = recover_xy.view(batch_sz, -1) # (bs, 17*2) trans_3d_2 = torch.cat((x_fake[:,:,None], y_fake[:,:,None], z_pred_fake_3d[:,:,None]), dim=2) trans_3d_2 = trans_3d_2.view(batch_sz, -1) loss_2d = L2_loss(recover_xy, pose_2d) loss_3d = L2_loss(trans_3d_1, trans_3d_2) D_result = D(xy_fake) y_ = torch.ones(batch_sz, 1).to(device) loss_adv = BCE_loss(D_result, y_) T_result = T(xy_fake - xy_fake_next) loss_t = BCE_loss(T_result, y_) L_train_loss = loss_adv + args.weight_2d*loss_2d + args.weight_3d*loss_3d + args.weight_wt*loss_t L_train_loss.backward(retain_graph=True) optim_L.step() loss_2ds.update(loss_2d.item(), batch_sz) loss_3ds.update(loss_3d.item(), batch_sz) loss_advs.update(loss_adv.item(), batch_sz) loss_ts.update(loss_t.item(), batch_sz) batch_time.update(time.time() - end) end = time.time() if args.verbose: if batch_idx % 5 == 0: outstr = '[{batch}/{size}], Data: {data:.3f}s | Batch: {bt:.3f}s | loss_2d: {l2d:.6f} | loss_3d: {l3d:.6f} | loss_adv: {ladv:.6f} | loss_t: {lt:.6f}'.format( batch = batch_idx + 1, size = len(train_loader), data = data_time.val, bt = batch_time.val, l2d = loss_2ds.val, l3d = loss_3ds.val, ladv = loss_advs.val, lt = loss_ts.val, ) print(outstr) return loss_2ds.avg, loss_3ds.avg, loss_advs.avg, loss_ts.avg
def train(train_loader, L, D, T, optim_L, optim_D, optim_T, epoch, device, args): L.train() D.train() T.train() batch_time = AverageMeter() data_time = AverageMeter() loss_2ds = AverageMeter() loss_3ds = AverageMeter() loss_advs = AverageMeter() loss_ts = AverageMeter() end = time.time() L2_loss = nn.MSELoss(reduction='mean') BCE_loss = nn.BCELoss(reduction='mean') for batch_idx, (xy, X, ls) in enumerate(train_loader): data_time.update(time.time() - end) # train D D.zero_grad() batch_sz = xy.size(0) pose_2d = xy[:, 0].to(device) # (bs, 17*2) xy_real = xy[:, 1:].to(device) # (bs, length-1,17*2) fake_pose_2d_t, _, trans_3d_1, rot_mat = lift_proj( L, pose_2d, batch_sz, device, None, True) D_real_loss = D(xy_real.view(batch_sz * (args.length - 1), -1)).mean() # (bs*(length-1),17*2) #y_real_ = torch.ones(batch_sz*(args.length-1), 1).to(device) #D_real_loss = BCE_loss(D_real_score, y_real_) fake_pose_2d_t_repeat = fake_pose_2d_t.repeat(args.length - 1, 1) D_fake_loss = D(fake_pose_2d_t_repeat).mean() #y_fake_ = torch.zeros(batch_sz, 1).to(device) #D_fake_loss = BCE_loss(D_fake_score, y_fake_) gradient_penalty_D = compute_gradient_penalty( D, xy_real.view(batch_sz * (args.length - 1), -1), fake_pose_2d_t_repeat, args.lamda_gp, device) D_train_loss = D_fake_loss - D_real_loss + gradient_penalty_D D_train_loss.backward(retain_graph=True) optim_D.step() # Train T T.zero_grad() xy_pose_next = xy[:, 1].to(device) fake_pose_2d_next_t, _, _, _ = lift_proj(L, xy_pose_next, batch_sz, device, rot_mat, True) T_real_loss = T(pose_2d - xy_pose_next).mean() #y_real_ = torch.ones(batch_sz, 1).to(device) #T_real_loss = BCE_loss(T_real_score, y_real_) T_fake_loss = T(fake_pose_2d_t - fake_pose_2d_next_t).mean() #y_fake_ = torch.zeros(batch_sz, 1).to(device) #T_fake_loss = BCE_loss(T_fake_score, y_fake_) gradient_penalty_T = compute_gradient_penalty( T, pose_2d - xy_pose_next, fake_pose_2d_t - fake_pose_2d_next_t, args.lamda_gp, device) T_train_loss = T_fake_loss - T_real_loss + gradient_penalty_T T_train_loss.backward(retain_graph=True) optim_T.step() # Train L L.zero_grad() rot_mat_inv = rot_mat.inverse() recon_pose_2d_t, trans_3d_2, _, _ = lift_proj(L, fake_pose_2d_t, batch_sz, device, rot_mat_inv, False) # print(trans_3d_1[0], trans_3d_2[0]) loss_2d = L2_loss(recon_pose_2d_t, pose_2d) loss_3d = L2_loss(trans_3d_1, trans_3d_2) #D_result = D(fake_pose_2d_t) #y_ = torch.ones(batch_sz, 1).to(device) #loss_adv = BCE_loss(D_result, y_) loss_adv = -D(fake_pose_2d_t).mean() #T_result = T(fake_pose_2d_t - fake_pose_2d_next_t) #loss_t = BCE_loss(T_result, y_) loss_t = -T(fake_pose_2d_t - fake_pose_2d_next_t).mean() L_train_loss = loss_adv + args.weight_2d * loss_2d + args.weight_3d * loss_3d + args.weight_wt * loss_t L_train_loss.backward(retain_graph=True) optim_L.step() loss_2ds.update(loss_2d.item(), batch_sz) loss_3ds.update(loss_3d.item(), batch_sz) loss_advs.update(loss_adv.item(), batch_sz) loss_ts.update(loss_t.item(), batch_sz) batch_time.update(time.time() - end) end = time.time() if args.verbose: if batch_idx % 5 == 0: outstr = '[{batch}/{size}], Data: {data:.3f}s | Batch: {bt:.3f}s | loss_2d: {l2d:.6f} | loss_3d: {l3d:.6f} | loss_adv: {ladv:.6f} | loss_t: {lt:.6f}'.format( batch=batch_idx + 1, size=len(train_loader), data=data_time.val, bt=batch_time.val, l2d=loss_2ds.val, l3d=loss_3ds.val, ladv=loss_advs.val, lt=loss_ts.val, ) print(outstr) return loss_2ds.avg, loss_3ds.avg, loss_advs.avg, loss_ts.avg
def train_cae(trainloader, model, class_name, testloader, y_train, device, args): """ model train function. :param trainloader: :param model: :param class_name: :param testloader: :param y_train: numpy array, sample normal/abnormal labels, [1 1 1 1 0 0] like, original sample size. :param device: cpu or gpu:0/1/... :param args: :return: """ global_step = 0 losses = AverageMeter() start_time = time.time() epoch_time = AverageMeter() for epoch in range(1, args.epochs + 1): model.train() need_hour, need_mins, need_secs = convert_secs2time( epoch_time.avg * (args.epochs - epoch)) need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format( need_hour, need_mins, need_secs) print('{:3d}/{:3d} ----- {:s} {:s}'.format(epoch, args.epochs, time_string(), need_time)) mse = nn.MSELoss(reduction='mean') # default lr = 0.1 / pow(2, np.floor(epoch / args.lr_schedule)) logger.add_scalar(class_name + "/lr", lr, epoch) if args.optimizer == 'SGD': optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=args.weight_decay) else: optimizer = optim.Adam(model.parameters(), eps=1e-7, weight_decay=0.0005) for batch_idx, (input, _, _) in enumerate(trainloader): optimizer.zero_grad() input = input.to(device) _, output = model(input) loss = mse(input, output) losses.update(loss.item(), 1) logger.add_scalar(class_name + '/loss', losses.avg, global_step) global_step = global_step + 1 loss.backward() optimizer.step() # print losses print('Epoch: [{} | {}], loss: {:.4f}'.format(epoch, args.epochs, losses.avg)) # log images if epoch % args.log_img_steps == 0: os.makedirs(os.path.join(RESULTS_DIR, class_name), exist_ok=True) fpath = os.path.join(RESULTS_DIR, class_name, 'pretrain_epoch_' + str(epoch) + '.png') visualize(input, output, fpath, num=32) # test while training if epoch % args.log_auc_steps == 0: rep, losses_result = test(testloader, model, class_name, args, device, epoch) centroid = torch.mean(rep, dim=0, keepdim=True) losses_result = losses_result - losses_result.min() losses_result = losses_result / (1e-8 + losses_result.max()) scores = 1 - losses_result auroc_rec = roc_auc_score(y_train, scores) _, p = dec_loss_fun(rep, centroid) score_p = p[:, 0] auroc_dec = roc_auc_score(y_train, score_p) print("Epoch: [{} | {}], auroc_rec: {:.4f}; auroc_dec: {:.4f}". format(epoch, args.epochs, auroc_rec, auroc_dec)) logger.add_scalar(class_name + '/auroc_rec', auroc_rec, epoch) logger.add_scalar(class_name + '/auroc_dec', auroc_dec, epoch) # time epoch_time.update(time.time() - start_time) start_time = time.time()
def train(train_loader, model, criterion, optimizer, args): model.train() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() running_metric_text = runningScore(2) running_metric_kernel = runningScore(2) end = time.time() for batch_idx, (imgs, gt_texts, gt_kernels, training_masks) in enumerate(train_loader): data_time.update(time.time() - end) imgs = Variable(imgs.cuda()) gt_texts = Variable(gt_texts.cuda()) gt_kernels = Variable(gt_kernels.cuda()) training_masks = Variable(training_masks.cuda()) outputs = model(imgs) texts = outputs[:, 0, :, :] kernels = outputs[:, 1:, :, :] loss = criterion(texts, gt_texts, kernels, gt_kernels, training_masks) losses.update(loss.item(), imgs.size(0)) optimizer.zero_grad() loss.backward() if (args.sr_lr is not None): updateBN(model, args) optimizer.step() score_text = cal_text_score(texts, gt_texts, training_masks, running_metric_text) score_kernel = cal_kernel_score(kernels, gt_kernels, gt_texts, training_masks, running_metric_kernel) batch_time.update(time.time() - end) end = time.time() if batch_idx % 20 == 0: output_log = '({batch}/{size}) Batch: {bt:.3f}s | TOTAL: {total:.0f}min | ETA: {eta:.0f}min | Loss: {loss:.4f} | Acc_t: {acc: .4f} | IOU_t: {iou_t: .4f} | IOU_k: {iou_k: .4f}'.format( batch=batch_idx + 1, size=len(train_loader), bt=batch_time.avg, total=batch_time.avg * batch_idx / 60.0, eta=batch_time.avg * (len(train_loader) - batch_idx) / 60.0, loss=losses.avg, acc=score_text['Mean Acc'], iou_t=score_text['Mean IoU'], iou_k=score_kernel['Mean IoU']) print(output_log) sys.stdout.flush() return (losses.avg, score_text['Mean Acc'], score_kernel['Mean Acc'], score_text['Mean IoU'], score_kernel['Mean IoU'])
def evaluate(val_loader, model, criterion, test=None): ''' 模型评估 :param val_loader: :param model: :param criterion: :param test: :return: ''' global best_acc batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() predict_all = np.array([], dtype=int) labels_all = np.array([], dtype=int) ################# # val the model ################# model.eval() end = time.time() # 训练每批数据,然后进行模型的训练 ## 定义bar 变量 bar = Bar('Processing', max=len(val_loader)) for batch_index, (inputs, targets) in enumerate(val_loader): data_time.update(time.time() - end) # move tensors to GPU if cuda is_available inputs, targets = inputs.to(device), targets.to(device) # 模型的预测 outputs = model(inputs) # 计算loss loss = criterion(outputs, targets) # 计算acc和变量更新 prec1, _ = accuracy(outputs.data, targets.data, topk=(1, 1)) losses.update(loss.item(), inputs.size(0)) top1.update(prec1.item(), inputs.size(0)) batch_time.update(time.time() - end) end = time.time() # 评估混淆矩阵的数据 targets = targets.data.cpu().numpy() # 真实数据的y数值 predic = torch.max(outputs.data, 1)[1].cpu().numpy() # 预测数据y数值 labels_all = np.append(labels_all, targets) # 数据赋值 predict_all = np.append(predict_all, predic) ## 把主要的参数打包放进bar中 # plot progress bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f}'.format( batch=batch_index + 1, size=len(val_loader), data=data_time.val, bt=batch_time.val, total=bar.elapsed_td, eta=bar.eta_td, loss=losses.avg, top1=top1.avg) bar.next() bar.finish() if test: return (losses.avg, top1.avg, predict_all, labels_all) else: return (losses.avg, top1.avg)
def train_with_correspondences(save_folder, startnet, args): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") check_mkdir(save_folder) writer = SummaryWriter(save_folder) # Network and weight loading model_config = model_configs.PspnetCityscapesConfig() net = model_config.init_network().to(device) if args['snapshot'] == 'latest': args['snapshot'] = get_latest_network_name(save_folder) if len(args['snapshot']) == 0: # If start from beginning state_dict = torch.load(startnet) # needed since we slightly changed the structure of the network in # pspnet state_dict = rename_keys_to_match(state_dict) net.load_state_dict(state_dict) # load original weights start_iter = 0 args['best_record'] = { 'iter': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0 } else: # If continue training print('training resumes from ' + args['snapshot']) net.load_state_dict( torch.load(os.path.join(save_folder, args['snapshot']))) # load weights split_snapshot = args['snapshot'].split('_') start_iter = int(split_snapshot[1]) with open(os.path.join(save_folder, 'bestval.txt')) as f: best_val_dict_str = f.read() args['best_record'] = eval(best_val_dict_str.rstrip()) net.train() freeze_bn(net) # Data loading setup if args['corr_set'] == 'rc': corr_set_config = data_configs.RobotcarConfig() elif args['corr_set'] == 'cmu': corr_set_config = data_configs.CmuConfig() sliding_crop_im = joint_transforms.SlidingCropImageOnly( 713, args['stride_rate']) input_transform = model_config.input_transform pre_validation_transform = model_config.pre_validation_transform target_transform = extended_transforms.MaskToTensor() train_joint_transform_seg = joint_transforms.Compose([ joint_transforms.Resize(1024), joint_transforms.RandomRotate(10), joint_transforms.RandomHorizontallyFlip(), joint_transforms.RandomCrop(713) ]) train_joint_transform_corr = corr_transforms.Compose([ corr_transforms.CorrResize(1024), corr_transforms.CorrRandomCrop(713) ]) # keep list of segmentation loaders and validators seg_loaders = list() validators = list() # Correspondences corr_set = correspondences.Correspondences( corr_set_config.correspondence_path, corr_set_config.correspondence_im_path, input_size=(713, 713), mean_std=model_config.mean_std, input_transform=input_transform, joint_transform=train_joint_transform_corr) corr_loader = DataLoader(corr_set, batch_size=args['train_batch_size'], num_workers=args['n_workers'], shuffle=True) # Cityscapes Training c_config = data_configs.CityscapesConfig() seg_set_cs = cityscapes.CityScapes( c_config.train_im_folder, c_config.train_seg_folder, c_config.im_file_ending, c_config.seg_file_ending, id_to_trainid=c_config.id_to_trainid, joint_transform=train_joint_transform_seg, sliding_crop=None, transform=input_transform, target_transform=target_transform) seg_loader_cs = DataLoader(seg_set_cs, batch_size=args['train_batch_size'], num_workers=args['n_workers'], shuffle=True) seg_loaders.append(seg_loader_cs) # Cityscapes Validation val_set_cs = cityscapes.CityScapes( c_config.val_im_folder, c_config.val_seg_folder, c_config.im_file_ending, c_config.seg_file_ending, id_to_trainid=c_config.id_to_trainid, sliding_crop=sliding_crop_im, transform=input_transform, target_transform=target_transform, transform_before_sliding=pre_validation_transform) val_loader_cs = DataLoader(val_set_cs, batch_size=1, num_workers=args['n_workers'], shuffle=False) validator_cs = Validator(val_loader_cs, n_classes=c_config.n_classes, save_snapshot=False, extra_name_str='Cityscapes') validators.append(validator_cs) # Vistas Training and Validation if args['include_vistas']: v_config = data_configs.VistasConfig( use_subsampled_validation_set=True, use_cityscapes_classes=True) seg_set_vis = cityscapes.CityScapes( v_config.train_im_folder, v_config.train_seg_folder, v_config.im_file_ending, v_config.seg_file_ending, id_to_trainid=v_config.id_to_trainid, joint_transform=train_joint_transform_seg, sliding_crop=None, transform=input_transform, target_transform=target_transform) seg_loader_vis = DataLoader(seg_set_vis, batch_size=args['train_batch_size'], num_workers=args['n_workers'], shuffle=True) seg_loaders.append(seg_loader_vis) val_set_vis = cityscapes.CityScapes( v_config.val_im_folder, v_config.val_seg_folder, v_config.im_file_ending, v_config.seg_file_ending, id_to_trainid=v_config.id_to_trainid, sliding_crop=sliding_crop_im, transform=input_transform, target_transform=target_transform, transform_before_sliding=pre_validation_transform) val_loader_vis = DataLoader(val_set_vis, batch_size=1, num_workers=args['n_workers'], shuffle=False) validator_vis = Validator(val_loader_vis, n_classes=v_config.n_classes, save_snapshot=False, extra_name_str='Vistas') validators.append(validator_vis) else: seg_loader_vis = None map_validator = None # Extra Training extra_seg_set = cityscapes.CityScapes( corr_set_config.train_im_folder, corr_set_config.train_seg_folder, corr_set_config.im_file_ending, corr_set_config.seg_file_ending, id_to_trainid=corr_set_config.id_to_trainid, joint_transform=train_joint_transform_seg, sliding_crop=None, transform=input_transform, target_transform=target_transform) extra_seg_loader = DataLoader(extra_seg_set, batch_size=args['train_batch_size'], num_workers=args['n_workers'], shuffle=True) seg_loaders.append(extra_seg_loader) # Extra Validation extra_val_set = cityscapes.CityScapes( corr_set_config.val_im_folder, corr_set_config.val_seg_folder, corr_set_config.im_file_ending, corr_set_config.seg_file_ending, id_to_trainid=corr_set_config.id_to_trainid, sliding_crop=sliding_crop_im, transform=input_transform, target_transform=target_transform, transform_before_sliding=pre_validation_transform) extra_val_loader = DataLoader(extra_val_set, batch_size=1, num_workers=args['n_workers'], shuffle=False) extra_validator = Validator(extra_val_loader, n_classes=corr_set_config.n_classes, save_snapshot=True, extra_name_str='Extra') validators.append(extra_validator) # Loss setup if args['corr_loss_type'] == 'class': corr_loss_fct = CorrClassLoss(input_size=[713, 713]) else: corr_loss_fct = FeatureLoss( input_size=[713, 713], loss_type=args['corr_loss_type'], feat_dist_threshold_match=args['feat_dist_threshold_match'], feat_dist_threshold_nomatch=args['feat_dist_threshold_nomatch'], n_not_matching=0) seg_loss_fct = torch.nn.CrossEntropyLoss( reduction='elementwise_mean', ignore_index=cityscapes.ignore_label).to(device) # Optimizer setup optimizer = optim.SGD([{ 'params': [ param for name, param in net.named_parameters() if name[-4:] == 'bias' and param.requires_grad ], 'lr': 2 * args['lr'] }, { 'params': [ param for name, param in net.named_parameters() if name[-4:] != 'bias' and param.requires_grad ], 'lr': args['lr'], 'weight_decay': args['weight_decay'] }], momentum=args['momentum'], nesterov=True) if len(args['snapshot']) > 0: optimizer.load_state_dict( torch.load(os.path.join(save_folder, 'opt_' + args['snapshot']))) optimizer.param_groups[0]['lr'] = 2 * args['lr'] optimizer.param_groups[1]['lr'] = args['lr'] open(os.path.join(save_folder, str(datetime.datetime.now()) + '.txt'), 'w').write(str(args) + '\n\n') if len(args['snapshot']) == 0: f_handle = open(os.path.join(save_folder, 'log.log'), 'w', buffering=1) else: clean_log_before_continuing(os.path.join(save_folder, 'log.log'), start_iter) f_handle = open(os.path.join(save_folder, 'log.log'), 'a', buffering=1) ########################################################################## # # MAIN TRAINING CONSISTS OF ALL SEGMENTATION LOSSES AND A CORRESPONDENCE LOSS # ########################################################################## softm = torch.nn.Softmax2d() val_iter = 0 train_corr_loss = AverageMeter() train_seg_cs_loss = AverageMeter() train_seg_extra_loss = AverageMeter() train_seg_vis_loss = AverageMeter() seg_loss_meters = list() seg_loss_meters.append(train_seg_cs_loss) if args['include_vistas']: seg_loss_meters.append(train_seg_vis_loss) seg_loss_meters.append(train_seg_extra_loss) curr_iter = start_iter for i in range(args['max_iter']): optimizer.param_groups[0]['lr'] = 2 * args['lr'] * ( 1 - float(curr_iter) / args['max_iter'])**args['lr_decay'] optimizer.param_groups[1]['lr'] = args['lr'] * ( 1 - float(curr_iter) / args['max_iter'])**args['lr_decay'] ####################################################################### # SEGMENTATION UPDATE STEP ####################################################################### # for si, seg_loader in enumerate(seg_loaders): # get segmentation training sample inputs, gts = next(iter(seg_loader)) slice_batch_pixel_size = inputs.size(0) * inputs.size( 2) * inputs.size(3) inputs = inputs.to(device) gts = gts.to(device) optimizer.zero_grad() outputs, aux = net(inputs) main_loss = args['seg_loss_weight'] * seg_loss_fct(outputs, gts) aux_loss = args['seg_loss_weight'] * seg_loss_fct(aux, gts) loss = main_loss + 0.4 * aux_loss loss.backward() optimizer.step() seg_loss_meters[si].update(main_loss.item(), slice_batch_pixel_size) ####################################################################### # CORRESPONDENCE UPDATE STEP ####################################################################### if args['corr_loss_weight'] > 0 and args[ 'n_iterations_before_corr_loss'] < curr_iter: img_ref, img_other, pts_ref, pts_other, weights = next( iter(corr_loader)) # Transfer data to device # img_ref is from the "good" sequence with generally better # segmentation results img_ref = img_ref.to(device) img_other = img_other.to(device) pts_ref = [p.to(device) for p in pts_ref] pts_other = [p.to(device) for p in pts_other] weights = [w.to(device) for w in weights] # Forward pass if args['corr_loss_type'] == 'hingeF': # Works on features net.output_all = True with torch.no_grad(): output_feat_ref, aux_feat_ref, output_ref, aux_ref = net( img_ref) output_feat_other, aux_feat_other, output_other, aux_other = net( img_other ) # output1 must be last to backpropagate derivative correctly net.output_all = False else: # Works on class probs with torch.no_grad(): output_ref, aux_ref = net(img_ref) if args['corr_loss_type'] != 'hingeF' and args[ 'corr_loss_type'] != 'hingeC': output_ref = softm(output_ref) aux_ref = softm(aux_ref) # output1 must be last to backpropagate derivative correctly output_other, aux_other = net(img_other) if args['corr_loss_type'] != 'hingeF' and args[ 'corr_loss_type'] != 'hingeC': output_other = softm(output_other) aux_other = softm(aux_other) # Correspondence filtering pts_ref_orig, pts_other_orig, weights_orig, batch_inds_to_keep_orig = correspondences.refine_correspondence_sample( output_ref, output_other, pts_ref, pts_other, weights, remove_same_class=args['remove_same_class'], remove_classes=args['classes_to_ignore']) pts_ref_orig = [ p for b, p in zip(batch_inds_to_keep_orig, pts_ref_orig) if b.item() > 0 ] pts_other_orig = [ p for b, p in zip(batch_inds_to_keep_orig, pts_other_orig) if b.item() > 0 ] weights_orig = [ p for b, p in zip(batch_inds_to_keep_orig, weights_orig) if b.item() > 0 ] if args['corr_loss_type'] == 'hingeF': # remove entire samples if needed output_vals_ref = output_feat_ref[batch_inds_to_keep_orig] output_vals_other = output_feat_other[batch_inds_to_keep_orig] else: # remove entire samples if needed output_vals_ref = output_ref[batch_inds_to_keep_orig] output_vals_other = output_other[batch_inds_to_keep_orig] pts_ref_aux, pts_other_aux, weights_aux, batch_inds_to_keep_aux = correspondences.refine_correspondence_sample( aux_ref, aux_other, pts_ref, pts_other, weights, remove_same_class=args['remove_same_class'], remove_classes=args['classes_to_ignore']) pts_ref_aux = [ p for b, p in zip(batch_inds_to_keep_aux, pts_ref_aux) if b.item() > 0 ] pts_other_aux = [ p for b, p in zip(batch_inds_to_keep_aux, pts_other_aux) if b.item() > 0 ] weights_aux = [ p for b, p in zip(batch_inds_to_keep_aux, weights_aux) if b.item() > 0 ] if args['corr_loss_type'] == 'hingeF': # remove entire samples if needed aux_vals_ref = aux_feat_ref[batch_inds_to_keep_orig] aux_vals_other = aux_feat_other[batch_inds_to_keep_orig] else: # remove entire samples if needed aux_vals_ref = aux_ref[batch_inds_to_keep_aux] aux_vals_other = aux_other[batch_inds_to_keep_aux] optimizer.zero_grad() # correspondence loss if output_vals_ref.size(0) > 0: loss_corr_hr = corr_loss_fct(output_vals_ref, output_vals_other, pts_ref_orig, pts_other_orig, weights_orig) else: loss_corr_hr = 0 * output_vals_other.sum() if aux_vals_ref.size(0) > 0: loss_corr_aux = corr_loss_fct( aux_vals_ref, aux_vals_other, pts_ref_aux, pts_other_aux, weights_aux) # use output from img1 as "reference" else: loss_corr_aux = 0 * aux_vals_other.sum() loss_corr = args['corr_loss_weight'] * \ (loss_corr_hr + 0.4 * loss_corr_aux) loss_corr.backward() optimizer.step() train_corr_loss.update(loss_corr.item()) ####################################################################### # LOGGING ETC ####################################################################### curr_iter += 1 val_iter += 1 writer.add_scalar('train_seg_loss_cs', train_seg_cs_loss.avg, curr_iter) writer.add_scalar('train_seg_loss_extra', train_seg_extra_loss.avg, curr_iter) writer.add_scalar('train_seg_loss_vis', train_seg_vis_loss.avg, curr_iter) writer.add_scalar('train_corr_loss', train_corr_loss.avg, curr_iter) writer.add_scalar('lr', optimizer.param_groups[1]['lr'], curr_iter) if (i + 1) % args['print_freq'] == 0: str2write = '[iter %d / %d], [train corr loss %.5f] , [seg cs loss %.5f], [seg vis loss %.5f], [seg extra loss %.5f]. [lr %.10f]' % ( curr_iter, len(corr_loader), train_corr_loss.avg, train_seg_cs_loss.avg, train_seg_vis_loss.avg, train_seg_extra_loss.avg, optimizer.param_groups[1]['lr']) print(str2write) f_handle.write(str2write + "\n") if val_iter >= args['val_interval']: val_iter = 0 for validator in validators: validator.run(net, optimizer, args, curr_iter, save_folder, f_handle, writer=writer) # Post training f_handle.close() writer.close()
def train(train_loader, net, optim, curr_epoch): """ Runs the training loop per epoch train_loader: Data loader for train net: thet network optimizer: optimizer curr_epoch: current epoch return: """ net.train() train_main_loss = AverageMeter() start_time = None warmup_iter = 10 for i, data in enumerate(train_loader): if i <= warmup_iter: start_time = time.time() # inputs = (bs,3,713,713) # gts = (bs,713,713) images, gts, _img_name, scale_float = data batch_pixel_size = images.size(0) * images.size(2) * images.size(3) images, gts, scale_float = images.cuda(), gts.cuda(), scale_float.cuda() inputs = {'images': images, 'gts': gts} optim.zero_grad() main_loss = net(inputs) if args.apex: log_main_loss = main_loss.clone().detach_() torch.distributed.all_reduce(log_main_loss, torch.distributed.ReduceOp.SUM) log_main_loss = log_main_loss / args.world_size else: main_loss = main_loss.mean() log_main_loss = main_loss.clone().detach_() train_main_loss.update(log_main_loss.item(), batch_pixel_size) if args.fp16: with amp.scale_loss(main_loss, optim) as scaled_loss: scaled_loss.backward() else: main_loss.backward() optim.step() if i >= warmup_iter: curr_time = time.time() batches = i - warmup_iter + 1 batchtime = (curr_time - start_time) / batches else: batchtime = 0 msg = ('[epoch {}], [iter {} / {}], [train main loss {:0.6f}],' ' [lr {:0.6f}] [batchtime {:0.3g}]') msg = msg.format( curr_epoch, i + 1, len(train_loader), train_main_loss.avg, optim.param_groups[-1]['lr'], batchtime) logx.msg(msg) metrics = {'loss': train_main_loss.avg, 'lr': optim.param_groups[-1]['lr']} curr_iter = curr_epoch * len(train_loader) + i logx.metric('train', metrics, curr_iter) if i >= 10 and args.test_mode: del data, inputs, gts return del data
def test(test_loader, model, configs): batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') acc_event = AverageMeter('Acc_Event', ':6.4f') iou_seg = AverageMeter('IoU_Seg', ':6.4f') mse_global = AverageMeter('MSE_Global', ':6.4f') mse_local = AverageMeter('MSE_Local', ':6.4f') # switch to evaluate mode model.eval() with torch.no_grad(): start_time = time.time() for batch_idx, (origin_imgs, resized_imgs, org_ball_pos_xy, global_ball_pos_xy, event_class, target_seg) in enumerate(tqdm(test_loader)): data_time.update(time.time() - start_time) batch_size = resized_imgs.size(0) target_seg = target_seg.to(configs.device, non_blocking=True) resized_imgs = resized_imgs.to(configs.device, non_blocking=True).float() # compute output if 'local' in configs.tasks: origin_imgs = origin_imgs.to(configs.device, non_blocking=True).float() pred_ball_global, pred_ball_local, pred_events, pred_seg, local_ball_pos_xy, total_loss, _ = model( origin_imgs, resized_imgs, org_ball_pos_xy, global_ball_pos_xy, event_class, target_seg) else: pred_ball_global, pred_ball_local, pred_events, pred_seg, local_ball_pos_xy, total_loss, _ = model( None, resized_imgs, org_ball_pos_xy, global_ball_pos_xy, event_class, target_seg) # Transfer output to cpu pred_ball_global = pred_ball_global.cpu().numpy() global_ball_pos_xy = global_ball_pos_xy.numpy() if pred_ball_local is not None: pred_ball_local = pred_ball_local.cpu().numpy() local_ball_pos_xy = local_ball_pos_xy.cpu().numpy( ) # Ground truth of the local stage if pred_events is not None: pred_events = pred_events.cpu().numpy() if pred_seg is not None: pred_seg = pred_seg.cpu().numpy() target_seg = target_seg.cpu().numpy() org_ball_pos_xy = org_ball_pos_xy.numpy() for sample_idx in range(batch_size): w, h = configs.input_size # Get target sample_org_ball_pos_xy = org_ball_pos_xy[sample_idx] sample_global_ball_pos_xy = global_ball_pos_xy[ sample_idx] # Target # Process the global stage sample_pred_ball_global = pred_ball_global[sample_idx] sample_pred_ball_global[sample_pred_ball_global < configs.thresh_ball_pos_mask] = 0. sample_pred_ball_global_x = np.argmax( sample_pred_ball_global[:w]) sample_pred_ball_global_y = np.argmax( sample_pred_ball_global[w:]) # Calculate the MSE if (sample_global_ball_pos_xy[0] > 0) and (sample_global_ball_pos_xy[1] > 0): mse = (sample_pred_ball_global_x - sample_global_ball_pos_xy[0])**2 + ( sample_pred_ball_global_y - sample_global_ball_pos_xy[1])**2 mse_global.update(mse) print( 'Global stage: (x, y) - org: ({}, {}), gt = ({}, {}), prediction = ({}, {})' .format(sample_org_ball_pos_xy[0], sample_org_ball_pos_xy[1], sample_global_ball_pos_xy[0], sample_global_ball_pos_xy[1], sample_pred_ball_global_x, sample_pred_ball_global_y)) # Process local ball stage if pred_ball_local is not None: # Get target sample_local_ball_pos_xy = local_ball_pos_xy[ sample_idx] # Target # Process the local stage sample_pred_ball_local = pred_ball_local[sample_idx] sample_pred_ball_local[sample_pred_ball_local < configs.thresh_ball_pos_mask] = 0. sample_pred_ball_local_x = np.argmax( sample_pred_ball_local[:w]) sample_pred_ball_local_y = np.argmax( sample_pred_ball_local[w:]) # Calculate the MSE if (sample_local_ball_pos_xy[0] > 0) and (sample_local_ball_pos_xy[1] > 0): mse = (sample_pred_ball_local_x - sample_local_ball_pos_xy[0])**2 + ( sample_pred_ball_local_y - sample_local_ball_pos_xy[1])**2 mse_local.update(mse) print( 'Local stage: (x, y) - gt = ({}, {}), prediction = ({}, {})' .format(sample_local_ball_pos_xy[0], sample_local_ball_pos_xy[1], sample_pred_ball_local_x, sample_pred_ball_local_y)) # Process event stage if pred_events is not None: sample_target_event = event_class[sample_idx].item() vec_sample_target_event = np.zeros((2, ), dtype=np.int) if sample_target_event < 2: vec_sample_target_event[sample_target_event] = 1 sample_pred_event = (pred_events[sample_idx] > configs.event_thresh).astype(np.int) print('Event stage: gt = {}, prediction: {}'.format( sample_target_event, pred_events[sample_idx])) diff = sample_pred_event - vec_sample_target_event # Check correct or not if np.sum(diff) != 0: # Incorrect acc_event.update(0) else: # Correct acc_event.update(1) # Process segmentation stage if pred_seg is not None: sample_target_seg = target_seg[sample_idx].transpose( 1, 2, 0) sample_pred_seg = pred_seg[sample_idx].transpose(1, 2, 0) sample_target_seg = sample_target_seg.astype(np.int) sample_pred_seg = (sample_pred_seg > configs.seg_thresh).astype(np.int) # Calculate the IoU iou = 2 * np.sum(sample_target_seg * sample_pred_seg) / ( np.sum(sample_target_seg) + np.sum(sample_pred_seg) + 1e-9) iou_seg.update(iou) if configs.save_test_output: fig, axes = plt.subplots(nrows=batch_size, ncols=2, figsize=(10, 5)) plt.tight_layout() axes.ravel() axes[2 * sample_idx].imshow(sample_target_seg * 255) axes[2 * sample_idx + 1].imshow(sample_pred_seg * 255) # title target_title = 'target seg' pred_title = 'pred seg' if pred_events is not None: target_title += ', is bounce: {}, is net: {}'.format( vec_sample_target_event[0], vec_sample_target_event[1]) pred_title += ', is bounce: {}, is net: {}'.format( sample_pred_event[0], sample_pred_event[1]) axes[2 * sample_idx].set_title(target_title) axes[2 * sample_idx + 1].set_title(pred_title) plt.savefig( os.path.join( configs.saved_dir, 'batch_idx_{}_sample_idx_{}.jpg'.format( batch_idx, sample_idx))) if ((batch_idx + 1) % configs.print_freq) == 0: print( 'batch_idx: {} - Average acc_event: {}, iou_seg: {}, mse_global: {}, mse_local: {}' .format(batch_idx, acc_event.avg, iou_seg.avg, mse_global.avg, mse_local.avg)) batch_time.update(time.time() - start_time) start_time = time.time() print('Average acc_event: {}, iou_seg: {}, mse_global: {}, mse_local: {}'. format(acc_event.avg, iou_seg.avg, mse_global.avg, mse_local.avg)) print('Done testing')
def train_one_epoch(train_dataloader, model, optimizer, lr_scheduler, epoch, configs, logger, tb_writer): batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') losses = AverageMeter('Loss', ':.4e') progress = ProgressMeter(len(train_dataloader), [batch_time, data_time, losses], prefix="Train - Epoch: [{}/{}]".format(epoch, configs.num_epochs)) num_iters_per_epoch = len(train_dataloader) # switch to train mode model.train() start_time = time.time() for batch_idx, batch_data in enumerate(tqdm(train_dataloader)): data_time.update(time.time() - start_time) _, imgs, targets = batch_data global_step = num_iters_per_epoch * (epoch - 1) + batch_idx + 1 batch_size = imgs.size(0) targets = targets.to(configs.device, non_blocking=True) imgs = imgs.to(configs.device, non_blocking=True) total_loss, outputs = model(imgs, targets) # For torch.nn.DataParallel case if (not configs.distributed) and (configs.gpu_idx is None): total_loss = torch.mean(total_loss) # compute gradient and perform backpropagation total_loss.backward() if global_step % configs.subdivisions == 0: optimizer.step() # Adjust learning rate lr_scheduler.step() # zero the parameter gradients optimizer.zero_grad() if configs.distributed: reduced_loss = reduce_tensor(total_loss.data, configs.world_size) else: reduced_loss = total_loss.data losses.update(to_python_float(reduced_loss), batch_size) # measure elapsed time # torch.cuda.synchronize() batch_time.update(time.time() - start_time) if tb_writer is not None: if (global_step % configs.tensorboard_freq) == 0: tensorboard_log = get_tensorboard_log(model) tensorboard_log['lr'] = lr_scheduler.get_lr()[0] * configs.batch_size * configs.subdivisions tensorboard_log['avg_loss'] = losses.avg tb_writer.add_scalars('Train', tensorboard_log, global_step) # Log message if logger is not None: if (global_step % configs.print_freq) == 0: logger.info(progress.get_message(batch_idx)) start_time = time.time()
def train(train_loader, net, optim, curr_epoch, writer): """ Runs the training loop per epoch train_loader: Data loader for train net: thet network optimizer: optimizer curr_epoch: current epoch writer: tensorboard writer return: """ net.train() train_main_loss = AverageMeter() curr_iter = curr_epoch * len(train_loader) for i, data in enumerate(train_loader): edges = None if args.joint_edge_loss_pfnet: inputs, gts, bodys, edges, _img_name = data else: inputs, gts, _img_name = data batch_pixel_size = inputs.size(0) * inputs.size(2) * inputs.size(3) inputs, gts = inputs.cuda(), gts.cuda() optim.zero_grad() if args.joint_edge_loss_pfnet: main_loss_dic = net(inputs, gts=(gts, edges)) main_loss = 0.0 for v in main_loss_dic.values(): main_loss = main_loss + v else: main_loss = net(inputs, gts=gts) if args.apex: log_main_loss = main_loss.clone().detach_() torch.distributed.all_reduce(log_main_loss, torch.distributed.ReduceOp.SUM) log_main_loss = log_main_loss / args.world_size else: main_loss = main_loss.mean() log_main_loss = main_loss.clone().detach_() train_main_loss.update(log_main_loss.item(), batch_pixel_size) if args.fp16: with amp.scale_loss(main_loss, optim) as scaled_loss: scaled_loss.backward() else: if not torch.isfinite(main_loss).all(): raise FloatingPointError( "Loss became infinite or NaN at iteration={}!\nloss_dict = {}" .format(curr_iter, main_loss)) main_loss.backward() optim.step() curr_iter += 1 if args.local_rank == 0 and i % args.print_freq == 0: if args.joint_edge_loss_pfnet: msg = f'[epoch {curr_epoch}], [iter {i + 1} / {len(train_loader)}], ' msg += '[seg_main_loss:{:0.5f}]'.format( main_loss_dic['seg_loss']) for j in range(3): temp_msg = '[layer{}:, [edge loss {:0.5f}] '.format( (3 - j), main_loss_dic[f'edge_loss_layer{3-j}']) msg += temp_msg msg += ', [lr {:0.5f}]'.format(optim.param_groups[-1]['lr']) else: msg = '[epoch {}], [iter {} / {}], [train main loss {:0.6f}], [lr {:0.6f}]'.format( curr_epoch, i + 1, len(train_loader), train_main_loss.avg, optim.param_groups[-1]['lr']) logging.info(msg) # Log tensorboard metrics for each iteration of the training phase writer.add_scalar('training/loss', (train_main_loss.val), curr_iter) writer.add_scalar('training/lr', optim.param_groups[-1]['lr'], curr_iter) if i > 5 and args.test_mode: return